DP on Trees and Graphs

advanced dynamic-programming tree-DP graph

So far, our DP has been on arrays and strings. But DP can also work on trees and graphs. The core idea stays the same — break a problem into subproblems and cache results. The difference is that our subproblems follow the tree structure instead of array indices.

In simple language, we solve the problem for leaf nodes first, then work our way up to the root. This naturally happens with post-order traversal (solve children before parent).

The Tree DP Pattern

The general approach:

  1. Traverse the tree (usually post-order / DFS)
  2. At each node, compute the answer using the answers from its children
  3. Return the result up to the parent
  4. Memoize if the same subtree can be reached multiple ways (rare in trees, common in graphs)

House Robber III

Problem: Same house robber rules (can’t rob adjacent), but now the houses are arranged in a binary tree. Two directly linked nodes can’t both be robbed.

At each node we have two choices:

  • Rob it — take this node’s value + the “skip” results from both children
  • Skip it — take the best results from both children (whether they robbed or skipped)

We return a pair [robThis, skipThis] from each node.

House Robber III — Rob or Skip Each Node
3
2
3
3
1
robbed skipped
Rob root(3) + grandchildren(3+1) = 7
function rob(root) {
  function dfs(node) {
    if (!node) return [0, 0];        // [rob, skip]
    const left = dfs(node.left);
    const right = dfs(node.right);
    const robThis = node.val + left[1] + right[1]; // rob me + skip kids
    const skipThis = Math.max(...left) + Math.max(...right); // best of kids
    return [robThis, skipThis];
  }
  return Math.max(...dfs(root));
}
// Time: O(n), Space: O(h) where h = tree height
def rob(root):
    def dfs(node):
        if not node:
            return (0, 0)            # (rob, skip)
        left = dfs(node.left)
        right = dfs(node.right)
        rob_this = node.val + left[1] + right[1]   # rob me + skip kids
        skip_this = max(left) + max(right)          # best of kids
        return (rob_this, skip_this)
    return max(dfs(root))
# Time: O(n), Space: O(h) where h = tree height
static int rob(TreeNode root) {
    int[] res = dfs(root);
    return Math.max(res[0], res[1]);
}
static int[] dfs(TreeNode node) {
    if (node == null) return new int[]{0, 0}; // {rob, skip}
    int[] left = dfs(node.left);
    int[] right = dfs(node.right);
    int robThis = node.val + left[1] + right[1];
    int skipThis = Math.max(left[0], left[1]) + Math.max(right[0], right[1]);
    return new int[]{robThis, skipThis};
}
// Time: O(n), Space: O(h) where h = tree height

The key trick: instead of memoizing in a hash map, we return both states (rob/skip) from each recursive call. Clean and efficient.

Tree Diameter

Problem: Find the length of the longest path between any two nodes in a binary tree. The path doesn’t need to go through the root.

At each node, the longest path passing through it equals leftDepth + rightDepth. We track the global maximum while computing depths.

function diameterOfBinaryTree(root) {
  let maxDiameter = 0;
  function depth(node) {
    if (!node) return 0;
    const left = depth(node.left);
    const right = depth(node.right);
    maxDiameter = Math.max(maxDiameter, left + right); // path through node
    return 1 + Math.max(left, right);  // return depth to parent
  }
  depth(root);
  return maxDiameter;
}
// Time: O(n), Space: O(h)
def diameter_of_binary_tree(root):
    max_diameter = 0
    def depth(node):
        nonlocal max_diameter
        if not node: return 0
        left = depth(node.left)
        right = depth(node.right)
        max_diameter = max(max_diameter, left + right)  # path through node
        return 1 + max(left, right)    # return depth to parent
    depth(root)
    return max_diameter
# Time: O(n), Space: O(h)
static int maxDiameter = 0;
static int diameterOfBinaryTree(TreeNode root) {
    maxDiameter = 0;
    depth(root);
    return maxDiameter;
}
static int depth(TreeNode node) {
    if (node == null) return 0;
    int left = depth(node.left);
    int right = depth(node.right);
    maxDiameter = Math.max(maxDiameter, left + right);
    return 1 + Math.max(left, right);
}
// Time: O(n), Space: O(h)

Binary Tree Maximum Path Sum

Problem: Find the maximum path sum in a binary tree. A path can start and end at any node. Node values can be negative.

This is the harder cousin of tree diameter. Same pattern, but we track sums instead of depths, and we need to handle negative values.

