C++写红黑树

洛谷P3369

/*
    https://www.luogu/problem/P3369
*/
#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
#define INF 0x3f3f3f3f
#define maxn 105
#define minn -105
#define ll long long int
#define islc(x) ((x)&&(x)->fa->lc==(x))
#define isrc(x) ((x)&&(x)->fa->rc==(x))
#define bro(x) ((x)->fa->lc==(x)?((x)->fa->rc):((x)->fa->lc))

struct node
{
    node* fa; //父节点
    node* lc; //左孩子
    node* rc; //右孩子
    int val; //值
    bool RBc; //颜色,黑为0,红为1
    int num; //以this为根的子树的节点数
    node(int v,bool c,node* f=NULL,int s=1,node* l=NULL,node* r=NULL)
    {
        val=v;RBc=c;num=s;
        fa=f;lc=l;rc=r;
    }
    node(){ }
    void maintain() //维护当前节点数
    {
        num=1;
        if(lc)num+=lc->num;
        if(rc)num+=rc->num;
    }
    node* left_node() //寻找直接前驱,可能不在this下方
    {
        node* ptr=this;
        if(!lc) //无左子树,只能向上寻找
        {
            while(ptr->fa&&islc(ptr))ptr=ptr->fa; //只要当前ptr是左子节点,父节点肯定不比它小
            ptr=ptr->fa; //如果this是最小值,会返回NULL
        }
        else
        {
            ptr=ptr->lc;
            while(ptr->rc)ptr=ptr->rc;
        }
        return ptr;
    }
    node* right_node() //寻找直接后驱,可能不在this下方
    {
        node* ptr=this;
        if(!rc)
        {
            while(ptr->fa&&isrc(ptr))ptr=ptr->fa;
            ptr=ptr->fa;
        }
        else
        {
            ptr=ptr->rc;
            while(ptr->lc)ptr=ptr->lc;
        }
        return ptr;
    }
    node* succ() //后面删除过程中,要查找后继交换节点,即右支中值域最小节点
    {
        node* ptr=rc;
        while(ptr->lc)
        {
            ptr->num--; //注意这里num--
            ptr=ptr->lc;
        }
        return ptr;
    }
};

node* root(NULL); //树根

node* find_insert_pos(int val,int op) //查找插入的位置,tmp记录插入节点
{
    node* tmp=NULL;
    node* ptr=root;
    while(ptr)
    {
        tmp=ptr;
        ptr->num+=op; //在一些查找中顺带进行修改
        if(ptr->val>val)ptr=ptr->lc; //相等也往右走
        else ptr=ptr->rc;
    }
    return tmp;
}

node* find_node(int val,int op) //查找,同值节点中随便返回一个
{
    node* ptr=root;
    while(ptr&&(ptr->val!=val)) //与上面的差别在这,同值直接结束
    {
        ptr->num+=op;
        if(ptr->val>val)ptr=ptr->lc;
        else ptr=ptr->rc;
    }
    return ptr;
}

void left_revolve(node* cur_gfa,node* new_gfa) //左旋,一定要非空
{
    node* ptr=new_gfa->lc;
    if(cur_gfa==root)new_gfa->fa=NULL,root=new_gfa;
    else
    {
        if(islc(cur_gfa))cur_gfa->fa->lc=new_gfa; //维护此旋转子树与上方父节点的关系
        else cur_gfa->fa->rc=new_gfa;
        new_gfa->fa=cur_gfa->fa;
    }
    new_gfa->lc=cur_gfa,cur_gfa->fa=new_gfa;
    if(ptr)ptr->fa=cur_gfa;
    cur_gfa->rc=ptr; //可能为空
    cur_gfa->maintain(),new_gfa->maintain(); //维护当前节点数量
}

