kmjp's blog

競技プログラミング参加記です

Codeforces ECR #060 : G. Recursive Queries

ECRのG、いつも実装が重めなんだよなぁ。
https://codeforces.com/contest/1117/problem/G

問題

Permutationとなっている数列Pが与えられる。
f(L,R)は、以下のように定義される

  • R<Lなら0
  • P[L..R]のうち最大値をP[M]とする。f(L,R)=(R-L+1)+f(L,M-1)+f(M+1,R)

Q個のクエリ(L,R)が与えられる。
それぞれf(L,R)を求めよ。

解法

f(L,R)の定義によると、f(L,R)を求める際はP[L..R]のうち大きな順にMを選択し、Mを含む有効な(未削除の)要素数の総和を取っていくことに相当する。
Mを取り除く瞬間、f(L',R')を求める際にA[L'...R']の最大値がMであったとする。
この際(R'-L'+1)が答えに加算されるだが、左右別々に考えることにする。
すなわち、L'は何でもよく、f(*,R')を求める際、P[M]がP[M...R']で最大なら(R'-M+1)、最大値が別途P[Q]があったとすると(Q-M)だけ解に加算される。

そこで、各Mに対し、right(M)は最寄りのP[M]<P[right(M)]となる要素だとする。
これはRMQなのでSegTreeで容易に実装できる。

次に、クエリをL順にソートしておき、平面走査していくのだが、以下の2つのBITを持っておく。

  • 各Mをキーとし、(right(M)-M)を値として、総和を高速に求めるBIT
  • 各Qをキーとし、right(M)=QとなるMの個数

平面走査の際、上記2つの値を更新・取得すると、f(L,R)のうち、各M∈[L,R]に対しmin(R,right(M))-Mの総和が取れる。

以後同様に逆向きの走査を行う。

int N,Q;
int P[1010101];
int L[1010101],R[1010101];

int Lma[1010101],Rma[1010101];
vector<int> delL[1010101],delR[1010101];
vector<int> QL[1010101],QR[1010101];
template<class V,int NV> class SegTree_1 {
public:
	vector<V> val;
	static V const def=0;
	V comp(V l,V r){ return max(l,r);};
	
	SegTree_1(){val=vector<V>(NV*2,def);};
	V getval(int x,int y,int l=0,int r=NV,int k=1) { // x<=i<y
		if(r<=x || y<=l) return def;
		if(x<=l && r<=y) return val[k];
		return comp(getval(x,y,l,(l+r)/2,k*2),getval(x,y,(l+r)/2,r,k*2+1));
	}
	void update(int entry, V v) {
		entry += NV;
		val[entry]=comp(v,val[entry]);
		while(entry>1) entry>>=1, val[entry]=comp(val[entry*2],val[entry*2+1]);
	}
};
SegTree_1<int,1<<20> st;

template<class V, int ME> class BIT {
public:
	V bit[1<<ME];
	V operator()(int e) {if(e<0) return 0;V s=0;e++;while(e) s+=bit[e-1],e-=e&-e; return s;}
	void add(int e,V v) { e++; while(e<=1<<ME) bit[e-1]+=v,e+=e&-e;}
};
BIT<ll,20> sum;
BIT<int,20> mu;

ll ret[1010101];

void solve() {
	int i,j,k,l,r,x,y; string s;
	
	scanf("%d%d",&N,&Q);
	for(i=1;i<=N;i++) {
		scanf("%d",&P[i]);
		Lma[i]=st.getval(P[i],N+1);
		delL[Lma[i]].push_back(i);
		st.update(P[i],i);
		sum.add(i,i-Lma[i]);
	}
	FOR(i,st.val.size()) st.val[i]=0;
	for(i=N;i>=1;i--) {
		Rma[i]=N+1-st.getval(P[i],N+1);
		delR[Rma[i]].push_back(i);
		st.update(P[i],N+1-i);
	}
	
	FOR(i,Q) {
		scanf("%d",&L[i]);
		QL[L[i]].push_back(i);
	}
	FOR(i,Q) {
		scanf("%d",&R[i]);
		QR[R[i]].push_back(i);
	}
	
	FOR(i,N+1) {
		FORR(q,QL[i]) ret[q]+=(sum(R[q])-sum(L[q]-1))-1LL*i*(mu(R[q])-mu(L[q]-1));
		FORR(x,delL[i]) {
			sum.add(x,x+1-(sum(x)-sum(x-1)));
			mu.add(x,1);
		}
	}
	ZERO(mu.bit);
	ZERO(sum.bit);
	for(i=1;i<=N;i++) {
		sum.add(i,Rma[i]-i-1);
	}
	for(i=N+1;i>=1;i--) {
		FORR(q,QR[i]) ret[q]+=(sum(R[q])-sum(L[q]-1))+1LL*i*(mu(R[q])-mu(L[q]-1));
		FORR(x,delR[i]) {
			sum.add(x,-x-(sum(x)-sum(x-1)));
			mu.add(x,1);
		}
	}
	
	FOR(i,Q) cout<<ret[i]<<" ";
	cout<<endl;
	
}

まとめ

説明が面倒でだいぶ手を抜いている…。