Find the root node of a tree from any node in the Java Tree

In this tutorial we will see how a traversal to the root node can be achieved from any node in a java tree. In the below tree we can reach the root from any Node.

The traversal is  achieved by a adding a simple method to each node.It is very concise.

public Node getRoot() {
 if(parent == null){
  return this;
 }
 return parent.getRoot();
}
 

So how does magical function work. Every node first checks if it is a root by itself. And it can be root only when it has no parent. If it is not a root, it calls the function to the parent and it bubbles up right to the root. And once a root is found the resulting node is passed back and the node asking for it gets the reference of the root node. Here is the image depicting the same flow:

Let us see an example.
Node.java 

package com.programtak.tree.tutorial;

import java.util.ArrayList;
import java.util.List;

public class Node {
 private String id;
 private final List<Node> children = new ArrayList<>();
 private final Node parent;

 
 public Node getRoot() {
  if(parent == null){
   return this;
  }
  return parent.getRoot();
 }
 
 public Node(Node parent) {
  this.parent = parent;
 }

 public String getId() {
  return id;
 }

 public void setId(String id) {
  this.id = id;
 }

 public List<Node> getChildren() {
  return children;
 }

 public Node getParent() {
  return parent;
 }

}

Now Let us build a Tree based on the diagram shown above and also use the getRoot() method to see its root.

package com.programtak.tree.tutorial;

public class TreeTest {

 public static void main(String[] args) {
  Node treeRootNode = new Node(null);
  treeRootNode.setId("root");
  // add child to root node 
  Node childNode1= addChild(treeRootNode, "child-1");
  // add child to the child node created above
  addChild(childNode1, "child-11");
  
  Node childNode12 = addChild(childNode1, "child-12");
  addChild(childNode12, "child-121");
  addChild(childNode12, "child-122");
  
  
  // add child to root node
  Node child2 = addChild(treeRootNode, "child-2");
  // add child to the child node created above
  addChild(child2, "child-21");

  
  printTree(treeRootNode, " ");

 }

 private static Node addChild(Node parent, String id) {
   Node node = new Node(parent);
   node.setId(id);
   parent.getChildren().add(node);
   
   return node;
 }

 private static void printTree(Node node, String appender) {
  System.out.println(appender + node.getId() +",  root node is: " +  node.getRoot().getId());
  for (Node each : node.getChildren()) {
   printTree(each, appender + appender);
  }
 }
}

And the output of above traveral of tree which prints the root node with every node:

 root,  root node is: root
  child-1,  root node is: root
    child-11,  root node is: root
    child-12,  root node is: root
        child-121,  root node is: root
        child-122,  root node is: root
  child-2,  root node is: root
    child-21,  root node is: root

Leave a Comment

This site uses Akismet to reduce spam. Learn how your comment data is processed.