平衡二叉树 | AVL 树

四种不平衡的情况

参考代码:

template <typename K, typename V>
struct AVLNode {
    K key;
    V value;
    AVLNode *left, *right;

    AVLNode(K key, V value) : key(key), value(value), left(NULL), right(NULL) {}
};

template <typename K, typename V>
class AVLTree {
    typedef AVLNode<K, V>  Node;

    Node* root;

    Node* put(Node* h, K const& key, V const& value) {
        if (h == nullptr) return new Node(key, value);

        if (key < h->key)
            h->left = put(h->left, key, value);
        else if (h->key < key)
            h->right = put(h->right, key, value);
        else
            h->value = value;

        return balance(h);
    }

    Node* rotationRR(Node* h) {
        Node* x = h->right;
        h->right = x->left;
        x->left = h;
        return x;
    }

    Node* rotationLL(Node* h) {
        Node* x = h->left;
        h->left = x->right;
        x->right = h;
        return x;
    }

    Node* rotationLR(Node* h) {
        Node* x = h->left;
        h->left = rotationRR(x);
        return rotationLL(h);
    }

    Node* rotationRL(Node* h) {
        Node* x = h->right;
        h->right = rotationLL(x);
        return rotationRR(h);
    }

    Node* balance(Node* x) {
        int bal_factor = diff(x);
        if (bal_factor > 1) {
            if (diff(x->left) > 0)
                x = rotationLL(x);
            else
                x = rotationLR(x);
        } else if (bal_factor < -1) {
            if (diff(x->right) > 0)
                x = rotationRL(x);
            else
                x = rotationRR(x);
        }
        return x;
    }

    Node* get(Node* h, K const& key) {
        if (!h) return nullptr;

        if (h->key < key)
            return get(h->left, key);
        else if (key < h->key)
            return get(h->left, key);
        return h;
    }

    int height(Node* h) {
        if (!h) return 0;
        return max(height(h->left), height(h->right)) + 1;
    }

    int diff(Node* h) { return height(h->left) - height(h->right); }

    void preorder(Node* node, void (*visit)(V)) {
        if (!node) return;
        visit(node->key);
        preorder(node->left, visit);
        preorder(node->right, visit);
    }

   public:
    AVLTree() : root(nullptr) {}

    void put(K const& key, V const& value) { root = put(root, key, value); }

    void preorder(void (*visit)(V)) { preorder(root, visit); }

    V& get(K const& key) {
        Node* x = get(root, key);
        if (x == nullptr) {
            put(key, V{});
            x = get(root, key);
        }
        return x->value;
    }

    V& operator[](K const& key) { return get(key); }
};