kmjp's blog

競技プログラミング参加記です

Mujin Programming Challenge 2017 : D - Oriented Tree

終わってみるとコードはそこまで長くないんだよな。
http://mujin-pc-2017.contest.atcoder.jp/tasks/mujin_pc_2017_d

問題

木を成すN頂点の無向グラフが与えられる。
(N-1)個の辺に向きを付けた有向グラフを考える。
この時、2頂点(s,t)に対し以下が定義される。
d(s,t) := 辺を無向辺として頂点s→tの最短路を通る際、逆向きの辺を通過した回数
有向グラフにおける値Dとは、d(s,t)の最大値とする。

辺の向きの割り当てかた2^(N-1)通りのうち、Dが最小となるものはいくつあるか。

解法

グラフの直径をAとし、直径を成す2頂点を(u,v)とする。また、中心をc、e(s,t)をs,t間の距離とする。
d(u,v)+d(v,u)=Aなので、Aが偶数ならDが最小となるのはd(u,v)=d(v,u)=A/2の場合。

以下、Aが偶数の場合を考える。
頂点x,yに対し値h(x)を以下のように割り当てる。

  • x→yに辺が張られている場合、h(x)=h(y)+1

こうするとh(u)=h(v)+d(v,u)-d(u,v)より、h(u)=h(v)となる。
全頂点xに対しh(x)に同じ値を足したり引いたりしても上記定義に反しないため、仮にh(u)=h(v)=0とおく。
辺の向きを割り当てるということは、結局h(*)を定めることに相当する。
よってd(s,t)>A/2となる(s,t)が存在しないようDPでh(*)を求めていこう。

d(s,t)≦A/2であるためには、以下の条件を満たさなければならない。
d(s,t)=(e(s,t)-h(s)+h(t))/2、max(d(s,t),d(t,s))≦A/2より、|h(s)-h(t)|≦2*floor(A/2)-e(s,t)
Aが偶数なので|h(s)-h(t)|≦A-e(s,t)
ここで中心cと各頂点wに対し、|h(w)|≦A/2-e(c,w)であれば、別の頂点w'に対し、c-wの関係とc-w'の関係を足し合わせてに対し|h(w)-h(w')|≦|h(w)|+|h(w')|≦A-e(w,w')となるため、上記条件を全頂点に対して成り立たせることができる。
よって、中心cを根とした木DPを行い、|h(w)|≦A/2-e(c,w)の範囲でh(w)を総当たりし、h(w)がその値を取りうるようなSubTreeの数を総当たりしよう。

Aが奇数の場合はちょっとややこしい。
Dはceil(A/2)である。
よって直径を成す(u,v)に対しh(u)とh(v)の絶対値は1差があってもよい。
また、中心cを挟んで、vとは別にuと直径を成す頂点v'があるとき、h(v)とh(v')は異なっていてもよい(h(u)=0の時、片方は1でもう片方は-1であってよい)

そこで包除原理で数え上げていく。
直径が奇数なので、中心となる辺を挟んだ2頂点を(a,b)とする。a側にあるaの最遠点の集合をP、b側にあるbの最遠点の集合をQとする。
まず、以下を考える。

  • p∈Pに対しh(p)=0、q∈Qに対しh(q)=1か-1となるケースを数え上げる。
  • a側の頂点wに対し|h(w)|≦floor(A/2)-e(a,w)
  • b側の頂点wに対し|h(w)|≦floor(A/2)+1-e(b,w)
  • 最終的に、h(a)とh(b)に1差があるケースの積の総和を数える

同様に、逆のケースを考える。

  • p∈Pに対しh(p)=1か-1、q∈Qに対しh(q)=0となるケースを数え上げる。
  • a側の頂点wに対し|h(w)|≦floor(A/2)+1-e(a,w)
  • b側の頂点wに対し|h(w)|≦floor(A/2)-e(b,w)
  • 最終的に、h(a)とh(b)に1差があるケースの積の総和を数える

上記処理は、h(q)やh(p)が2種類の値を取りうるケースを考えているが、
「全頂点xに対しh(x)に同じ値を足したり引いたりしても上記定義に反しない」より、以下のケースは実は重複カウントしている。

  • 前者においてすべてのqに対しh(q)=1とした場合と、後者においてすべてのpに対しh(p)=-1とした場合
  • 前者においてすべてのqに対しh(q)=-1とした場合と、後者においてすべてのpに対しh(p)=1とした場合

