CSES Solutions – Tree Diameter

You are given a tree consisting of n nodes. The diameter of a tree is the maximum distance between two nodes. Your task is to determine the diameter of the tree.

Examples:

Input: n = 5, edges = { { 1, 2 }, { 1, 3 }, { 3, 4 }, { 3, 5 } };
Output: 3

Input: n = 4, edges = { { 1, 2 }, { 1, 3 }, { 3, 4 }};
Output: 3

Approach:

We can uses a Depth-First Search (DFS) algorithm to compute the distance of each node from a starting node. The algorithm first finds the farthest node from the starting node, then finds the farthest node from that farthest node. The distance between these two farthest nodes is the diameter of the tree.

Step-by-step algorithm:

  • Find the farthest node from the starting node using DFS.
  • Reset the distances of all the nodes.
  • Call DFS again, starting at the farthest node.
  • Find the farthest node from the farthest node.
  • The distance between these two farthest nodes is the diameter of the tree.

Below is the implementation of the algorithm:

C++
#include <bits/stdc++.h>
using namespace std;

// Define the maximum number of nodes
const int MAX = 2e5 + 5;
// Create an adjacency list for the tree
vector<int> adj[MAX];
// Create an array to store the distance of each node from
// the root
int dist[MAX];

// Depth-First Search function
void dfs(int node, int parent)
{
    // Iterate over all children of the current node
    for (auto child : adj[node]) {
        // If the child is not the parent
        if (child != parent) {
            // Update the distance of the child
            dist[child] = dist[node] + 1;
            // Recursively call DFS for the child
            dfs(child, node);
        }
    }
}

int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    // Number of nodes (static input)
    int n = 5;

    // Edges of the tree (static input)
    vector<pair<int, int> > edges
        = { { 1, 2 }, { 1, 3 }, { 3, 4 }, { 3, 5 } };

    // Add edges to the adjacency list
    for (auto edge : edges) {
        int a = edge.first, b = edge.second;
        adj[a].push_back(b);
        adj[b].push_back(a);
    }

    // Initialize the distance of the first node and call
    // DFS
    dist[1] = 0;
    dfs(1, 0);

    // Find the node with the maximum distance from the
    // first node
    int maxDist = 0, node = 1;
    for (int i = 1; i <= n; i++) {
        if (dist[i] > maxDist) {
            maxDist = dist[i];
            node = i;
        }
    }

    // Reset the distances and call DFS from the farthest
    // node
    for (int i = 1; i <= n; i++) {
        dist[i] = 0;
    }
    dfs(node, 0);

    // Find the maximum distance from the farthest node
    maxDist = 0;
    for (int i = 1; i <= n; i++) {
        maxDist = max(maxDist, dist[i]);
    }

    // Print the diameter of the tree
    cout << maxDist << "\n";

    return 0;
}
Java
import java.util.*;

public class TreeDiameter {
    static final int MAX = 200005;
    static List<Integer>[] adj = new ArrayList[MAX];
    static int[] dist = new int[MAX];

    static void dfs(int node, int parent) {
        for (int child : adj[node]) {
            if (child != parent) {
                dist[child] = dist[node] + 1;
                dfs(child, node);
            }
        }
    }

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);

        // Number of nodes
        int n = 5;

        // Edges of the tree
        int[][] edges = { { 1, 2 }, { 1, 3 }, { 3, 4 }, { 3, 5 } };

        for (int i = 1; i <= n; i++) {
            adj[i] = new ArrayList<>();
        }

        // Add edges to the adjacency list
        for (int[] edge : edges) {
            int a = edge[0], b = edge[1];
            adj[a].add(b);
            adj[b].add(a);
        }

        dist[1] = 0;
        dfs(1, 0);

        int maxDist = 0, node = 1;
        for (int i = 1; i <= n; i++) {
            if (dist[i] > maxDist) {
                maxDist = dist[i];
                node = i;
            }
        }

        Arrays.fill(dist, 0);
        dfs(node, 0);

        maxDist = 0;
        for (int i = 1; i <= n; i++) {
            maxDist = Math.max(maxDist, dist[i]);
        }

        System.out.println(maxDist);
    }
}


// This code is contributed by shiamgupta0987654321
Python3
# Create an adjacency list for the tree
adj = [[] for _ in range(2 * 10**5 + 5)]
# Create an array to store the distance of each node from the root
dist = [0] * (2 * 10**5 + 5)

# Depth-First Search function
def dfs(node, parent):
    # Iterate over all children of the current node
    for child in adj[node]:
        # If the child is not the parent
        if child != parent:
            # Update the distance of the child
            dist[child] = dist[node] + 1
            # Recursively call DFS for the child
            dfs(child, node)

# Number of nodes (static input)
n = 5

# Edges of the tree (static input)
edges = [(1, 2), (1, 3), (3, 4), (3, 5)]

# Add edges to the adjacency list
for a, b in edges:
    adj[a].append(b)
    adj[b].append(a)

# Initialize the distance of the first node and call DFS
dist[1] = 0
dfs(1, 0)

# Find the node with the maximum distance from the first node
maxDist = max(dist[1:])
node = dist.index(maxDist)

# Reset the distances and call DFS from the farthest node
for i in range(1, n + 1):
    dist[i] = 0
dfs(node, 0)

# Find the maximum distance from the farthest node
maxDist = max(dist[1:])

# Print the diameter of the tree
print(maxDist)

# This code is contributed by Ayush Mishra
JavaScript
class TreeDiameter {
    static dfs(node, parent, adj, dist) {
        for (const child of adj[node]) {
            if (child !== parent) {
                dist[child] = dist[node] + 1;
                TreeDiameter.dfs(child, node, adj, dist);
            }
        }
    }

    static main() {
        const MAX = 200005;
        const adj = new Array(MAX).fill().map(() => []);
        const dist = new Array(MAX).fill(0);

        // Number of nodes
        const n = 5;

        // Edges of the tree
        const edges = [[1, 2], [1, 3], [3, 4], [3, 5]];

        for (let i = 1; i <= n; i++) {
            adj[i] = [];
        }

        // Add edges to the adjacency list
        for (const edge of edges) {
            const [a, b] = edge;
            adj[a].push(b);
            adj[b].push(a);
        }

        dist[1] = 0;
        TreeDiameter.dfs(1, 0, adj, dist);

        let maxDist = 0;
        let node = 1;
        for (let i = 1; i <= n; i++) {
            if (dist[i] > maxDist) {
                maxDist = dist[i];
                node = i;
            }
        }

        dist.fill(0);
        TreeDiameter.dfs(node, 0, adj, dist);

        maxDist = 0;
        for (let i = 1; i <= n; i++) {
            maxDist = Math.max(maxDist, dist[i]);
        }

        console.log(maxDist);
    }
}

TreeDiameter.main();

Output
3

Time complexity: O(N), where N is the number of nodes in the tree.
Auxiliary space: O(N), where N is the number of nodes in the tree.