kmjp's blog

競技プログラミング参加記です

AtCoder ARC #143 : F - Counting Subsets

これだいぶ手間取った…。
https://atcoder.jp/contests/arc143/tasks/arc143_f

問題

正整数Nが与えられる。
1~Nの部分集合Sのうち、以下を満たすのは何通りか。

  • Sのいくつかの異なる要素の和として1~Nを表せ、かつそれぞれ表し方は2通り以下である。

解法

Sに2の累乗だけ入れると、1~Nの表し方は1通りに定まる。
以下、2の累乗でない数のうち最小値をaとする。以下aを総当たりしよう。
aより大きい最小の2の累乗を2^kとし、a未満の2の累乗はSに含まれているものとする。

ここで、Sの部分集合の和で作れるのは

  • 0~(a-1)が1通りずつ
  • a~(2^k-a-1)が2通りずつ
  • (2^k-a-1)~(2^k-2a-2)が1通りずつ

となる。例えばa=3の時、この値を並べると11122222111のようになる。
ここにさらに値を1個追加すると、位置をずらして2つこの並びを重ねた形になる。
例えばSに10を加えると11122222122122222111のようになる。

このように要素を加えて、この列の長さがN+1を超えるようにしよう。
((N+1)要素目の手前に並ぶ1の数+1)だけ、最後に追加可能な値となる。

この1と2の並びは二分木の要領で左右対称に並ぶので、どこの2の並びで長さN+1に到達するかを総当たりするとよい。

int N;
const ll mo=998244353;

ll dp[4030],S[4040];
int step[2515];
int stepS[2515][13];


ll hoge(int a,int b) {
	
	ll ret=0;
	int i,j;
	for(int num=1;num<=N;num++) {
		int s=0;
		while(1<<s<=num) s++;
		int tar=step[num];
		
		ZERO(dp);
		int mi=a+(num+1)/2*b+num/2*a;
		if(mi-a>=N) continue;
		
		if(num%2) {
			dp[a]=1;
		}
		else {
			dp[a+stepS[num][0]*b]=1;
		}
		
		for(i=1;i<s;i++) if(i!=tar) {
			int n=stepS[num][i];
			if(n==0) continue;
			//a以上2a以下
			FOR(j,N+1) {
				S[j]=dp[j];
				if(j>=n) (S[j]+=S[j-n])%=mo;
				
				if(j>=n*a) {
					dp[j]=S[j-n*a];
					if(j>=n*(2*a+1)) {
						(dp[j]+=mo-S[j-n*(2*a+1)])%=mo;
					}
				}
				else {
					dp[j]=0;
				}
			}
		}
		
		int n=stepS[num][tar];
		if(tar==0) {
			
			FOR(i,N) if(i+(n-1)*b<N&&i+n*b>=N) {
				ret+=dp[i];
			}
		}
		else if(tar==s-1) {
			FOR(i,N) if(dp[i]) {
				int lef=N-i;
				if(lef<=a) {
					(ret+=dp[i]*(lef+1))%=mo;
				}
				else if(lef<=2*a) {
					for(int l=lef;l<=2*a;l++) {
						if(l==2*a) {
							(ret+=dp[i]*(lef-a))%=mo;
						}
						else {
							(ret+=dp[i]*(lef+1-a))%=mo;
						}
						
					}
					
				}
			}
		}
		else {
			FOR(i,N) if(dp[i]) {
				for(int l=a;l<=2*a;l++) {
					if(i+(n-1)*l<N&&i+n*l>=N) {
						int lef=N-(i+(n-1)*l);
						int b2=2*a-l;
						int a2=(l-b2)/2;
						if(l==a*2) {
							(ret+=dp[i]*min(lef+1,a+1))%=mo;
						}
						else if(lef<=a2) {
							(ret+=dp[i]*(lef+1))%=mo;
						}
						else if(lef<=a) {
							(ret+=dp[i])%=mo;
						}
						else {
							(ret+=dp[i]*(lef+1-a))%=mo;
						}
						
					}
				}
			}
			
		}
		ret%=mo;
	}
	return ret;
	
}

void solve() {
	int i,j,k,l,r,x,y; string s;
	
	for(i=1;i<=2510;i++) {
		FOR(j,13) stepS[i][j]=stepS[i-1][j];
		x=i;
		while(x%2==0) x/=2, step[i]++;
		stepS[i][step[i]]++;
	}
	
	cin>>N;
	N++;
	ll ret=1;
	int step=0;
	for(i=1;i<N;i++) {
		k=0;
		while(1<<k<i) k++;
		y=1<<k;
		if(i==y) continue;
		int a=i;
		int b=y-i;
		
		ret+=hoge(a,b);
	}
	
	cout<<ret%mo<<endl;
}

まとめ

これ解けなくてだいぶ長い間放置してたので思い入れが強い。
問題設定がシンプルでいいね。