えらく実装に手間取った問題。
http://codeforces.com/contest/280/problem/D
問題
N要素の数列A[i]が与えられる。
ここで以下のM個のクエリを処理せよ。
- l,rが与えられるので、A[l]=rに更新する。
- l,r,kが与えられるので、A[l]~A[r]のうち、交差しないK個以下の連続部分列の総和の最大値を答える。
解法
Editorialを見て解答。
最初Editorialを半分だけ見て解いたら、面倒なSegTree処理を書いたのにO(M*k^2*logN)かかりTLE。
後半もちゃんと読むとO(M*k*logN)の解法が書いてあった。
2つ目のクエリを考える。
ここは最小コストフローの考え方を用いて、以下のように解ける。
- A[l,r]の間で総和が最大になる連続部分列を求める。
- 上記総和が0なら終了。そうでなければ上記部分列の値の符号を反転する。
- 上記処理をK回繰り返す。
容量1のフローを流して、逆向きのフローを流す、という最小コストフローの処理に確かに似ている。
あとは符号反転を行いつつA[l,r]の最大部分列を高速に求めれば良い。
これにはセグメントツリーを用いて、A[l,r]の区間に対応する添え字kに対し以下の値を用いて処理していく。
- A[l]~A[r]の総和
- A[l]~A[r]のうち、両端の値を使わない最大連続部分列
- A[l]~A[r]のうち、A[l]からつながる最大連続部分列
- A[l]~A[r]のうち、A[r]からつながる最大連続部分列
- さらに、上記4つをA[l]~A[r]を符号反転した-A[l]~-A[r]に対し同様の処理をする。
値の更新クエリは1回O(logN)。
連続部分列の検索1回はO(logN)、符号反転もO(logN)、K回繰り返すとO(K*logN)。
よって全体でO(M*K*logN)。
const int NV=1<<17; int neg[NV*2]; int val[4][NV*2][2][2]; int LL[4][NV*2][2][2]; int RR[4][NV*2][2][2]; int N,M; void update_val(int entry, int t) { int (&v)[2][2]=val[t][entry], (&c0)[2][2]=val[t][entry*2], (&c1)[2][2]=val[t][entry*2+1]; int (&L)[2][2]=LL[t][entry], (&L0)[2][2]=LL[t][entry*2], (&L1)[2][2]=LL[t][entry*2+1]; int (&R)[2][2]=RR[t][entry], (&R0)[2][2]=RR[t][entry*2], (&R1)[2][2]=RR[t][entry*2+1]; v[0][0]=v[0][1]=v[1][0]=v[1][1]=-100000000; if(v[1][1]<c0[1][1]+c1[1][1]) v[1][1]=c0[1][1]+c1[1][1], L[1][1]=L0[1][1], R[1][1]=R1[1][1]; if(v[1][0]<c0[1][0]) v[1][0]=c0[1][0], L[1][0]=L0[1][0], R[1][0]=R0[1][0]; if(v[1][0]<c0[1][1]) v[1][0]=c0[1][1], L[1][0]=L0[1][1], R[1][0]=R0[1][1]; if(v[1][0]<c0[1][1]+c1[1][0]) v[1][0]=c0[1][1]+c1[1][0], L[1][0]=L0[1][1], R[1][0]=R1[1][0]; if(v[0][1]<c1[0][1]) v[0][1]=c1[0][1], L[0][1]=L1[0][1], R[0][1]=R1[0][1]; if(v[0][1]<c1[1][1]) v[0][1]=c1[1][1], L[0][1]=L1[1][1], R[0][1]=R1[1][1]; if(v[0][1]<c0[0][1]+c1[1][1]) v[0][1]=c0[0][1]+c1[1][1], L[0][1]=L0[0][1], R[0][1]=R1[1][1]; if(v[0][0]<c0[0][0]) v[0][0]=c0[0][0], L[0][0]=L0[0][0], R[0][0]=R0[0][0]; if(v[0][0]<c0[0][1]) v[0][0]=c0[0][1], L[0][0]=L0[0][1], R[0][0]=R0[0][1]; if(v[0][0]<c1[0][0]) v[0][0]=c1[0][0], L[0][0]=L1[0][0], R[0][0]=R1[0][0]; if(v[0][0]<c1[1][0]) v[0][0]=c1[1][0], L[0][0]=L1[1][0], R[0][0]=R1[1][0]; if(v[0][0]<c0[0][1]+c1[1][0]) v[0][0]=c0[0][1]+c1[1][0], L[0][0]=L0[0][1], R[0][0]=R1[1][0]; } void getval(int x,int y,int l,int r,int k) { int m=(l+r)/2; if(r<=x || y<=l) return; if(x<=l && y>=r && neg[k]>=0) { memmove(val[2][k],val[neg[k]][k],sizeof(val[2][k])); memmove(LL[2][k],LL[neg[k]][k],sizeof(LL[2][k])); memmove(RR[2][k],RR[neg[k]][k],sizeof(RR[2][k])); return; } if(neg[k]>-1) neg[k*2]=neg[k*2+1]=neg[k], neg[k]=-1; if(y<=m) { getval(x,y,l,m,k*2); memmove(val[2][k],val[2][k*2],sizeof(val[2][k])); memmove(LL[2][k],LL[2][k*2],sizeof(LL[2][k])); memmove(RR[2][k],RR[2][k*2],sizeof(RR[2][k])); return; } if(x>=m) { getval(x,y,m,r,k*2+1); memmove(val[2][k],val[2][k*2+1],sizeof(val[2][k])); memmove(LL[2][k],LL[2][k*2+1],sizeof(LL[2][k])); memmove(RR[2][k],RR[2][k*2+1],sizeof(RR[2][k])); return; } getval(x,y,l,m,k*2); getval(x,y,m,r,k*2+1); update_val(k,2); } void negateseg(int x,int y,int l,int r,int k) { int m=(l+r)/2; if(r<=x || y<=l) return; if(x<=l && y>=r && neg[k]>=0) { neg[k]^=1; return; } if(neg[k]>-1) neg[k*2]=neg[k*2+1]=neg[k], neg[k]=-1; negateseg(x,y,l,m,k*2); negateseg(x,y,m,r,k*2+1); } void update(int t, int entry, int v) { int x,y,ma=1; entry += NV; val[t][entry][0][0]=val[t][entry][0][1]=val[t][entry][1][0]=-100000000; val[t][entry][1][1]=v; LL[t][entry][1][1]=entry-NV; RR[t][entry][1][1]=entry-NV+1; while(entry>1) { entry>>=1; update_val(entry,t); } } void solve() { int i,j,k,l,r,x,y; string s; cin>>N; FOR(i,NV*2) val[0][i][0][0]=val[0][i][0][1]=val[0][i][1][0]=val[0][i][1][1]=-100000000; FOR(i,NV*2) val[1][i][0][0]=val[1][i][0][1]=val[1][i][1][0]=val[1][i][1][1]=-100000000; FOR(i,N) { cin>>r; update(0,i+1,r); update(1,i+1,-r); } cin>>M; while(M--) { cin>>x>>l>>r; if(x==0) { update(0,l,r); update(1,l,-r); } else { cin>>k; int ret=0, ma=1; neg[1]=0; while(k-->0) { getval(l,r+1,0,NV,1); int ma=0; int l2=NV,r2=0; if(ma<val[2][1][0][0]) ma=val[2][1][0][0], l2=LL[2][1][0][0], r2=RR[2][1][0][0]; if(ma<val[2][1][0][1]) ma=val[2][1][0][1], l2=LL[2][1][0][1], r2=RR[2][1][0][1]; if(ma<val[2][1][1][0]) ma=val[2][1][1][0], l2=LL[2][1][1][0], r2=RR[2][1][1][0]; if(ma<val[2][1][1][1]) ma=val[2][1][1][1], l2=LL[2][1][1][1], r2=RR[2][1][1][1]; if(ma<=0) break; ret+=ma; negateseg(l2,r2,0,NV,1); } cout << ret << endl; } } }
まとめ
問題設定は割とシンプルなのに、えらい実装が手間。
本番ほとんど正答者いないしね…。