よってこれらを取り除こう。
実際はh(p)とh(q)は1ずれるが、(a-b)間の辺の向きを無視し、h(p)=0になるケースとh(q)=0となるケースを下記の通り考える。

  • a側の頂点wに対し|h(w)|≦floor(A/2)-e(a,w)
  • b側の頂点wに対し|h(w)|≦floor(A/2)-e(b,w)

aとbの間の辺の向きによって、h(a)とh(b)の間で許容される値が以下の通りとのあるので、それぞれDPの結果を掛け合わせ引き算する。

  • a→bに辺が張られている場合、h(a)=h(b)またはh(a)=h(b)-2であってよい。
  • b→aに辺が張られている場合、h(a)=h(b)またはh(a)=h(b)+2であってよい。
int N;
vector<vector<int>> E;
vector<int> D;
ll dp[1020][1020];
ll mo=1000000007;

pair<int,int> farthest(vector<vector<int>>& E, int cur,int pre,int d,vector<int>& D) {
	D[cur]=d;
	pair<int,int> r={d,cur};
	FORR(e,E[cur]) if(e!=pre) r=max(r, farthest(E,e,cur,d+1,D));
	return r;
}

pair<int,vector<int>> diameter(vector<vector<int>>& E) { // diameter,center
	vector<int> D[2];
	D[0].resize(E.size());
	D[1].resize(E.size());
	auto v1=farthest(E,0,0,0,D[0]);
	auto v2=farthest(E,v1.second,v1.second,0,D[0]);
	farthest(E,v2.second,v2.second,0,D[1]);
	pair<int,vector<int>> R;
	R.first = v2.first;
	for(int i=E.size()-1;i>=0;i--) if(D[0][i]+D[1][i]==R.first && abs(D[0][i]-D[1][i])<=1) R.second.push_back(i);
	return R;
}

void dfs(int cur,int pre,int d,int md) {
	FORR(e,E[cur]) if(e!=pre) dfs(e,cur,d+1,md);
	for(int x=-503;x<=503;x++) if(abs(x) <= md-d){
		dp[cur][505+x]=1;
		FORR(e,E[cur]) if(e!=pre) (dp[cur][505+x] *= (dp[e][505+x+1]+dp[e][505+x-1]))%=mo;
	}
}

void solve() {
	int i,j,k,l,r,x,y; string s;
	
	cin>>N;
	E.resize(N);
	FOR(i,N-1) {
		cin>>x>>y;
		E[x-1].push_back(y-1);
		E[y-1].push_back(x-1);
	}
	
	auto dia = diameter(E);
	if(dia.first%2==0) {
		dfs(dia.second[0],-1,0,dia.first/2,0);
		cout<< accumulate(dp[dia.second[0]],dp[dia.second[0]]+1010,0LL)%mo<<endl;
	}
	else {
		ll ret=0;
		dfs(dia.second[0],dia.second[1],0,dia.first/2);
		dfs(dia.second[1],dia.second[0],0,dia.first/2+1);
		for(x=3;x<=1008;x++) (ret += dp[dia.second[0]][x]*(dp[dia.second[1]][x+1]+dp[dia.second[1]][x-1]))%=mo;
		
		ZERO(dp);
		dfs(dia.second[0],dia.second[1],0,dia.first/2+1);
		dfs(dia.second[1],dia.second[0],0,dia.first/2);
		for(x=3;x<=1008;x++) (ret += dp[dia.second[0]][x]*(dp[dia.second[1]][x+1]+dp[dia.second[1]][x-1]))%=mo;
		
		ZERO(dp);
		dfs(dia.second[0],dia.second[1],0,dia.first/2);
		dfs(dia.second[1],dia.second[0],0,dia.first/2);
		for(i=3;i<=1008;i++) ret -= 2*dp[dia.second[0]][i]*dp[dia.second[1]][i]%mo;
		for(i=3;i<=1008;i++) ret -= dp[dia.second[0]][i]*dp[dia.second[1]][i+2]%mo;
		for(i=3;i<=1008;i++) ret -= dp[dia.second[0]][i]*dp[dia.second[1]][i-2]%mo;
		cout<<(ret%mo+mo)%mo<<endl;
	}
	
	
}

まとめ

これ文章で書くのかなり面倒くさい…。
直径の両端h(u)とh(v)の値を考えるとわかりやすくなるね。