#include #ifdef DUMP #include using std::cout; using std::endl; #endif typedef unsigned int Umword; typedef int Mword; template class Bin_tree; struct Bin_tree_node { private: template friend class Bin_tree; void recalc() { Umword leftHeight = (left == nullptr) ? 0 : left->height + 1; Umword rightHeight = (right == nullptr) ? 0 : right->height + 1; balance = leftHeight - rightHeight; height = (leftHeight < rightHeight ? rightHeight : leftHeight); } void replaceChild(Bin_tree_node* child, Bin_tree_node* replace) { assert(child == left || child == right); if(child == left) left = replace; else if(child == right) right = replace; if(replace) replace->parent = this; child->parent = nullptr; recalc(); } protected: constexpr Bin_tree_node() : parent(nullptr), left(nullptr), right(nullptr), tree(nullptr), balance(0), height(0) {} private: Bin_tree_node* parent; Bin_tree_node* left; Bin_tree_node* right; void* tree; // AVL Tree Mword balance; Umword height; }; template class Bin_tree_node_t : public Bin_tree_node { public: class Key_trait { public: typedef Key Key_type; static inline Key_type get_key(Bin_tree_node* node) { return static_cast* >(node)->key(); } static inline bool compare(Key_type a, Key_type b) { return (a < b); } }; Bin_tree_node_t(Key k) : Bin_tree_node(), _key(k) {} inline Key& key() { return _key; } private: Key _key; }; template class Bin_tree { protected: typedef Bin_tree_node Node; typedef typename Key_trait::Key_type Key_type; private: Node* root; Node* findByKey(Key_type key) const { Node* iterparent = nullptr; Node* iter = root; while(iter != nullptr) { iterparent = iter; if(Key_trait::compare(key, Key_trait::get_key(iter))) iter = iter->left; else if(Key_trait::compare(Key_trait::get_key(iter), key)) iter = iter->right; else break; } return iterparent; } Node* findLeft(Node* start) const { assert(start); Node* iter = start; while(iter->left != nullptr) { iter = iter->left; } return iter; } void leafRemove(Node* leaf) { assert(leaf); assert(leaf->left == nullptr || leaf->right == nullptr); Node* child; if(leaf->left == nullptr && leaf->right == nullptr) child = nullptr; else if(leaf->left == nullptr) child = leaf->right; else if(leaf->right == nullptr) child = leaf->left; else assert(false); if(leaf->parent) { leaf->parent->replaceChild(leaf, child); rebalance(leaf->parent, 1); } else { root = child; if(child) child->parent = nullptr; } leaf->parent = nullptr; leaf->left = nullptr; leaf->right = nullptr; leaf->tree = nullptr; leaf->recalc(); } void rebalance(Node* start, Mword abortCond) { for(Node* iter = start; iter != nullptr; iter = iter->parent) { iter->recalc(); Mword absBalance = (iter->balance < 0) ? -(iter->balance) : iter->balance; if(absBalance == 2) // needs rotation { // Left side is taller if(iter->balance > 0) { // Left side's right side is taller if(iter->left && iter->left->balance < 0) { rotateLeft(iter->left); } rotateRight(iter); } // Right side is taller else { // Right side's left side is taller if(iter->right && iter->right->balance > 0) { rotateRight(iter->right); } rotateLeft(iter); } // then abort break; } else if(absBalance == abortCond) // can abort break; } } void rotateLeft(Node* node) { assert(node); assert(node->right); Node* partner = node->right; node->right = partner->left; if(node->right) node->right->parent = node; // No parent, partner is the new root if(node->parent == nullptr) { root = partner; partner->parent = nullptr; } else node->parent->replaceChild(node, partner); node->parent = partner; partner->left = node; assert(node->parent == nullptr || node->parent->left == node || node->parent->right == node); assert(partner->parent == nullptr || partner->parent->left == partner || partner->parent->right == partner); assert(node->left == nullptr || node->left->parent == node); assert(node->right == nullptr || node->right->parent == node); assert(partner->left == nullptr || partner->left->parent == partner); assert(partner->right == nullptr || partner->right->parent == partner); node->recalc(); partner->recalc(); } void rotateRight(Node* node) { assert(node); assert(node->left); Node* partner = node->left; node->left = partner->right; if(node->left) node->left->parent = node; // No parent, partner is the new root if(node->parent == nullptr) { root = partner; partner->parent = nullptr; } else node->parent->replaceChild(node,partner); node->parent = partner; partner->right = node; assert(node->parent == nullptr || node->parent->left == node || node->parent->right == node); assert(partner->parent == nullptr || partner->parent->left == partner || partner->parent->right == partner); assert(node->left == nullptr || node->left->parent == node); assert(node->right == nullptr || node->right->parent == node); assert(partner->left == nullptr || partner->left->parent == partner); assert(partner->right == nullptr || partner->right->parent == partner); node->recalc(); partner->recalc(); } #ifdef DUMP void dump(Node* node, int indent = 0) { if(node == nullptr) return; assert(node->tree == this); assert(node->left == nullptr || node->left->parent == node); dump(node->left, indent+1); for(int i = 0; i < indent; i++) cout << " "; cout << Key_trait::get_key(node) << endl; assert(node->right == nullptr || node->right->parent == node); dump(node->right, indent+1); } #endif public: constexpr Bin_tree() : root(nullptr) {} bool insert(Node* node) { assert(node->tree == nullptr); Node* parent = nullptr; if(root == nullptr) // Tree is empty, add node as root root = node; else { // else, find a suitable parent parent = findByKey(Key_trait::get_key(node)); assert(parent); if(Key_trait::compare(Key_trait::get_key(node), Key_trait::get_key(parent))) parent->left = node; else if(Key_trait::compare(Key_trait::get_key(parent), Key_trait::get_key(node))) parent->right = node; else // Node with same key exists return false; } // setup node node->parent = parent; node->left = nullptr; node->right = nullptr; node->tree = this; node->height = 0; node->balance = 0; // rebalance tree if(parent) rebalance(parent,0); return true; } void remove(Node* node) { assert(node); assert(node->tree == this); if(node->left == nullptr || node->right == nullptr) { leafRemove(node); return; } Node* replacement = findLeft(node->right); leafRemove(replacement); if(node->parent == nullptr) { root = replacement; replacement->parent = nullptr; } else node->parent->replaceChild(node, replacement); replacement->left = node->left; replacement->right = node->right; replacement->tree = this; if(node->left) node->left->parent = replacement; if(node->right) node->right->parent = replacement; replacement->recalc(); node->left = nullptr; node->right = nullptr; node->parent = nullptr; node->tree = nullptr; node->recalc(); } Node* lookup(Key_type key) const { if(root == nullptr) return nullptr; Node* node = findByKey(key); if(node == nullptr) __builtin_unreachable(); if(Key_trait::compare(Key_trait::get_key(node),key) || Key_trait::compare(key, Key_trait::get_key(node))) return nullptr; return node; } inline bool in_tree(Node* node) { return (node->tree == this); } #ifdef DUMP void dump() { cout << "---------------------------------------------" << endl; dump(root); cout << "---------------------------------------------" << endl; } #endif }; template class Bin_tree_t : Bin_tree { private: typedef Bin_tree Base; typedef typename Key_trait::Key_type Key_type; public: inline bool insert(T* node) { return Base::insert(static_cast(node)); } inline void remove(T* node) { Base::remove(static_cast(node)); } inline T* lookup(Key_type key) { return static_cast(Base::lookup(key)); } inline bool in_tree(T* node) { return Base::in_tree(static_cast(node)); } #ifdef DUMP inline void dump() { Base::dump(); } #endif };