Dynamic Programming on Trees
Definition
DP on Trees involves solving a problem for a root node by combining results from its children sub-trees. It’s essentially “Post-Order Traversal” (Left, Right, Root) combined with DP logic.
How it Works
- Leaf Nodes: These are the base cases (e.g., diameter is 0).
- Recursive Step: For a node u, calculate the answer based on answers from children v1, v2….
- Example Problem (Diameter of Tree): The longest path between any two nodes. The path either passes through the root or is entirely inside one of the sub-trees.

Example Logic (Tree Diameter)
For every node, calculate:
- Height of left subtree (L).
- Height of right subtree (R).
- Diameter passing through this node = L + R + 1.
- Update global maximum diameter.
- Return max(L, R) + 1 to the parent.
Code Example
PythonJavaCC++
# Node class
class Node:
def __init__(self, data):
self.data = data
self.left = None
self.right = None
# Diameter of Binary Tree
def diameter(root, res):
if root is None:
return 0
# Heights of left and right subtree
l = diameter(root.left, res)
r = diameter(root.right, res)
# Update maximum diameter
res[0] = max(res[0], l + r + 1)
# Return height
return max(l, r) + 1
# Build tree
root = Node(1)
root.left = Node(2)
root.right = Node(3)
res = [-1]
diameter(root, res)
print("Diameter is", res[0])
class Node {
int data;
Node left, right;
Node(int data) {
this.data = data;
left = right = null;
}
}
public class DiameterOfTree {
public static int diameter(Node root, int[] res) {
if (root == null)
return 0;
int l = diameter(root.left, res);
int r = diameter(root.right, res);
res[0] = Math.max(res[0], l + r + 1);
return Math.max(l, r) + 1;
}
public static void main(String[] args) {
Node root = new Node(1);
root.left = new Node(2);
root.right = new Node(3);
int[] res = new int[]{-1};
diameter(root, res);
System.out.println("Diameter is " + res[0]);
}
}
#include <stdio.h>
#include <stdlib.h>
struct Node {
int data;
struct Node* left;
struct Node* right;
};
struct Node* newNode(int data) {
struct Node* node =
(struct Node*)malloc(sizeof(struct Node));
node->data = data;
node->left = NULL;
node->right = NULL;
return node;
}
int diameter(struct Node* root, int* res) {
if (root == NULL)
return 0;
int l = diameter(root->left, res);
int r = diameter(root->right, res);
if (l + r + 1 > *res)
*res = l + r + 1;
return (l > r ? l : r) + 1;
}
int main() {
struct Node* root = newNode(1);
root->left = newNode(2);
root->right = newNode(3);
int res = -1;
diameter(root, &res);
printf("Diameter is %d\n", res);
return 0;
}
#include <iostream>
using namespace std;
class Node {
public:
int data;
Node* left;
Node* right;
Node(int data) {
this->data = data;
left = right = NULL;
}
};
int diameter(Node* root, int &res) {
if (root == NULL)
return 0;
int l = diameter(root->left, res);
int r = diameter(root->right, res);
res = max(res, l + r + 1);
return max(l, r) + 1;
}
int main() {
Node* root = new Node(1);
root->left = new Node(2);
root->right = new Node(3);
int res = -1;
diameter(root, res);
cout << "Diameter is " << res;
return 0;
}
Complexity Analysis
- Time Complexity: O(N. We visit every node once.
- Space Complexity: O(H) where H is the height of the tree (recursion stack).
Use Cases
- Network Routing: Finding optimal paths in hierarchical networks.
- Company Hierarchy: Calculating total salary of a department (subtree sum).