平衡二叉树 | 红黑树

注:下面这个版本基于《算法(第四版)》实现。这个红黑树是简化版本的。

enum Color { RED, BLACK };

template <class K, class V>
class RBTree {
    struct RBNode {
        K key;
        V value;
        size_t N;
        Color color;
        RBNode *left, *right;

        RBNode(K key, V value, size_t N, Color color)
            : key(key), value(value), N(N), color(color), left(nullptr), right(nullptr) {}
    };

    using Node = RBNode;

    Node* root = nullptr;

    const size_t size(const Node* h) {
        if (!h) return 0;
        return h->N;
    }

    const bool isRed(const Node* h) {
        if (h == nullptr) return false;
        return h->color == RED;
    }

    Node* put(Node* h, K const& key, V const& value) {
        if (h == nullptr) return new Node(key, value, 1, RED);

        if (key < h->key)
            h->left = put(h->left, key, value);
        else if (h->key < key)
            h->right = put(h->right, key, value);
        else
            h->value = value;

        // A red link appears on the right
        if (isRed(h->right) && !isRed(h->left)) h = rotateLeft(h);
        // Two continuous red links
        if (isRed(h->left) && isRed(h->left->left)) h = rotateRight(h);
        // Red links on both sides
        if (isRed(h->left) && isRed(h->right)) flipColors(h);

        h->N = 1 + size(h->left) + size(h->right);

        return h;
    }

    void flipColors(Node* h) {
        h->left->color = h->right->color = BLACK;
        h->color = RED;
    }

    Node* rotateLeft(Node* h) {
        Node* x = h->right;
        h->right = x->left;
        x->left = h;
        x->color = h->color;
        x->N = h->N;
        h->N = 1 + size(h->left) + size(h->right);
        return x;
    }

    Node* rotateRight(Node* h) {
        Node* x = h->left;
        h->left = x->right;
        x->right = h;
        x->color = h->color;
        x->N = h->N;
        h->N = 1 + size(h->left) + size(h->right);
        return x;
    }

    Node* get(Node* h, K const& key) {
        if (!h) return nullptr;

        if (h->key < key)
            return get(h->left, key);
        else if (key < h->key)
            return get(h->left, key);
        return h;
    }

   public:
    size_t size() { return size(root); }

    void put(K const& key, V const& value) {
        root = put(root, key, value);
        root->color = BLACK;
    }

    V& get(K const& key) {
        Node* x = get(root, key);
        if (x == nullptr) {
            put(key, V{});
            x = get(root, key);
        }
        return x->value;
    }

    V& operator[](K const& key) { return get(key); }
};