kmjp's blog

競技プログラミング参加記です

AtCoder ARC #139 : F - Many Xor Optimization Problems

式変形が難しすぎる…。
https://atcoder.jp/contests/arc139/tasks/arc139_f

問題

整数N,Mが与えられる。
0~(2^M-1)の値を取るN要素の数列は、(2^M)^N通り考えられる。
それぞれにおいて、いくつかの要素を選んでxorを取った場合の最大値の総和を求めよ。

解法

各数列において、基底ベクトルがK個からなる場合を考え、各bitが最大値に含まれるケースを列挙することになる。
詳細な式変形はEditorialを参照。
Editorialで1か所誤りがあるのと、計算方法が省略されている点を補足しておく。

誤っているのは、後半 2^{\frac{K(K-1)}{2}}を掛けている箇所。
これは割る方が正しい。 \displaystyle \prod_{i=1}^K 2^{X_i-(i-1)}の2^(-i+1)の部分をくくりだしたものであるため。

あと最後、 \displaystyle 2^{X_K} \times \prod_{i=1}^K 2^{X_i}の計算だが、多項式を2つペアにしてそれぞれ分割統治しながらNTTで掛け合わせていく。
上記の式は、先頭に2^{X_K}がついているので、後半のProdの部分で2^{X_K}より大きい値が掛け合わされていてはならない。
そこで、頭の2^{X_K}が確定した場合の多項式と、未確定の場合の多項式をペアで持つことで対応する。

ll N,M;
const ll mo=998244353;


ll modpow(ll a, ll n = mo-2) {
	ll r=1; a%=mo;
	while(n) r=r*((n%2)?a:1)%mo,a=a*a%mo,n>>=1;
	return r;
}

template<class T> vector<T> fft(vector<T> v, bool rev=false) {
	int n=v.size(),i,j,m;
	for(int m=n; m>=2; m/=2) {
		T wn=modpow(5,(mo-1)/m);
		if(rev) wn=modpow(wn);
		for(i=0;i<n;i+=m) {
			T w=1;
			for(int j1=i,j2=i+m/2;j2<i+m;j1++,j2++) {
				T t1=v[j1],t2=v[j2];
				v[j1]=t1+t2;
				v[j2]=ll(t1+mo-t2)*w%mo;
				while(v[j1]>=mo) v[j1]-=mo;
				w=(ll)w*wn%mo;
			}
		}
	}
	for(i=0,j=1;j<n-1;j++) {
		for(int k=n>>1;k>(i^=k);k>>=1);
		if(i>j) swap(v[i],v[j]);
	}
	if(rev) {
		ll rv = modpow(n);
		FOR(i,n) v[i]=(ll)v[i]*rv%mo;
	}
	return v;
}

template<class T> vector<T> MultPoly(vector<T> P,vector<T> Q,bool resize=false) {
	if(resize) {
		int maxind=0,pi=0,qi=0,i;
		int s=2;
		FOR(i,P.size()) if(norm(P[i])) pi=i;
		FOR(i,Q.size()) if(norm(Q[i])) qi=i;
		maxind=pi+qi+1;
		while(s*2<maxind) s*=2;
		P.resize(s*2);Q.resize(s*2);
		if(s<=16) { //fastpath
			vector<T> R(s*2);
			for(int x=0;x<2*s;x++) for(int y=0;x+y<2*s;y++) (R[x+y]+=P[x]*Q[y])%=mo;
			return R;
		}
	}
	P=fft(P), Q=fft(Q);
	for(int i=0;i<P.size();i++) P[i]=(ll)P[i]*Q[i]%mo;
	return fft(P,true);
}

ll P[525252];
ll B[252525];

void solve() {
	int i,j,k,l,r,x,y; string s;
	
	cin>>N>>M;
	
	P[0]=1;
	for(i=1;i<=520000;i++) P[i]=P[i-1]*(1+mo-modpow(2,i))%mo;
	ll v=1;
	for(i=1;i<=N;i++) {
		v=v*(modpow(2,i)-1)%mo*modpow(2,i-1)%mo;
		B[i]=P[N]*modpow(P[N-i])%mo*modpow(P[i])%mo;
		B[i]=B[i]*v%mo;
	}
	
	vector<pair<vector<ll>,vector<ll>>> Q;
	FOR(i,M) {
		vector<ll> a={1,modpow(2,i)};
		vector<ll> b={0,modpow(2,2*i)};
		Q.push_back({a,b});
	}
	while(Q.size()>1) {
		vector<pair<vector<ll>,vector<ll>>> Q2;
		for(i=0;i<Q.size();i+=2) {
			if(i==Q.size()-1) {
				Q2.push_back(Q[i]);
			}
			else {
				auto a=MultPoly(Q[i].first,Q[i+1].first,1);
				auto b=MultPoly(Q[i].first,Q[i+1].second,1);
				auto c=Q[i].second;
				b.resize(max(b.size(),c.size()));
				c.resize(max(b.size(),c.size()));
				FOR(j,b.size()) (b[j]+=c[j])%=mo;
				Q2.push_back({a,b});
			}
		}
		swap(Q,Q2);
	}
	vector<ll> C=Q[0].first,D=Q[0].second;
	C.resize(max(M,N)+2);
	D.resize(max(M,N)+2);
	
	ll ret=0;
	for(k=1;k<=min(M,N);k++) {
		ll a=B[k]*modpow(modpow(2,1LL*k*(k-1)/2))%mo;
		ll p=((modpow(2,M)+mo-1)*C[k]-(k+1)*C[k+1])%mo;
		ll q=2*D[k];
		ret+=a*(p+q-C[k])%mo;
	}
	
	cout<<(ret%mo*modpow(2)+mo)%mo<<endl;
	
	
}

まとめ

Editorialの誤りは心と時間を浪費するので、問題文やテストケースと同じぐらい気を配ってほしいなぁ…。