参考代码:
template <typename K, typename V>
struct AVLNode {
K key;
V value;
AVLNode *left, *right;
AVLNode(K key, V value) : key(key), value(value), left(NULL), right(NULL) {}
};
template <typename K, typename V>
class AVLTree {
typedef AVLNode<K, V> Node;
Node* root;
Node* put(Node* h, K const& key, V const& value) {
if (h == nullptr) return new Node(key, value);
if (key < h->key)
h->left = put(h->left, key, value);
else if (h->key < key)
h->right = put(h->right, key, value);
else
h->value = value;
return balance(h);
}
Node* rotationRR(Node* h) {
Node* x = h->right;
h->right = x->left;
x->left = h;
return x;
}
Node* rotationLL(Node* h) {
Node* x = h->left;
h->left = x->right;
x->right = h;
return x;
}
Node* rotationLR(Node* h) {
Node* x = h->left;
h->left = rotationRR(x);
return rotationLL(h);
}
Node* rotationRL(Node* h) {
Node* x = h->right;
h->right = rotationLL(x);
return rotationRR(h);
}
Node* balance(Node* x) {
int bal_factor = diff(x);
if (bal_factor > 1) {
if (diff(x->left) > 0)
x = rotationLL(x);
else
x = rotationLR(x);
} else if (bal_factor < -1) {
if (diff(x->right) > 0)
x = rotationRL(x);
else
x = rotationRR(x);
}
return x;
}
Node* get(Node* h, K const& key) {
if (!h) return nullptr;
if (h->key < key)
return get(h->left, key);
else if (key < h->key)
return get(h->left, key);
return h;
}
int height(Node* h) {
if (!h) return 0;
return max(height(h->left), height(h->right)) + 1;
}
int diff(Node* h) { return height(h->left) - height(h->right); }
void preorder(Node* node, void (*visit)(V)) {
if (!node) return;
visit(node->key);
preorder(node->left, visit);
preorder(node->right, visit);
}
public:
AVLTree() : root(nullptr) {}
void put(K const& key, V const& value) { root = put(root, key, value); }
void preorder(void (*visit)(V)) { preorder(root, visit); }
V& get(K const& key) {
Node* x = get(root, key);
if (x == nullptr) {
put(key, V{});
x = get(root, key);
}
return x->value;
}
V& operator[](K const& key) { return get(key); }
};