10010

定义 $f$ 函数:$f(0) = 0$,当 $x > 0$ 时

$$f(x) =
\begin{cases}
y & z = 0\
f(z) + 2 & z \neq 0 \land \mathrm{lowbit}(z) = \mathrm{lowbit}(x) \times 2 \
y & z \neq 0 \land \mathrm{lowbit}(z) \neq \mathrm{lowbit}(x) \times 2
\end{cases}$$

给出一个长度为 $n$ 的 $01$ 序列 $a$。有 $m$ 次操作,每次操作都形如以下的两种:

  • 1 l r:查询序列 $a$ 的区间 $[l, r]$ 组成的二进制数 $x$ 的 $f(x)$ 值。
  • 2 x:令 $a_x$ 反转($0$ 变成 $1$,$1$ 变成 $0$)。

$\dagger$ $\mathrm{lowbit}(i)$ 表示 $i$ 在二进制表示下最低位的 $1$ 及其后面所有的 $0$ 构成的数值。

$\dagger$ $\land$ 表示逻辑且。

每个测试点中包含多组测试数据。输入的第一行包含一个正整数 $T(1 \leq T \leq 110)$,表示数据组数。对于每组测试数据:

第一行两个正整数 $n, m(1 \leq n \leq 5.1 \times 10^5, 1 \leq m \leq 5 \times 10^5)$,分别表示字符串长度和操作次数。

第二行一个长度为 $n$ 的字符串,表示初始的 $01$ 序列 $a$。

接下来 $m$ 行,每行若干个整数,第一个数 $\mathrm{opt}$ 表示操作类型:

  • 若 $\mathrm{opt} = 1$,则后面两个数 $l, r(1 \leq l \leq r \leq n)$,表示查询的区间。
  • 若 $\mathrm{opt} = 2$,则后面一个数 $x(1 \leq x \leq n)$,表示修改的位置。

保证所有测试数据中 $n$ 之和不超过 $6.2 \times 10^5$,$m$ 之和不超过 $7.1 \times 10^5$。

时间限制 2000ms。

题目链接


线段树、树状数组、哈希

分析:实际上只需要求出在区间内部且包含区间右端点的形如 $100001000100$ 的串最长有多长,之后利用一些容易维护的信息就可以得到答案。

Try 1/Solution 1:注意到合法串内至多包含 $\sqrt{n}$ 个 $1$,求出区间中最靠右的 $1$ 之后逐个检查可以在 $O(\sqrt{n})$ 的时间复杂度内完成一次询问。使用链表维护全部 $1$ 的位置即可。

不一定能通过本题,时间复杂度 $O(n \sqrt{n})$。

耗时 1341ms 通过本题,猜测是由于杭电仅能安放一组数据,而较大数据会拿来卡线段树等解法,跑不满所以通过的。

在本地构造一个极长合法串测试,在单组数据 $n,m$ 拉满,开启 O2 优化,全为询问的情况下平均运行时间为 900ms,说明该做法可以在跑满的情况下通过本题。不开 O2 优化的运行时间约为开启 O2 优化时的 10 倍。