At each node:

  • The path through this node = node.val + leftGain + rightGain
  • But we can only return one branch to the parent (can’t fork upward)
  • Negative gains should be treated as 0 (just don’t take that branch)
function maxPathSum(root) {
  let maxSum = -Infinity;
  function gain(node) {
    if (!node) return 0;
    const left = Math.max(0, gain(node.left));   // ignore negative paths
    const right = Math.max(0, gain(node.right));
    maxSum = Math.max(maxSum, node.val + left + right); // path through node
    return node.val + Math.max(left, right);  // return best single branch
  }
  gain(root);
  return maxSum;
}
def max_path_sum(root):
    max_sum = float('-inf')
    def gain(node):
        nonlocal max_sum
        if not node: return 0
        left = max(0, gain(node.left))         # ignore negative paths
        right = max(0, gain(node.right))
        max_sum = max(max_sum, node.val + left + right)  # path through node
        return node.val + max(left, right)     # return best single branch
    gain(root)
    return max_sum
static int maxSum;
static int maxPathSum(TreeNode root) {
    maxSum = Integer.MIN_VALUE;
    gain(root);
    return maxSum;
}
static int gain(TreeNode node) {
    if (node == null) return 0;
    int left = Math.max(0, gain(node.left));    // ignore negative
    int right = Math.max(0, gain(node.right));
    maxSum = Math.max(maxSum, node.val + left + right);
    return node.val + Math.max(left, right);    // best single branch
}

The Tree DP Template

All three problems above follow the same structure:

  1. DFS / post-order traversal — solve children first
  2. Combine children results at each node
  3. Update a global answer (diameter, max sum, etc.)
  4. Return local result to parent (depth, gain, etc.)

The trick is figuring out what to return vs what to track globally:

Problem Return to Parent Track Globally
House Robber III[robValue, skipValue]--
Tree Diameterdepth (single branch)left + right (through node)
Max Path Summax gain (single branch)val + left + right (through node)

Graph DP: Floyd-Warshall

Trees are easy because there’s no cycles — we just do DFS. Graphs are trickier, but one classic graph DP is Floyd-Warshall for finding shortest paths between ALL pairs of nodes.

Idea: We try using each node k as an intermediate stop. Can going through k shorten the path from i to j?

State: dp[i][j] = shortest distance from node i to node j

Recurrence: dp[i][j] = min(dp[i][j], dp[i][k] + dp[k][j])

function floydWarshall(graph) {
  const n = graph.length;
  const dp = graph.map(row => [...row]);    // copy the adjacency matrix
  for (let k = 0; k < n; k++) {             // try each intermediate node
    for (let i = 0; i < n; i++) {
      for (let j = 0; j < n; j++) {
        if (dp[i][k] + dp[k][j] < dp[i][j]) {
          dp[i][j] = dp[i][k] + dp[k][j];  // shorter path through k
        }
      }
    }
  }
  return dp; // dp[i][j] = shortest path from i to j
}
// Time: O(V³), Space: O(V²)
def floyd_warshall(graph):
    n = len(graph)
    dp = [row[:] for row in graph]          # copy the adjacency matrix
    for k in range(n):                      # try each intermediate node
        for i in range(n):
            for j in range(n):
                if dp[i][k] + dp[k][j] < dp[i][j]:
                    dp[i][j] = dp[i][k] + dp[k][j]  # shorter path
    return dp  # dp[i][j] = shortest path from i to j
# Time: O(V³), Space: O(V²)
static int[][] floydWarshall(int[][] graph) {
    int n = graph.length;
    int[][] dp = new int[n][n];
    for (int i = 0; i < n; i++) dp[i] = graph[i].clone();
    for (int k = 0; k < n; k++) {           // try each intermediate node
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                if (dp[i][k] + dp[k][j] < dp[i][j]) {
                    dp[i][j] = dp[i][k] + dp[k][j];
                }
            }
        }
    }
    return dp;
}
// Time: O(V³), Space: O(V²)

Floyd-Warshall is O(V^3), which is fine for small graphs (up to ~500 nodes). For single-source shortest paths, Dijkstra is faster. But when we need all pairs, Floyd-Warshall is beautifully simple.

When to Use Tree/Graph DP

  • Tree DP: Whenever we need to compute something for every node based on its subtree. Post-order DFS is our friend.
  • Graph DP: When the problem has overlapping subproblems on a graph structure. Less common in interviews than tree DP, but Floyd-Warshall shows up regularly.
  • Key difference: Trees have no cycles, so memoization is usually not needed (each node is visited once). Graphs may need explicit visited tracking or topological ordering.