kmjp's blog

競技プログラミング参加記です

Codeforces #940 : Div2 F2. Frequency Mismatch (Hard Version)

これは実装しんどい。
https://codeforces.com/contest/1957/problem/F2

問題

N頂点の木を成すグラフが与えられる。
各点vには整数値A[v]が設定されている。

以下のクエリに答えよ。
2つのパスと、整数値Kが与えられる。
1つ目のパス上の点における整数値cの登場頻度をX(c)、2つ目のパスの点における整数値cの登場頻度をY(c)とする。
X(c)!=Y(c)となるcがZ個ある場合、min(Z,K)個答えよ。

解法

値に対するハッシュ値を定めておく。
整数値の区間[L,R]において、パス上に現れる区間内の整数値に対応するハッシュ値の総和をSegtreeで持って置く。
これにより、整数値の区間[L,R]内に、2つのパスで頻度が異なる整数値が1個以上あるかどうか高速に求められる。

区間の二分探索をK回繰り返し、そのような整数値をK個見つけよう。

template<class V> class PersistentSegTree_add {  //1点更新・区間和
public:
	V val;
	int L,R;
	PersistentSegTree_add *left, *right;

	PersistentSegTree_add(int L_,int R_,V v): L(L_),R(R_),val(v),left(NULL),right(NULL) {}
	PersistentSegTree_add(PersistentSegTree_add* p): L(p->L),R(p->R),val(p->val),left(p->left),right(p->right) {}
	V comp(V l,V r){ return l+r;}
	
	V getval(int x,int y) { // x<=i<y
		if(R<=x || y<=L) return 0;
		if(x<=L && R<=y) return val;
		V a=left?left->getval(x,y):0;
		V b=right?right->getval(x,y):0;
		
		return comp(a,b);
	}
	PersistentSegTree_add* update(int entry, V v) {
		PersistentSegTree_add* ret;
		if(L+1==R) {
			ret=new PersistentSegTree_add(L,R,comp(val,v));
		}
		else {
			int M=(L+R)/2;
			ret=new PersistentSegTree_add(L,R,0);
			ret->left=left;
			ret->right=right;
			if(entry<M) {
				if(ret->left==NULL) ret->left=new PersistentSegTree_add(L,M,0);
				else ret->left=new PersistentSegTree_add(left);
				ret->left=ret->left->update(entry,v);
			}
			else {
				if(ret->right==NULL) ret->right=new PersistentSegTree_add(M,R,0);
				else ret->right=new PersistentSegTree_add(right);
				ret->right=ret->right->update(entry,v);
			}
			V a=ret->left?ret->left->val:0;
			V b=ret->right?ret->right->val:0;
			ret->val=comp(a,b);
		}
		return ret;
		
	}
};

const ll mo1=1000000007;
const ll mo2=998244353;
PersistentSegTree_add<ll>* root1,*PST1[101010];
PersistentSegTree_add<ll>* root2,*PST2[101010];
ll po[2][202020];


int N,Q;
int A[101010];
vector<int> E[101010];
int P[21][200005],D[200005];

int lca(int a,int b) {
	int ret=0,i,aa=a,bb=b;
	if(D[aa]>D[bb]) swap(aa,bb);
	for(i=19;i>=0;i--) if(D[bb]-D[aa]>=1<<i) bb=P[i][bb];
	for(i=19;i>=0;i--) if(P[i][aa]!=P[i][bb]) aa=P[i][aa], bb=P[i][bb];
	return (aa==bb)?aa:P[0][aa];               // vertex
}

void dfs(int cur,int pre) {
	if(cur==0) {
		PST1[cur]=root1->update(A[cur],po[0][A[cur]]);
		PST2[cur]=root2->update(A[cur],po[1][A[cur]]);
	}
	else {
		PST1[cur]=PST1[pre]->update(A[cur],po[0][A[cur]]);
		PST2[cur]=PST2[pre]->update(A[cur],po[1][A[cur]]);
		D[cur]=D[pre]+1;
		P[0][cur]=pre;
	}
	

	FORR(e,E[cur]) if(e!=pre) dfs(e,cur);
	
}

ll val(
	PersistentSegTree_add<ll>* a1,
	PersistentSegTree_add<ll>* a2,
	PersistentSegTree_add<ll>* a3,
	PersistentSegTree_add<ll>* a4
) {
	ll ret=0;
	if(a1) ret+=a1->val;
	if(a2) ret+=a2->val;
	if(a3) ret-=a3->val;
	if(a4) ret-=a4->val;
	return ret;
}