#include<bits/stdc++.h>
using namespace std;
list<int>::iterator ex[510'005];
int n,m,cnt,c[510'005];
char s[510'005];
void add(int x,int y){
    cnt+=y;
    while(x<=n) c[x]+=y,x+=x&-x;
}
int sum(int x){
    int ret=0;
    while(x) ret+=c[x],x-=x&-x;
    return ret;
}
int frank(int x){
    int ret=0;
    for(int i=1<<__lg(n);i;i>>=1) if(ret+i<=n&&c[ret+i]<x) ret+=i,x-=c[ret];
    return ret+1;
}
void solution(){
    cin>>n>>m,cnt=0;
    list<int> w;
    for(int i=1;i<=n;i++){
        cin>>s[i],s[i]-='0';
        if(s[i]){
            w.push_back(i);
            ex[i]=prev(w.end());
            add(i,1);
        }
    }
    for(int i=1;i<=m;i++){
        static int x,l,r,pre,opt;
        cin>>opt;
        if(opt==1){
            cin>>l>>r,pre=sum(r);
            if(pre==0){
                cout<<0<<'\n';
                continue;
            }
            pre=frank(pre);
            if(pre<l){
                cout<<0<<'\n';
                continue;
            }
            int ans=0,len=r-pre+1;
            auto it=ex[pre];
            while(it!=w.begin()){
                int las=*it;
                --it;
                if(*it<l||las-*it!=len+1) break;
                ++len,ans+=2;
            }
            ans+=len-1;
            cout<<ans<<'\n';
        }
        else{
            cin>>x,s[x]^=1;
            if(s[x]){
                pre=sum(x);
                if(pre==0){
                    w.push_front(x);
                    ex[x]=w.begin();
                }
                else if(pre==cnt){
                    w.push_back(x);
                    ex[x]=prev(w.end());
                }
                else{
                    pre=frank(pre+1);
                    ex[x]=w.insert(ex[pre],x);
                }
                add(x,1);
            }
            else{
                w.erase(ex[x]);
                add(x,-1);
            }
        }
    }
    for(int i=1;i<=n;i++) c[i]=0;
}
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    int T=1;
    cin>>T;
    while(T--) solution();
    return 0;
}

Try 2/Solution 2:二分区间内 $1$ 的数量,然后利用数据结构维护区间哈希值(类似子串哈希的手法)来检查,但这样是 $O(n \log^2 n)$ 的,不一定能通过本题(再试一下)。

由于区间内至多有 $\sqrt{n}$ 个 $1$,二分上界可以直接定为 $\sqrt{n}$,从而让二分的 $\log_2 \sqrt{n} = \log_2 n^{\frac{1}{2}} = \frac{1}{2} \log_2 n$ 带有一个 $\frac{1}{2}$ 的常数。

使用线段树维护区间哈希值(双哈希),为方便维护把 $01$ 串翻转,这样只用从左到右找形如 $010010001$ 的结构了。

居然也能过,平均耗时 1490ms,常数比较优秀。

#include<bits/stdc++.h>
using namespace std;
long long slen[510'005];
int n,m,uim,c[510'005];
char a[510'005];
template<int mod,int B,int O>//<模数,进制,1的对应值>
struct SegmentTree{
    int f[510'005],h[510'005],s[510'005*4];
    void set(){
        static constexpr int lim=510'000;
        f[0]=1,h[0]=0ll;//用 h 维护 1010010001 每个 1 位置上的前缀哈希值
        for(int i=1;i<=lim;i++) f[i]=1ll*f[i-1]*B%mod;
        for(int i=1;i<=lim;i++) h[i]=(1ll*h[i-1]*f[i]+O)%mod;
    }
    int pushup(int ls,int rs,int rlen){
        return (1ll*ls*f[rlen]+rs)%mod;
    }
    void build(int u,int l,int r){
        if(l==r) return s[u]=a[l]*O,void();
        int mid=(l+r)>>1;
        build(u*2,l,mid),build(u*2+1,mid+1,r);
        s[u]=pushup(s[u*2],s[u*2+1],r-mid);
    }
    void update(int u,int l,int r,int p){
        if(l==r) return s[u]=a[l]*O,void();
        int mid=(l+r)>>1;
        if(p<=mid) update(u*2,l,mid,p);
        else update(u*2+1,mid+1,r,p);
        s[u]=pushup(s[u*2],s[u*2+1],r-mid);
    }
    int query(int u,int l,int r,int ql,int qr){
        if(l>=ql&&r<=qr) return s[u];
        int ret=0,mid=(l+r)>>1;
        if(ql<=mid) ret=query(u*2,l,mid,ql,qr);//注意这里 pushup 的第三个参数不能是 qr-mid
        if(qr>mid) ret=pushup(ret,query(u*2+1,mid+1,r,ql,qr),min(r,qr)-mid);
        return ret;
    }
    bool equal(int l,int r,int sl,int sr){
        //判断区间 [l,r] 哈希值是否等于第 sl 个 1 到第 sr 个 1 的哈希值
        int L=query(1,1,n,l,r);
        int R=h[sr]-1ll*h[sl-1]*f[slen[sr]-slen[sl-1]]%mod;
        R=(R+mod)%mod;
        return L==R;
    }
};
SegmentTree<1'000'000'007,13,3> S1;
SegmentTree<998'244'353,17,11> S2;
void add(int x,int y){
    while(x<=n) c[x]+=y,x+=x&-x;
}
int sum(int x){
    int ret=0;
    while(x) ret+=c[x],x-=x&-x;
    return ret;
}
int frank(int x){
    int ret=0;
    for(int i=1<<__lg(n);i;i>>=1) if(ret+i<=n&&c[ret+i]<x) ret+=i,x-=c[ret];
    return ret+1;
}
int answer(int ql,int qr){
    if(sum(qr)-sum(ql-1)==0) return 0;
    int pos=frank(sum(ql-1)+1);
    int ret=1,len=pos-ql+1;
    int l=1,r=uim;
    while(l<=r){
        int mid=(l+r)>>1;
        if(slen[len+mid-1]-slen[len-1]>qr-ql+1){
            r=mid-1;
            continue;
        }
        int d=slen[len+mid-1]-slen[len-1];
        if(S1.equal(ql,ql+d-1,len,len+mid-1)&&S2.equal(ql,ql+d-1,len,len+mid-1)) ret=mid,l=mid+1;
        else r=mid-1;
    }
    return (ret-1)*2+(len+ret-2);
}
void solution(){
    cin>>n>>m,uim=0;
    for(int i=1;slen[i]<=n;i++) uim=i;
    for(int i=1;i<=n;i++) cin>>a[i];
    reverse(a+1,a+n+1);//翻转字符串
    for(int i=1;i<=n;i++){
        a[i]-='0';
        if(a[i]) add(i,1);
    }
    S1.build(1,1,n),S2.build(1,1,n);
    for(int i=1;i<=m;i++){
        static int x,l,r,opt;
        cin>>opt;
        if(opt==1){
            cin>>l>>r,swap(l,r);
            l=n-l+1,r=n-r+1;
            cout<<answer(l,r)<<'\n';
        }
        else{
            cin>>x,x=n-x+1,a[x]^=1;
            S1.update(1,1,n,x),S2.update(1,1,n,x);
            add(x,a[x] ? 1:-1);
        }
    }
    for(int i=1;i<=n;i++) c[i]=0;
}
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    static constexpr int lim=510'000;
    S1.set(),S2.set();
    for(int i=1;i<=lim;i++) slen[i]=slen[i-1]+i;
    int T=1;
    cin>>T;
    while(T--) solution();
    return 0;
}

使用树状数组维护区间哈希值(双哈希)。由于线段树能过,树状数组不可能不能过。注意树状数组维护的手法和线段树不同,采用了更简单的前缀和哈希而不是递推哈希。写线段树版本代码的时候我还不会前缀和哈希。

耗时 1045ms。

#include<bits/stdc++.h>
using namespace std;
int n,m,uim,c[510'005];
long long w[510'005];
char a[510'005];
template<int mod,int B>//<模数,进制>
struct BIT{
    int c[510'005],f[510'005],h[510'005],iw[510'005],inv[510'005];
    int mul(int x,long long y){
        int ret=1;
        while(y){
            if(y&1) ret=1ll*ret*x%mod;
            x=1ll*x*x%mod,y>>=1;
        }
        return ret;
    }
    void set(){
        static constexpr int lim=510'000;
        f[0]=1;
        for(int i=1;i<=lim;i++) f[i]=1ll*f[i-1]*B%mod;
        inv[lim]=mul(f[lim],mod-2);
        for(int i=lim;i>=1;i--) inv[i-1]=1ll*inv[i]*B%mod;
        iw[0]=1;
        for(int i=1;i<=lim;i++) iw[i]=1ll*iw[i-1]*inv[i]%mod;
        for(int i=1;i<=lim;i++) h[i]=(h[i-1]+mul(B,w[i]-1))%mod;
    }
    void add(int x,int y){
        while(x<=n) c[x]=(c[x]+y)%mod,x+=x&-x;
    }
    void update(int x,int k){add(x,k*f[x-1]);}
    int sum(int x){
        int ret=0;
        while(x) ret=(ret+c[x])%mod,x-=x&-x;
        return ret;
    }
    int hash(int l,int r){
        int ret=1ll*(sum(r)-sum(l-1))*inv[l-1]%mod;
        ret=(ret+mod)%mod;
        return ret;
    }
    bool equal(int l,int r,int ql,int qr){
        int x=1ll*(h[r]-h[l-1])*iw[l-1]%mod;
        x=(x+mod)%mod;
        return x==hash(ql,qr);
    }
    void clear(){
        for(int i=1;i<=n;i++) c[i]=0;
    }
};
BIT<1'000'000'007,97> C1;
BIT<998'244'353,31> C2;
void add(int x,int y){
    while(x<=n) c[x]+=y,x+=x&-x;
}
int sum(int x){
    int ret=0;
    while(x) ret+=c[x],x-=x&-x;
    return ret;
}
int frank(int x){
    int ret=0;
    for(int i=1<<__lg(n);i;i>>=1) if(ret+i<=n&&c[ret+i]<x) ret+=i,x-=c[ret];
    return ret+1;
}
void update(int x,int k){
    add(x,k);
    C1.update(x,k);
    C2.update(x,k);
}
int answer(int ql,int qr){
    if(sum(qr)-sum(ql-1)==0) return 0;
    int pos=frank(sum(ql-1)+1);
    int len=pos-ql+1;
    int l=1,r=uim,ret=1;
    while(l<=r){
        int mid=(l+r)>>1;
        if(w[len+mid-1]-w[len-1]>qr-ql+1){
            r=mid-1;
            continue;
        }
        int nr=ql+w[len+mid-1]-w[len-1]-1;
        if(!C1.equal(len,len+mid-1,ql,nr)||!C2.equal(len,len+mid-1,ql,nr)) r=mid-1;
        else ret=mid,l=mid+1;
    }
    ret=(len-1+ret-1)+(ret-1)*2;
    return ret;
}
void solution(){
    cin>>n>>m,uim=0;
    for(int i=1;w[i]<=n;i++) uim=i;
    for(int i=1;i<=n;i++) cin>>a[i];
    reverse(a+1,a+n+1);
    for(int i=1;i<=n;i++){
        a[i]-='0';
        if(a[i]) update(i,1);
    }
    for(int i=1;i<=m;i++){
        static int x,l,r,opt;
        cin>>opt;
        if(opt==1){
            cin>>l>>r,swap(l,r);
            l=n-l+1,r=n-r+1;
            cout<<answer(l,r)<<'\n';
        }
        else{
            cin>>x,x=n-x+1,a[x]^=1;
            if(a[x]) update(x,1);
            else update(x,-1);
        }
    }
    for(int i=1;i<=n;i++) c[i]=0;
    C1.clear(),C2.clear();
}
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    static constexpr int lim=510'000;
    for(int i=1;i<=lim;i++) w[i]=w[i-1]+i;
    C1.set(),C2.set();
    int T=1;
    cin>>T;
    while(T--) solution();
    return 0;
}

Fail 1:仍然维护前缀哈希值,在线段树上二分,时间复杂度 $O(n \log n)$。

需要利用已知信息(第一个 $1$ 在区间中出现的位置)$O(1)$ 计算一段区间的哈希值,否则无法做到 $O(n \log n)$ 的时间复杂度。

令区间 $[l,r]$ 中第一个 $1$ 的出现位置为 $p$,这个 $1$ 到区间左端点 $l$ 包含的字符数量为 $len$(包含 $1$ 本身),那么区间的子区间 $[l,k]$ 如果合法则必然为 $len$ 位置出现一个 $1$,$2len+1$ 位置出现一个 $1$ 直到下一个 $1$ 的位置超出区间右端点。

既然预处理了 $1010010001$ 这样的合法串每个 $1$ 位置上的哈希值,计算出区间 $[l,k]$ 中包含 $1$ 的数量就可以了,现在需要得到如下不等式的最小整数解 $x$(共 $x$ 项):

$$len+(len+1)+(len+2)+\cdots (len+x-1) > k-l+1$$

$$\frac{1}{2} (2len+x-1) x > k-l+1$$

$$\frac{2lenx+x^2-x}{2} > k-l+1$$

$$\frac{x^2 +(2len-1)x}{2} > k-l+1$$

$$x^2 + (2len-1) x > 2(k-l+1)$$

$$x^2 + (2len-1) x - 2(k-l+1) >0$$

由于 $len,k,l$ 都是常数,直接解方程 $x^2 + (2len-1) x - 2(k-l+1)=0$ 得到实数 $x$ 后下取整就可以了。对大质数取模,使用 __int128 类型存储乘法结果。

由于常数太大无法通过(幽默)。

Solution 3:直接维护所需信息,时间复杂度 $O(n \log n)$。

耗时 452ms 通过本题。

#include<bits/stdc++.h>
using namespace std;
struct Segment{int l,r,w,p,len,pre,suf,sel,ser;}s[510'005*4];
constexpr Segment emp=Segment{0,0,0,0,0,0,0,0,0};
//维护区间和 (w) 维护区间长度 (len) 维护区间前缀/后缀 0 数量 (pre/suf)
//维护区间去掉形如 1000 后缀之后剩余部分之后形成形如 100010010 最长合法连续段左端点位置 (p)
//维护区间第二靠左的 1 的位置 (sel) 维护区间第二靠右的 1 的位置 (ser)
int n,m,c[510'005];
void add(int x,int y){while(x<=n) c[x]+=y,x+=x&-x;}
int sum(int x){int ret=0;while(x) ret+=c[x],x-=x&-x;return ret;}
int frank(int x){
    int ret=0;
    for(int i=1<<__lg(n);i;i>>=1) if(ret+i<=n&&x>c[ret+i]) ret+=i,x-=c[ret];
    return ret+1;
}
void pushup(Segment &u,const Segment &ls,const Segment &rs){
    u.l=ls.l,u.r=rs.r;
    u.w=ls.w+rs.w;
    u.len=ls.len+rs.len;
    if(ls.w==0) u.pre=ls.pre+rs.pre;
    else u.pre=ls.pre;
    if(rs.w==0) u.suf=ls.suf+rs.suf;
    else u.suf=rs.suf;
    if(u.w<=1) u.sel=0;
    else if(ls.w>=2) u.sel=ls.sel;
    else if(ls.w==1) u.sel=rs.l+rs.pre;
    else u.sel=rs.sel;
    if(u.w<=1) u.ser=0;
    else if(rs.w>=2) u.ser=rs.ser;
    else if(rs.w==1) u.ser=ls.r-ls.suf;
    else u.ser=ls.ser;
    if(u.w<=1) u.p=0;
    else if(rs.p==0){
        if(rs.w==0) u.p=ls.p;
        else if(rs.w==1){
            if(ls.w==1) u.p=ls.r-ls.suf;
            else{
                int x=ls.suf+rs.pre;
                if(x+1==(ls.r-ls.suf)-ls.ser-1) u.p=ls.p;
                else u.p=ls.r-ls.suf;
            }
        }
    }
    else{
        if(rs.p!=rs.l+rs.pre) u.p=rs.p;
        else if(ls.w==0) u.p=rs.p;
        else if(ls.w==1){
            int x=ls.suf+rs.pre;
            if(x==(rs.sel-rs.p-1)+1) u.p=ls.r-ls.suf;
            else u.p=rs.p;
        }
        else{
            int x=ls.suf+rs.pre;
            if(x==(rs.sel-rs.p-1)+1){
                if(x+1==(ls.r-ls.suf)-ls.ser-1) u.p=ls.p;
                else u.p=ls.r-ls.suf;
            }
            else u.p=rs.p;
        }
    }
}
void build(int u,int l,int r){
    s[u].l=l,s[u].r=r,s[u].w=0,s[u].p=0;
    s[u].len=0,s[u].pre=0,s[u].suf=0;
    s[u].sel=0,s[u].ser=0;
    if(l==r){
        int x=cin.get();
        while(x!='0'&&x!='1') x=cin.get();
        x-='0',s[u].len=1;
        if(x==1) add(l,1),s[u].w=1;
        else s[u].pre=1,s[u].suf=1;
        return;
    }
    int mid=(l+r)>>1;
    build(u*2,l,mid),build(u*2+1,mid+1,r);
    pushup(s[u],s[u*2],s[u*2+1]);
}
void update(int u,int p){
    if(s[u].l==s[u].r){
        if(s[u].w==1){
            add(p,-1),s[u].w=0;
            s[u].pre=1,s[u].suf=1;
        }
        else{
            add(p,1),s[u].w=1;
            s[u].pre=0,s[u].suf=0;
        }
        return;
    }
    if(p<=s[u*2].r) update(u*2,p);
    else update(u*2+1,p);
    pushup(s[u],s[u*2],s[u*2+1]);
}
Segment query(int u,int l,int r){
    if(s[u].l>=l&&s[u].r<=r) return s[u];
    if(r<=s[u*2].r) return query(u*2,l,r);
    if(l>s[u*2].r) return query(u*2+1,l,r);
    Segment ret=emp;
    pushup(ret,query(u*2,l,r),query(u*2+1,l,r));
    return ret;
}
int answer(int l,int r){
    Segment x=query(1,l,r);
    if(x.w==0) return 0;
    if(x.w==1) return x.suf;
    if(x.suf+1!=(x.r-x.suf)-x.ser-1) return x.suf;
    int pre=sum(x.p),ret=(sum(x.r)-pre)*2;
    int pos=frank(pre+1);
    ret+=pos-x.p-1;
    return ret;
}
void solution(){
    cin>>n>>m;
    for(int i=1;i<=n;i++) c[i]=0;
    build(1,1,n);
    for(int i=1;i<=m;i++){
        static int l,r,x,opt;
        cin>>opt;
        if(opt==1){
            cin>>l>>r;
            cout<<answer(l,r)<<'\n';
        }
        else{
            cin>>x;
            update(1,x);
        }
    }
}
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    int T=1;
    cin>>T;
    while(T--) solution();
    return 0;
}
最后修改:2025 年 12 月 23 日
如果觉得我的文章对你有用,请随意赞赏