Sample Test Cases
Input: 11
1 4
1 9
4 5
4 3
4 2
9 10
9 7
3 6
3 8
3 11
2 4 3 3 1 2 4 3 2 4 1
3
3 4 7
Output: 3 4 1
Explanation:
1
/ \
/ \
4 9
/\ \
/  \  \
5 3 2 10 7
/\
/  \
6 8 11
colors : 2 4 3 3 1 2 4 3 2 4 1
Let,
col 1 = RED (R)
col 2 = BLUE (B)
col 3 = GREEN (G)
col 4 = YELLOW (Y)
B
/ \
/ \
G B
/\ \
/  \  \
R G Y Y Y
/\
/  \
B G R
Query 1: There are four nodes in the subtree rooted under 4 (3, 6, 8, 11) with the color (G, B, G, R). As there are three distinct colors, the answer is 3.
Query 2: There are seven nodes in the subtree rooted under 4 (4, 5, 3, 2, 6, 8, 11) with the color (G, R, G, Y, B, G, R). As there are four distinct colors, the answer is 4.
Query 3: 7 is a leaf node, so the answer is 1
Also read, Euclid GCD Algorithm
Approach
We can solve this problem by flattening the given tree and finding the number of unique elements in some specific subarrays of the flattened tree.
How to flatten the tree??
To flatten the given tree in the form of an array, start the DFS Algorithm traversal from the root node and store the start time (at which node is first visited) and end time (at which the node is visited for the last time) for every node.
While flattening the given tree, instead of storing the nodes in the array, we will store the color of the node.
Steps to flatten the given tree 
 Let the adjacency list 'G' store the given tree, and the arrays 'stTime[] ' and 'enTime[]' store the start and end times, respectively.
 To store the flattened tree, declare an array 'linearTree[]'.
 Declare a counter variable 'timer', which indicates the current time. Initialize it to zero.
 Declare a function void dfs(int node, int par, int col[]), to traverse the given tree. 'node' denotes the current node, 'par' denotes the parent of the current node, and the array 'col[]' denotes the color of each node.
 Increment timer by 1, and store it as the start time of the current node (stTime[node] = timer).
 Instead of storing the current node at linearTree[timer], store its color (linearTree[timer] = col[node]).
 Iterate over all the children of the current node and recursively call the function dfs for them.
 Increment timer by 1, and store it as the end time of the current node (stTime[node] = timer).
 Store color of current node at linearTree[timer].
 Call the dfs function by passing the parameters (1, 1, col), 1 is the root node, and the root node can't have the parent, so pass 1.
Now, we have the flattened tree in which the subarray [stTime[node], enTime[node]], denotes the subtree rooted under the vertex 'node'. So we have to find the number of distinct elements in the subarray [stTime[u], enTime[u]] for all the nodes 'u' given in the query.
Steps to solve 
 Make a struct 'data' to store the start time and end time of all the nodes given in the query. Also, store the index of queries to print the answer in the order of input.
 Declare a vector 'newQuery' of size Q and store the start time, end time, and index of each query in the form of data.
 Sort the vector newQuery with respect to start time and end time by using a comparator.
 Declare an adjacency list 'colIndex'. colIndex[u] stores all the indices in the linearTree[] with the value u. Iterate over all the i from 1 to 2 * N and push i to the vector colIndex[linearTree[i]].
 Use binary index tree for update and query.
 To count the ith element of linearTree[] in the query() method declare an array currentIndex[]. currentIndex[i] indicates the index of the smallest element in the vector colIndex[i] that is not marked in BIT.
 For every color i, update '1' at currentIndex[i] and increment currentIndex[i] by one.
 Declare a vector 'answer' of size Q to store the answers in the order of input.
 Now process the queries in the order of newQuery. If l, r, and idx denote the start time, end time, and index of the current query, respectively, then remove all the elements' occurrence up to l  1. After performing the updates, query(r) is the answer of the query at index idx.
 As the queries are sorted in the order of l, we can perform all the update operations in O(Q logN) by keeping a pointer 'j', which indicates that all the elements' occurrences up to j  1 are already removed.