PersistentSegTree_add<ll>* LL(PersistentSegTree_add<ll>* cur) {
	if(cur) return cur->left;
	return NULL;
}
PersistentSegTree_add<ll>* RR(PersistentSegTree_add<ll>* cur) {
	if(cur) return cur->right;
	return NULL;
}



void dfs2(int L,int R,
	PersistentSegTree_add<ll>* p1a1,
	PersistentSegTree_add<ll>* p1a2,
	PersistentSegTree_add<ll>* p1a3,
	PersistentSegTree_add<ll>* p1a4,
	PersistentSegTree_add<ll>* p2a1,
	PersistentSegTree_add<ll>* p2a2,
	PersistentSegTree_add<ll>* p2a3,
	PersistentSegTree_add<ll>* p2a4,
	PersistentSegTree_add<ll>* p1b1,
	PersistentSegTree_add<ll>* p1b2,
	PersistentSegTree_add<ll>* p1b3,
	PersistentSegTree_add<ll>* p1b4,
	PersistentSegTree_add<ll>* p2b1,
	PersistentSegTree_add<ll>* p2b2,
	PersistentSegTree_add<ll>* p2b3,
	PersistentSegTree_add<ll>* p2b4,
	vector<int>& ret) {
	
	if(ret.size()>=10) return;
	ll a1=val(p1a1,p1a2,p1a3,p1a4);
	ll a2=val(p2a1,p2a2,p2a3,p2a4);
	ll b1=val(p1b1,p1b2,p1b3,p1b4);
	ll b2=val(p2b1,p2b2,p2b3,p2b4);
	if(a1==b1&&a2==b2) return;
	
	if(L+1==R) {
		ret.push_back(L);
	}
	else {
		dfs2(L,(L+R)/2,LL(p1a1),LL(p1a2),LL(p1a3),LL(p1a4),LL(p2a1),LL(p2a2),LL(p2a3),LL(p2a4),LL(p1b1),LL(p1b2),LL(p1b3),LL(p1b4),LL(p2b1),LL(p2b2),LL(p2b3),LL(p2b4),ret);
		dfs2((L+R)/2,R,RR(p1a1),RR(p1a2),RR(p1a3),RR(p1a4),RR(p2a1),RR(p2a2),RR(p2a3),RR(p2a4),RR(p1b1),RR(p1b2),RR(p1b3),RR(p1b4),RR(p2b1),RR(p2b2),RR(p2b3),RR(p2b4),ret);
	}
	
}
	


void solve() {
	int i,j,k,l,r,x,y; string s;
	
	root1=new PersistentSegTree_add<ll>(0,1<<17,0);
	root2=new PersistentSegTree_add<ll>(0,1<<17,0);
	po[0][0]=po[1][0]=1;
	FOR(i,1<<17) {
		po[0][i+1]=po[0][i]*12345%mo1;
		po[1][i+1]=po[1][i]*123456%mo2;
	}
	
	cin>>N;
	FOR(i,N) cin>>A[i];
	FOR(i,N-1) {
		cin>>x>>y;
		E[x-1].push_back(y-1);
		E[y-1].push_back(x-1);
	}
	dfs(0,0);
	FOR(i,19) FOR(x,N) P[i+1][x]=P[i][P[i][x]];
	cin>>Q;
	while(Q--) {
		int u1,v1,u2,v2,k,lc1,lc2;
		cin>>u1>>v1>>u2>>v2>>k;
		u1--,v1--,u2--,v2--;
		lc1=lca(u1,v1);
		lc2=lca(u2,v2);
		vector<int> ret;
		
		dfs2(0,1<<17,PST1[u1],PST1[v1],PST1[lc1],(lc1)?PST1[P[0][lc1]]:NULL,
		     PST2[u1],PST2[v1],PST2[lc1],(lc1)?PST2[P[0][lc1]]:NULL,
		     PST1[u2],PST1[v2],PST1[lc2],(lc2)?PST1[P[0][lc2]]:NULL,
		     PST2[u2],PST2[v2],PST2[lc2],(lc2)?PST2[P[0][lc2]]:NULL,
		     ret);


		if(ret.size()>k) ret.resize(k);
		cout<<ret.size();
		FORR(r,ret) cout<<" "<<r;
		cout<<"\n";
	}
}

まとめ

永続データ構造系、なかなか本番に頑張って使おうとなりにくいな…。