kmjp's blog

競技プログラミング参加記です

AtCoder ARC #135 : F - Delete 1, 4, 7, ...

うーん、本番でこういうの思いつく気がしないなぁ。
https://atcoder.jp/contests/arc135/tasks/arc135_f

問題

整数N,Kが与えられる。
整数列A=(1,2,...,N)を考える。

数列Aを、そこから1,4,7....,3k+1番目の要素を削除したものに置き換える処理をK回行ったとする。
最終的なAの総和を求めよ。

解法

以下、0-indexで考え、数列も0始まりとする。
最終的にK回処理後の|A|を解に加算すれば、そこのつじつまは合う。

1回処理前のAと処理後のA'は、A'[i]=floor(A[3*i/2])となるし、|A'|=floor(2*|A|/3)となる。

  • Kが十分(K≧38)大きい場合
    • 最終的なAのサイズは小さいので、上記式に基づきK回処理を巻き戻せば、もともとAのどこから来たか計算できる。
  • Kが小さい(K<38)の場合
    • f(x) = floor(3*i/2)とする。最終的なAに対し、A[i]=f^K(i)となるので、この和を求めよう。
    • f(x)は、f(x+a*(2^K))=f(x)+a*(3^K)となる。よって、Kがかなり(20以下)小さいならば、x % (2^K)が一致するようなxについて、f(x)をまとめて計算することができる。
    • とはいえこの問題ではKがもう少し大きい。そこで半分全列挙を行う。
    • X=ceil(K/2)、Y=floor(K/2)とし、f^X(x)を0≦x<2^Xの範囲で,f^Y(x)を0≦x<2^Yの範囲で計算しておく。
    • 求めたいのはsum(f^K(i))=sum(f^Y(f^X(i)))である。内側のカッコについて、iを2^Xで割った値で分類すると、i=a+b*2^Xであるような一連のiに対し、sum(f^Y(a)+f^Y(a+3^X)+f^Y(a+2*3^X)....)を求めることになる。
    • ここで、sum(f^Y(a)+f^Y(a+3^X)+...+f^Y(a+(2^c-1)*2^X))をダブリングで求めておけば、上記sumはaごとにO(logN)回計算することで求めることができる。
ll N,M,TM,K;
const ll mo=998244353;
ll H[1<<20][50];
ll p3[50];

int L,R;
ll rev(ll a) {
	return H[a&((1<<R)-1)][0]+(a>>R)*p3[R];
}

void solve() {
	int i,j,k,l,r,x,y; string s;
	
	p3[0]=1;
	FOR(i,48) p3[i+1]=p3[i]*3;
	
	cin>>N>>K;
	L=K/2;
	R=K-L;
	TM=M=N;
	FOR(i,K) M=2*M/3;
	FOR(i,R) TM=2*TM/3;
	ll ret=M;
	if(M<=5000000) {
		FOR(i,M) {
			ll cur=i;
			FOR(j,K) cur=(cur*3+2)/2;
			ret+=cur%mo;
		}
	}
	else {
		int mask;
		FOR(mask,1<<R) {
			H[mask][0]=mask;
			FOR(i,R) H[mask][0]=(H[mask][0]*3+2)/2;
		}
		for(i=1;i<=48;i++) {
			FOR(mask,1<<R) {
				if((N>>(i-1))<p3[L]) continue;
				ll f=mask+(p3[L]<<(i-1));
				ll a=f/(1<<R)%mo;
				ll b=f%(1<<R);
				(H[mask][i]=H[mask][i-1]+H[b][i-1]+((a*(p3[R]%mo)%mo)*((1LL<<(i-1))%mo)))%=mo;
			}
		}
		
		ll ss=0;
		FOR(mask,1<<L) {
			ll a=mask;
			FOR(i,L) a=(a*3+2)/2;
			if(rev(a)>=N) continue;
			ll b=a;
			for(i=48;i>=0;i--) {
				if(TM/p3[L]<(1LL<<i)) continue;
				if(b+(1LL<<i)*p3[L]<TM&&rev(b+(1LL<<i)*p3[L])<N) b+=(1LL<<i)*p3[L];
			}
			ll num=(b-a)/p3[L]+1;
			ll s=a;
			
			ss+=num;
			FOR(i,48) if(num&(1LL<<i)) {
				
				int cm=s&((1<<R)-1);
				ll t=s>>R;
				(ret+=H[cm][i]+(((t%mo)*(p3[R]%mo)%mo*((1LL<<i)%mo))))%=mo;
				s+=p3[L]<<i;
			}
			
		}
	}
	cout<<ret%mo<<endl;
}

まとめ

一つ一つのステップは理解できるものの、これを自力で詰められる気がしないなぁ…。