You are given two binary search trees, the goal is produce a sorted array of elements containing elements from both the trees.
I wanted to know if there is a simpler approach, and if the use of boost variant is justified in the code. The code uses a stack to iterate over the binary search tree. The stack maintains the state for pre-order traversal of the tree.
#include <iostream>
#include <stack>
#include <boost/range/irange.hpp>
#include <boost/variant.hpp>
#include <random>
#include <algorithm>
using namespace std;
using boost::irange;
using boost::variant;
struct Node
{
Node(Node* l, Node* r, int v):left(l), right(r), val(v)
{
}
Node* left;
Node* right;
int val;
};
typedef variant<Node*, int> stkElemT;
typedef stack<stkElemT> bstStkT;
//from variant extract the pointer if valid otherwise NULL
struct stkElemVisitorNode : public boost::static_visitor<Node*>
{
Node* operator()(const int& val) const { return NULL; }
Node* operator()(Node*& ptr) const { return ptr; }
};
//from variant extract the integer value if valid otherwise -1
struct stkElemVisitorInt : public boost::static_visitor<int>
{
int operator()(const int& val) const { return val; }
int operator()(Node*& ptr) const { return -1; }
};
//expand left most path of top node.
void fillPathStkRecurse(bstStkT& bstStk)
{
stkElemT topE = bstStk.top();
Node* topN = boost::apply_visitor(stkElemVisitorNode(), topE);
if(topN != NULL) //
{
bstStk.pop();
if (topN->right)
bstStk.push(topN->right);
bstStk.push(topN->val);
if (topN->left)
{
bstStk.push(topN->left);
}
fillPathStkRecurse(bstStk);
}
else{
return; //top node is not a pointer but value
}
}
int getTopVal(const bstStkT& bstStk)
{
assert(!bstStk.empty());
stkElemT topE = bstStk.top();
int val = boost::apply_visitor(stkElemVisitorInt(), topE);
return val;
}
void incrBstStk(bstStkT& bstStk)
{
if(bstStk.empty()) return;
int topVal = getTopVal(bstStk);
assert(topVal != -1);
bstStk.pop();
if(!bstStk.empty())
fillPathStkRecurse(bstStk); //expand till child node
return;
}
Node* create_tree(vector<int>& vals, int start, int end) //end excluded
{
if(end==start)
return new Node(NULL, NULL, vals[start]);
if(end == start + 1)
{
Node* curr = new Node(NULL, NULL, vals[start]);
curr->right = new Node(NULL, NULL, vals[start+1]);
return curr;
}
int mid = floor((start + end)/2.0);
Node* left = create_tree(vals, start, mid-1);
Node* right = create_tree(vals, mid+1, end);
Node* curr = new Node(left, right, vals[mid]);
return curr;
}
vector<int> merge_bst(Node* root1, Node* root2)
{
vector<int> res;
bstStkT bstStk1;
bstStk1.push(root1);
fillPathStkRecurse(bstStk1);
bstStkT bstStk2;
bstStk2.push(root2);
fillPathStkRecurse(bstStk2);
while(1)
{
//cout<<"stk sizes = "<<bstStk1.size()<<" "<<bstStk2.size()<<endl;
if(bstStk1.empty() && bstStk2.empty())
break;
int val1 = numeric_limits<int>::max();
if(!bstStk1.empty())
val1 = getTopVal(bstStk1);
int val2 = numeric_limits<int>::max();
if(!bstStk2.empty())
val2 = getTopVal(bstStk2);
if(val1 < val2)//consume bstStk1
{
res.push_back(val1);
incrBstStk(bstStk1);
}
else
{
res.push_back(val2);
incrBstStk(bstStk2);
}
}
return res;
}
int main(int argc, char** argv)
{
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> uid5k(0, 1000); // distribution in range [1, 6]
int n = 10000;
for(auto k: irange(0, 10000))
{
vector<int> inVec1;
for(auto i: irange(0, n))
inVec1.push_back(uid5k(rng));
sort(inVec1.begin(), inVec1.end());
Node* root1 = create_tree(inVec1, 0, n-1);
vector<int> inVec2;
for(auto i: irange(0, n))
inVec2.push_back(uid5k(rng));
sort(inVec2.begin(), inVec2.end());
Node* root2 = create_tree(inVec2, 0, n-1);
vector<int> merged_vec(inVec1.begin(), inVec1.end());
merged_vec.insert(end(merged_vec), begin(inVec2), end(inVec2));
sort(begin(merged_vec), end(merged_vec));
auto res = merge_bst(root1, root2);
assert(res == merged_vec);
}
return 0;
}