void right_revolve(node* cur_gfa,node* new_gfa) //右旋,一定要非空
{
    node* ptr=new_gfa->rc;
    if(cur_gfa==root)new_gfa->fa=NULL,root=new_gfa;
    else
    {
        if(islc(cur_gfa))cur_gfa->fa->lc=new_gfa;
        else cur_gfa->fa->rc=new_gfa;
        new_gfa->fa=cur_gfa->fa;
    }
    new_gfa->rc=cur_gfa,cur_gfa->fa=new_gfa;
    if(ptr)ptr->fa=cur_gfa;
    cur_gfa->lc=ptr;
    cur_gfa->maintain(),new_gfa->maintain();
}

/*插入完成后进行颜色修正,ptr颜色必为红,
ptr可能是新插入的叶子节点,也可能进行过递归操作的子树的根(所以黑高是不确定的)*/
void solve_double_red(node* ptr)
{
    while(!ptr->fa||ptr->fa->RBc) //向上迭代,出现双红就进行修正
    {
        if(ptr==root) //特判
        {
            ptr->RBc=0;
            return;
        }
        node* faptr=ptr->fa;
        node* uncptr=bro(faptr);
        node* gfaptr=faptr->fa;
        if(uncptr&&uncptr->RBc) //叔叔节点是红色,改色,向上递归进行修正
        {
            gfaptr->RBc=1;
            faptr->RBc=uncptr->RBc=0;
            ptr=gfaptr;
            continue;
        }
        //以下叔叔节点不存在或为黑色
        if(isrc(faptr)) //注意祖父关系与父子关系不同侧,其情况不同
        {
            if(islc(ptr))right_revolve(faptr,ptr); //非同侧,双红进行单旋,使之同侧,并且不改变向下的颜色性质
            faptr=gfaptr->rc; //faptr更新成当前真正的父节点
            left_revolve(gfaptr,faptr);
            faptr->RBc=0,gfaptr->RBc=1;
            break;
        }
        else
        {
            if(isrc(ptr))left_revolve(faptr,ptr);
            faptr=gfaptr->lc;
            right_revolve(gfaptr,faptr);
            faptr->RBc=0,gfaptr->RBc=1;
            break;
        }
    }
}

void insert_node(int val) //插入新点
{
    node* tmp=find_insert_pos(val,1); //返回插入位置的父节点
    if(!tmp) //插入节点的父节点
    {
        root=new node(val,0);
        return;
    }
    node* ptr=new node(val,1,tmp);
    if(tmp->val<=val)tmp->rc=ptr; //注意等于这时往右支走
    else tmp->lc=ptr;
    solve_double_red(ptr); //颜色修正,ptr一定为红
}

node* find_lower(int val) //查找最后一个比特定值严格小于的数
{
    node* ptr=root;
    node* tmp=NULL; //父节点
    while(ptr)
    {
        tmp=ptr;
        if(ptr->val<val)ptr=ptr->rc; //注意val相等的点插在右支
        else ptr=ptr->lc; //ptr->val==val往左走
    }
    if(!tmp)return NULL;
    if(tmp->val<val)return tmp;
    return tmp->left_node();
}

node* find_higher(int val) //查找第一个比特定值严格大于的数
{
    node* ptr=root;
    node* tmp=NULL;
    while(ptr)
    {
        tmp=ptr;
        if(ptr->val>val)ptr=ptr->lc; //相等默认往右走
        else ptr=ptr->rc;
    }
    if(!tmp)return NULL;
    if(tmp->val>val)return tmp;
    return tmp->right_node();
}

int get_k_rank(int val,node* ptr)
{
    if(!ptr)return 1;
    if(ptr->val>=val)return get_k_rank(val,ptr->lc);
    int ans=(ptr->lc)?ptr->lc->num:0; //这时查询值一定严格大于左支任意一个节点
    ans++;
    return ans+get_k_rank(val,ptr->rc); //右支可能还有比他小的
}

