In this tutorial we will see how a traversal to the root node can be achieved from any node in a java tree. In the below tree we can reach the root from any Node.

The traversal is achieved by a adding a simple method to each node.It is very concise.

public Node getRoot() { if(parent == null){ return this; } return parent.getRoot(); }

So how does magical function work. Every node first checks if it is a root by itself. And it can be root only when it has no parent. If it is not a root, it calls the function to the parent and it bubbles up right to the root. And once a root is found the resulting node is passed back and the node asking for it gets the reference of the root node. Here is the image depicting the same flow:

Let us see an example.

*Node.java *

package com.programtak.tree.tutorial; import java.util.ArrayList; import java.util.List; public class Node { private String id; private final List<Node> children = new ArrayList<>(); private final Node parent; public Node getRoot() { if(parent == null){ return this; } return parent.getRoot(); } public Node(Node parent) { this.parent = parent; } public String getId() { return id; } public void setId(String id) { this.id = id; } public List<Node> getChildren() { return children; } public Node getParent() { return parent; } }

Now Let us build a Tree based on the diagram shown above and also use the getRoot() method to see its root.

package com.programtak.tree.tutorial; public class TreeTest { public static void main(String[] args) { Node treeRootNode = new Node(null); treeRootNode.setId("root"); // add child to root node Node childNode1= addChild(treeRootNode, "child-1"); // add child to the child node created above addChild(childNode1, "child-11"); Node childNode12 = addChild(childNode1, "child-12"); addChild(childNode12, "child-121"); addChild(childNode12, "child-122"); // add child to root node Node child2 = addChild(treeRootNode, "child-2"); // add child to the child node created above addChild(child2, "child-21"); printTree(treeRootNode, " "); } private static Node addChild(Node parent, String id) { Node node = new Node(parent); node.setId(id); parent.getChildren().add(node); return node; } private static void printTree(Node node, String appender) { System.out.println(appender + node.getId() +", root node is: " + node.getRoot().getId()); for (Node each : node.getChildren()) { printTree(each, appender + appender); } } }

And the output of above traveral of tree which prints the root node with every node:

root, root node is: root child-1, root node is: root child-11, root node is: root child-12, root node is: root child-121, root node is: root child-122, root node is: root child-2, root node is: root child-21, root node is: root