CSES Solutions – Path Queries
You are given a rooted tree consisting of n nodes. The nodes are numbered 1,2,. . . .,n, and node 1 is the root. Each node has a value.
Your task is to process following types of queries:
- change the value of node s to x
- calculate the sum of values on the path from the root to node s
Examples:
Input: N = 5, Q = 3, v[] = {4, 2, 5, 2, 1}, edges[][] = {{1, 2}, {1, 3}, {3, 4}, {3, 5}}, queries[][] = {{2, 4}, {1, 3, 2}, {2, 4}}
Output:
11
8Input: N = 10, Q = 2, v[] = {1, 8, 6, 8, 6, 2, 9, 2, 3, 2}, edges = {{10, 5}, {6, 2}, {10, 7}, {5, 2}, {3, 9}, {8, 3}, {1, 4}, {6, 4}, {8, 7}}, queries[][] = {{2, 10}, {2, 6}}
Output:
27
11
Approach: To solve the problem, follow the below idea:
To efficiently handle these queries, we’ll use a combination of depth-first search (DFS) and segment trees.
Depth-First Search (DFS):
We’ll perform a depth-first traversal of the tree to compute the start and end times for each node. This will allow us to efficiently represent the tree as an array and facilitate range queries.
Segment Trees:
A segment tree is an effective data structure for managing updates and range queries. To efficiently conduct range updates and queries, we’ll use a segment tree to represent the values associated with nodes in the tree.
Implementation:
- To determine the start and finish timings for each node, we will use DFS to represent the tree as an adjacency list. Next, in order to represent the values connected to nodes, we will initialize a segment tree using the calculated start and end times.
- We will adjust the segment tree to reflect the updated value of each type 1 query, which modifies a node’s value.
- We’ll utilize the segment tree to quickly and efficiently calculate the total of values in the range that corresponds to the path from the root to the provided node for each type 2 query (which asks for the sum of values on the path from the root to a particular node).
Step-by-step algorithm:
- Input: Read the tree’s size, initial node values, and queries.
- DFS: Traverse the tree using DFS to compute start and end times for each node.
- Segment Tree: Initialize a segment tree to represent node values efficiently.
- Initialization: Update the segment tree with initial node values and their negations.
- Query Types: Handle two types of queries: change node value and calculate sum on path.
- Type 1 Query: Update segment tree and node value accordingly.
- Type 2 Query: Query segment tree to calculate sum on the path.
- Updating Tree: When updating node values, update corresponding segment tree nodes.
- Querying Tree: When querying sum on a path, use segment tree to compute it efficiently.
- Output: Output the results of the queries as per the query types.
Below is the implementation of the above algorithm:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 400001;
const ll MOD = 1e9 + 7;
ll segTree[1600010], a[400010], n, q;
vector<int> adj[maxn];
int timer = 0, st[maxn], en[maxn];
// function to answer queries of type 1
ll query(int i, int j, int p = 1, int l = 0,
int r = 2 * n - 2)
{
// No overlap
if (i > j)
return 0;
// Complete overlap
if (l >= i and r <= j) {
return segTree[p];
}
// Partial overlap
int mid = (l + r) / 2;
ll left = query(i, min(j, mid), p * 2, l, mid);
ll right
= query(max(i, mid + 1), j, p * 2 + 1, mid + 1, r);
return left + right;
}
// function to update the segment tree
void update(int x, int val, int p = 1, int l = 0,
int r = 2 * n - 2)
{
// If l == r, then it is a leaf node within the range
if (l == r) {
segTree[p] = val;
return;
}
int mid = (l + r) / 2;
if (x <= mid) {
update(x, val, p * 2, l, mid);
}
else {
update(x, val, p * 2 + 1, mid + 1, r);
}
segTree[p] = segTree[p * 2] + segTree[p * 2 + 1];
}
// Function to run dfs on the graph
void dfs(int s, int p)
{
st[s] = timer++;
for (auto u : adj[s]) {
if (u != p) {
dfs(u, s);
}
}
en[s] = timer++;
}
int main()
{
// Sample Input
n = 5, q = 3;
int v[] = { 4, 2, 5, 2, 1 };
vector<vector<int> > edges
= { { 1, 2 }, { 1, 3 }, { 3, 4 }, { 3, 5 } };
vector<vector<int> > queries
= { { 2, 4 }, { 1, 3, 2 }, { 2, 4 } };
// 1-based indexing
for (int i = 1; i <= n; i++) {
a[i] = v[i - 1];
}
// Construct the graph
for (int i = 0; i < n - 1; i++) {
int a = edges[i][0];
int b = edges[i][1];
adj[a].push_back(b);
adj[b].push_back(a);
}
dfs(1, 0);
for (int i = 1; i <= n; i++) {
update(st[i], a[i]);
}
for (int i = 1; i <= n; i++) {
update(en[i], -a[i]);
}
for (int i = 0; i < queries.size(); i++) {
int t = queries[i][0];
int s = queries[i][1];
if (t == 1) {
int x = queries[i][2];
update(st[s], x);
update(en[s], -x);
}
else
cout << query(st[1], en[s] - 1) << endl;
}
}
import java.util.*;
public class Main {
static final int maxn = 400001;
static final int MOD = (int)1e9 + 7;
static long[] segTree = new long[1600010];
static long[] a = new long[400010];
static int n, q;
static List<Integer>[] adj = new ArrayList[maxn];
static int timer = 0;
static int[] st = new int[maxn];
static int[] en = new int[maxn];
// function to answer queries of type 1
static long query(int i, int j, int p, int l, int r)
{
// No overlap
if (i > j)
return 0;
// Complete overlap
if (l >= i && r <= j)
return segTree[p];
// Partial overlap
int mid = (l + r) / 2;
long left
= query(i, Math.min(j, mid), p * 2, l, mid);
long right = query(Math.max(i, mid + 1), j,
p * 2 + 1, mid + 1, r);
return left + right;
}
// function to update the segment tree
static void update(int x, int val, int p, int l, int r)
{
// If l == r, then it is a leaf node within the
// range
if (l == r) {
segTree[p] = val;
return;
}
int mid = (l + r) / 2;
if (x <= mid)
update(x, val, p * 2, l, mid);
else
update(x, val, p * 2 + 1, mid + 1, r);
segTree[p] = segTree[p * 2] + segTree[p * 2 + 1];
}
// Function to run dfs on the graph
static void dfs(int s, int p)
{
st[s] = timer++;
for (int u : adj[s]) {
if (u != p)
dfs(u, s);
}
en[s] = timer++;
}
public static void main(String[] args)
{
Scanner sc = new Scanner(System.in);
// Sample Input
n = 5;
q = 3;
int[] v = { 4, 2, 5, 2, 1 };
int[][] edges
= { { 1, 2 }, { 1, 3 }, { 3, 4 }, { 3, 5 } };
int[][] queries
= { { 2, 4 }, { 1, 3, 2 }, { 2, 4 } };
// 1-based indexing
for (int i = 0; i <= n; i++)
adj[i] = new ArrayList<>();
for (int i = 1; i <= n; i++)
a[i] = v[i - 1];
// Construct the graph
for (int i = 0; i < n - 1; i++) {
int a = edges[i][0];
int b = edges[i][1];
adj[a].add(b);
adj[b].add(a);
}
dfs(1, 0);
for (int i = 1; i <= n; i++)
update(st[i], (int)a[i], 1, 0, 2 * n - 2);
for (int i = 1; i <= n; i++)
update(en[i], (int)-a[i], 1, 0, 2 * n - 2);
for (int i = 0; i < queries.length; i++) {
int t = queries[i][0];
int s = queries[i][1];
if (t == 1) {
int x = queries[i][2];
update(st[s], x, 1, 0, 2 * n - 2);
update(en[s], -x, 1, 0, 2 * n - 2);
}
else
System.out.println(query(st[1], en[s] - 1,
1, 0, 2 * n - 2));
}
}
}
// This code is contributed by Ayush Mishra
import sys
from collections import defaultdict
maxn = 400001
MOD = int(1e9) + 7
segTree = [0]*1600010
a = [0]*400010
n, q = 0, 0
adj = defaultdict(list)
timer = 0
st = [0]*maxn
en = [0]*maxn
# function to answer queries of type 1
def query(i, j, p, l, r):
# No overlap
if i > j:
return 0
# Complete overlap
if l >= i and r <= j:
return segTree[p]
# Partial overlap
mid = (l + r) // 2
left = query(i, min(j, mid), p * 2, l, mid)
right = query(max(i, mid + 1), j, p * 2 + 1, mid + 1, r)
return left + right
# function to update the segment tree
def update(x, val, p, l, r):
# If l == r, then it is a leaf node within the range
if l == r:
segTree[p] = val
return
mid = (l + r) // 2
if x <= mid:
update(x, val, p * 2, l, mid)
else:
update(x, val, p * 2 + 1, mid + 1, r)
segTree[p] = segTree[p * 2] + segTree[p * 2 + 1]
# Function to run dfs on the graph
def dfs(s, p):
global timer
st[s] = timer
timer += 1
for u in adj[s]:
if u != p:
dfs(u, s)
en[s] = timer
timer += 1
def main():
global n, q, timer, a
# Sample Input
n, q = 5, 3
v = [4, 2, 5, 2, 1]
edges = [(1, 2), (1, 3), (3, 4), (3, 5)]
queries = [(2, 4), (1, 3, 2), (2, 4)]
# 1-based indexing
for i in range(1, n+1):
a[i] = v[i - 1]
# Construct the graph
for i in range(n - 1):
edge_a, edge_b = edges[i]
adj[edge_a].append(edge_b)
adj[edge_b].append(edge_a)
dfs(1, 0)
for i in range(1, n+1):
update(st[i], a[i], 1, 0, 2*n - 2)
update(en[i], -a[i], 1, 0, 2*n - 2)
for i in range(len(queries)):
t, s = queries[i][0], queries[i][1]
if t == 1:
x = queries[i][2]
update(st[s], x, 1, 0, 2*n - 2)
update(en[s], -x, 1, 0, 2*n - 2)
else:
print(query(st[1], en[s] - 1, 1, 0, 2*n - 2))
if __name__ == "__main__":
main()
const maxn = 400001;
const MOD = 1e9 + 7;
let segTree = Array(1600010).fill(0);
let a = Array(400010).fill(0);
let n, q;
let adj = new Array(maxn).fill(null).map(() => []);
let timer = 0;
let st = new Array(maxn).fill(0);
let en = new Array(maxn).fill(0);
// Function to answer queries of type 1
function query(i, j, p = 1, l = 0, r = 2 * n - 2) {
// No overlap
if (i > j) return 0;
// Complete overlap
if (l >= i && r <= j) return segTree[p];
// Partial overlap
let mid = Math.floor((l + r) / 2);
let left = query(i, Math.min(j, mid), p * 2, l, mid);
let right = query(Math.max(i, mid + 1), j, p * 2 + 1, mid + 1, r);
return left + right;
}
// Function to update the segment tree
function update(x, val, p = 1, l = 0, r = 2 * n - 2) {
// If l == r, then it is a leaf node within the range
if (l === r) {
segTree[p] = val;
return;
}
let mid = Math.floor((l + r) / 2);
if (x <= mid) {
update(x, val, p * 2, l, mid);
} else {
update(x, val, p * 2 + 1, mid + 1, r);
}
segTree[p] = segTree[p * 2] + segTree[p * 2 + 1];
}
// Function to run dfs on the graph
function dfs(s, p) {
st[s] = timer++;
for (let u of adj[s]) {
if (u !== p) {
dfs(u, s);
}
}
en[s] = timer++;
}
// Sample Input
n = 5, q = 3;
let v = [4, 2, 5, 2, 1];
let edges = [[1, 2], [1, 3], [3, 4], [3, 5]];
let queries = [[2, 4], [1, 3, 2], [2, 4]];
// Construct the graph
for (let i = 0; i < n - 1; i++) {
let a = edges[i][0];
let b = edges[i][1];
adj[a].push(b);
adj[b].push(a);
}
dfs(1, 0);
for (let i = 1; i <= n; i++) {
a[i] = v[i - 1];
update(st[i], a[i]);
}
for (let i = 1; i <= n; i++) {
update(en[i], -a[i]);
}
for (let i = 0; i < queries.length; i++) {
let t = queries[i][0];
let s = queries[i][1];
if (t === 1) {
let x = queries[i][2];
update(st[s], x);
update(en[s], -x);
} else {
console.log(query(st[1], en[s] - 1));
}
}
// This code is contributed by Ayush Mishra
Output
11 8
Time Complexity: O(n + q * log n), where n is the number of nodes in the tree and q is the number of queries.
Auxiliary Space: O(n), where n is the number of nodes.