int query_rank_k(int rak,node* ptr) //ptr非空
{
    if(!ptr)return -1;
    if((ptr->lc)==NULL)
    {
        if(rak==1)return ptr->val;
        return query_rank_k(rak-1,ptr->rc);
    }
    if(ptr->lc->num>=rak)return query_rank_k(rak,ptr->lc);
    rak-=ptr->lc->num;
    if(rak==1)return ptr->val;
    else return query_rank_k(rak-1,ptr->rc);
}

void solve_double_black(node* ptr) //黑色修正,思路是让当前删除位置的一支黑高+1
{
    while(ptr!=root&&!ptr->RBc) //兄弟节点必定非空
    {
        node* broptr=bro(ptr);
        node* faptr=ptr->fa;
        if(broptr->RBc) //如果兄弟节点为红色,进行单旋
        {
            if(islc(broptr))right_revolve(faptr,broptr);
            else left_revolve(faptr,broptr);
            faptr->RBc=1,broptr->RBc=0;
            broptr=bro(ptr); //更新现在的bro
        }
        //注意接下来的兄弟节点一定非空,故它只能为黑色
        node* insdc; //兄弟内向儿子
        node* outsdc; //兄弟外向儿子
        if(islc(broptr))insdc=broptr->rc,outsdc=broptr->lc;
        else insdc=broptr->lc,outsdc=broptr->rc;
        if(!(insdc&&insdc->RBc)&&!(outsdc&&outsdc->RBc)) //兄弟两个孩子不存在或为黑
        {
            broptr->RBc=1; //兄弟为红
            if(!faptr->RBc) //父为黑,向上迭代修改
            {
                ptr=faptr;
                continue;
            }
            else //父为红,改成黑色,直接结束
            {
                faptr->RBc=0;
                break;
            }
        }
        if(outsdc&&outsdc->RBc) //兄弟有外侧红孩儿
        {
            if(islc(broptr))right_revolve(faptr,broptr);
            else left_revolve(faptr,broptr);
            outsdc->RBc=0,broptr->RBc=faptr->RBc,faptr->RBc=0;
            break;
        }
        if(insdc&&insdc->RBc) //兄弟内侧孩子为红色且外侧非红,修改成外侧红孩儿
        {
            if(islc(insdc))right_revolve(broptr,insdc);
            else left_revolve(broptr,insdc);
            broptr->RBc=1,insdc->RBc=0;
        }
    }
}

void remove_node(int val) //删除,思路是先进行值的临时交换,将删除值转移到叶节点,再进行颜色修正
{
    node* ptr=find_node(val,-1);
    node* next=NULL;
    while(ptr->lc||ptr->rc) //向下交换节点值域,而删除节点的值已无意义,故这里只要覆盖即可
    {
        if(!ptr->lc)next=ptr->rc; //只要是链型直接覆盖
        else if(!ptr->rc)next=ptr->lc;
        else next=ptr->succ();
        ptr->num--;
        ptr->val=next->val;
        ptr=next;
    }
    if(!ptr->RBc)
    {
        ptr->num--; //看似多余,但是在向上维护父节点节点域,这里这样写就方便了很多
        solve_double_black(ptr);
    }
    if(ptr==root) //特判根
    {
        root=NULL;
        delete ptr;
        return;
    }
    if(islc(ptr))ptr->fa->lc=NULL;
    else ptr->fa->rc=NULL;
    ptr->fa->maintain();
    delete ptr;
}

int main()
{
    int n;
    cin>>n;
    for(int i=0;i<n;i++)
    {
        int chose,k;
        cin>>chose>>k;
        switch(chose)
        {
            case 1:{insert_node(k);break;}
            case 2:{remove_node(k);break;}
            case 3:{cout<<get_k_rank(k,root)<<"\n";break;}
            case 4:{cout<<query_rank_k(k,root)<<"\n";break;}
            case 5:{cout<<find_lower(k)->val<<"\n";break;}
            case 6:{cout<<find_higher(k)->val<<"\n";break;}
        }
    }
    return 0;
}

更多推荐

C++写红黑树 翻仓库发现的大宝贝