Code
#include <bits/stdc++.h>
using namespace std;
const int M = 3e5 + 7;
//adjacency list to store given tree
vector <int> G[M];
//adjacency list to store the occurrence of each color in the flattened tree
vector <int> colIndex[M];
/*
currentIndex[i] indicates the index of the smallest element
in the vector colIndex[i] that is not marked in BIT
*/
int currentIndex[M];
//arrays to start time, end time and flattened tree, respectively
int stTime[M], enTime[M], linearTree[M];
//array to store binary index tree
int BIT[M];
//timer
int timer = 0;
//struct to keep start time, end time, and index as a single entity
struct data{
int l;
int r;
int idx;
};
/////////////////////BIT////////////////////////////
//function for update
void update(int idx, int change){
while(idx < M){
BIT[idx] += change;
idx += (idx & idx);
}
}
//function for query
int query(int idx){
int sum = 0;
while(idx > 0){
sum += BIT[idx];
idx = (idx & idx);
}
return sum;
}
////////////////////////////////////////////////////
//function to flatten the given tree
void dfs(int node, int par, int col[]){
//store the start time
stTime[node] = ++timer;
//instead of storing the nodes in the flattened tree, store its color
linearTree[timer] = col[node];
for(auto u : G[node]){
// If u is the parent of the current node then it is already visited
if(u == par)
continue;
//call dfs for children of the current node
dfs(u, node, col);
}
//store the end time
enTime[node] = ++timer;
//instead of storing the nodes in the flattened tree, store its color
linearTree[timer] = col[node];
}
//comparator to sort vector newQuery
bool cmp(data &x, data &y){
//sort w.r.t l
if(x.l != y.l)
return x.l < y.l;
//if l is same, sort w.r.t r
return x.r < y.r;
}
void solve(int &N, int &Q, vector <int> &Query){
//vector to store the start time, end time of node given in the query and query's index as a single entity
vector <data> newQuery(Q);
//storing queries in the form of data
for(int i = 0; i < Q; ++i){
newQuery[i].l = stTime[Query[i]];
newQuery[i].r = enTime[Query[i]];
newQuery[i].idx = i;
}
//sort vector newQuery using comparator cmp
sort(newQuery.begin(), newQuery.end(), cmp);
for(int i = 1; i <= 2 * N; ++i){
//store the occurrence of each color u in the vector colIndex[u]
colIndex[linearTree[i]].push_back(i);
//mark the first occurence of each color
if(colIndex[linearTree[i]].size() == 1){
update(i, 1);
//increment the pointer by one
currentIndex[linearTree[i]] = 1;
}
}
//vector to store answers in the order of input
vector <int> answer(Q);
//pointer to keep the track index up to which elements are removed
int j = 1;
for(auto u : newQuery){
for(; j < u.l; ++j){
int col = linearTree[j];
//remove the last occurence of col in the BIT
update(colIndex[col][currentIndex[col]  1], 1);
//if there is another occurrence of the same element
if(currentIndex[col] < colIndex[col].size()){
//add the smallest occurrence
update(colIndex[col][currentIndex[col]], 1);
++currentIndex[col];
}
}
//store the answer
answer[u.idx] = query(u.r);
}
//print the answers w.r.t input
for(auto u : answer)
cout << u << " ";
}
signed main(){
//number of node
int N;
cin >> N;
for(int i = 1; i < N; ++i){
int u, v;
//edge between u and v
cin >> u >> v;
//storing the tree in the adjacency list
G[u].push_back(v);
}
//color of nodes
int col[N + 1];
for(int i = 1; i <= N; ++i){
cin >> col[i];
}
//number of queries
int Q;
cin >> Q;
//vector to store the queries
vector <int> Query(Q);
for(int i = 0; i < Q; ++i){
cin >> Query[i];
}
//call dfs for root node
dfs(1, 1, col);
solve(N, Q, Query);
return 0;
}
Input
11
1 4
1 9
4 5
4 3
4 2
9 10
9 7
3 6
3 8
3 11
2 4 3 3 1 2 4 3 2 4 1
11
1 2 3 4 5 6 7 8 9 10 11
Output
4 1 3 4 1 1 1 1 2 1 1
Illustration
1
/ \
/ \
4 9
/\ \
/  \  \
5 3 2 10 7
/\
/  \
6 8 11
colors : 2 4 3 3 1 2 4 3 2 4 1
Let,
col 1 = RED (R)
col 2 = BLUE (B)
col 3 = GREEN (G)
col 4 = YELLOW (Y)
B
/ \
/ \
G B
/\ \
/  \  \
R G Y Y Y
/\
/  \
B G R
Flattened Tree : B G R R G B B G G R R G Y Y G B Y Y Y Y B B
newQuery (before sorting) : (1 22 0), (13 14 1), (5 12 2), (2 15 3), (3 4 4), (6 7 5), (19 20 6), (8 9 7), (16 21 8), (17 18 9), (10 11 10)
newQuery (after sorting) : (1 22 0), (2 15 3), (3 4 4), (5 12 2), (6 7 5), (8 9 7), (10 11 10), (13 14 1), (16 21 8), (17 18 9), (19 20 6)
answer (in the order of newQuery) : 4 4 1 3 1 1 1 1 2 1 1
answer (in the order of input) : 4 1 3 4 1 1 1 1 2 1 1
Time Complexity
The time complexity is O(Q log N)
Space Complexity
The space complexity is O(N).
FAQs

What is BIT?
BIT stands for binary index tree. It is a tree data structure where each node store the sum of specific elements. It is used to update and answer the range sum query efficiently.

What is the time complexity of update and query in BIT?
For both, it is O(logN), where N is the size of the array.

What is the subtree of a tree?
The subtree rooted at u represents descendants of a node u. It contains all the nodes below the node u.
Key Takeaways
In this article, we solved a problem on trees using the binary index tree/Fenwick tree. Fenwick tree is a very important data structure because it is efficient as well as easy to implement. Check out this coding ninjas' blog for getting a better hold on it.
Check out this problem  Longest Subarray With Sum K
To learn more about such data structures and algorithms, Coding Ninjas Studio is a onestop destination. This platform will help you to improve your coding techniques and give you an overview of interview experiences in various productbased companies by providing Online Mock Test Series and many more benefits.
Happy learning!