In this blog, we will discuss a problem based on a binary search tree in which we are asked to find the sum of all nodes with smaller values at a distance ‘K’ from the given node in BST. Problems based on binary search trees are widely asked in competitive programming contests and various coding interviews. Here we will discuss the efficient approach to reach our result.
You are given a binary search tree, an integer ‘K’, and the target node. Your task is to find the sum of all such nodes at a distance ‘K’ from the target node in a binary search tree whose value is less than the target node.
Explanation
Let’s understand the problem statement with the help of an example as shown below:
Here suppose TARGET = 40 and K = 1. Then all the nodes with distance one from node 40 are {10,55,70}. And out of this, only 10 is smaller than node 40. Hence, in this case, 10 is our answer.
In the same example, suppose 100 is our target, then all nodes with distance one from node 100 are {70,85,120}. Out of them, 70 and 85 are smaller than 100. Therefore, in this case, we get 70 + 85 = 155 as our answer.
Approach
The above problem can be solved somehow if we get all nodes at a distance ‘K’ from the target node and then compare each one with the target node, and if the value of any node is less than the target node then we simply add it in our answer. This can be achieved by applying a depth-first search algorithm for all the left, right, and parent nodes of the target node.
Algorithm
Initialize the global variable SUM = 0 for storing the final answer.
Declare a function kDistSum() giving ‘ROOT’, ‘TARGET’, and ‘K’ as an input. This function returns the sum of all nodes which are at distance ‘K’ from ‘TARGET’ which are not its child node.
Description of kDistSum() function:
Parameters Involved:
ROOT - the root of the subtree.
TARGET - target node.
K - distance. Working of kDistSum() function:
If ROOT==NULLPTR :
Return -1.
If ROOT->VAL==TARGET:
sumDown(ROOT, K), for calculating the sum of all child nodes of ‘TARGET’ node which are at distance ‘K’.
Declare a variable ‘LEFTDIST’=-1:
TARGET< ROOT->VAL:
LEFTDIST=kDistSum(ROOT->LEFT, TARGET, K).
If LEFTDIST != -1:
If SUM==LEFTDIST+1:
SUM+=ROOT->VAL.
Return -1.
Declare a variable RIGHTDIST=-1:
kDistSum(ROOT->RIGHT, TARGET,K).
if RIGHT->DIST!=-1:
If SUM==RIGHTDIST+1:
SUM+=ROOT->VAL.
Else sumDown(ROOT->LEFT, K, RIGHTDIST-2)
Print the ‘SUM’.
Declare a function sumDown() giving ‘ROOT’ and ‘K’ as an input. This function returns the sum of all child nodes which are at distance ‘K’ from ‘TARGET’.
Working of sumDown() function:
If ROOT == NULLPTR and K < 0:
Return.
If K == 0:
SUM = SUM + ROOT->VAL;
Return.
sumDown(ROOT->LEFT, K-1).
sumDown(ROOT->RIGHT,K-1).
Program
#include <iostream>
#include <vector>
using namespace std;
// BST Node
class Node
{
public:
int val;
Node *left, *right;
Node(int val)
{
this->val = val;
this->right = nullptr;
this->left = nullptr;
}
};
// To store the sum.
int sum = 0;
Node *insertNode(Node *root, int val)
{
// Base Condition to terminate the recursive call.
if (root == nullptr)
{
Node *ptr = new Node(val);
return ptr;
}
// if the value of val is greater than root->val.
else if (val > root->val)
{
root->right = insertNode(root->right, val);
}
else
{
root->left = insertNode(root->left, val);
}
// return node after inserting.
return root;
}
void sumDown(Node *root, int k)
{
// Base Condition to terminate the Recursive call.
if (root == nullptr or k < 0)
return;
if (k == 0)
{
sum += root->val;
return;
}
sumDown(root->left, k - 1);
sumDown(root->right, k - 1);
}
// Function to get the final answer.
int kDistSum(Node *root, int target, int k)
{
// Basic condition to terminate the recursive call.
if (root == nullptr)
return -1;
// If target matches with rootvalue.
if (root->val == target)
{
sumDown(root->left, k - 1);
return 0;
}
// To store distance from target.
int leftDist = -1;
if (target < root->val)
{
leftDist = kDistSum(root->left,
target, k);
}
if (leftDist != -1)
{
if (leftDist + 1 == k)
sum += root->val;
return -1;
}
int rightDist = -1;
if (target > root->val)
{
rightDist = kDistSum(root->right, target, k);
}
if (rightDist != -1)
{
if (rightDist + 1 == k)
sum += root->val;
else
kDistSum(root->left, target, k - rightDist - 2);
return 1 + rightDist;
}
return -1;
}
int main()
{
int n;
cin >> n;
vector<int> vec(n);
for (int i = 0; i < n; i++)
{
cin >> vec[i];
}
Node *root = nullptr;
for (int i = 0; i < n; i++)
{
root = insertNode(root, vec[i]);
}
int target, k;
cin >> target >> k;
int temp = kDistSum(root, target, k);
cout << sum;
return 0;
}
Input
7
70 40 100 10 55 85 120
100 1
Output
155
Time Complexity
O(N), where ‘N’ is the number of nodes in the given binary search tree.
As every node in the binary tree is visited only once.
A Binary Search Tree is a special case of a Binary Tree where all the elements in the left subtree of a given node are smaller than the element itself and the elements in the right subtree are larger than it
What is Inorder Traversal?
Inorder Traversal is a Tree Traversal Technique in which we traverse the tree in the manner left→root→right that is, we print the value at the left node of a node first then the value of the node and after that, we move on to the right part.
Conclusion
In this blog, we have learned the most straightforward and efficient approach to solving the sum of all nodes with smaller values at a distance ‘k’ from the given node in a binary search tree. This problem is the variation of the BST traversal problem, so practice tree traversalquestions before moving further.
Hence learning never stops, and there is a lot more to learn.
So head over to our practice platform Coding Ninjas Studio to practice top problems, attempt mock tests, read interview experiences, and much more. Till then, Happy Coding!