O(R*C*4^C)で書くつもりだったのに…。
http://yukicoder.me/problems/846
問題
R行C列の二次元グリッド状に人が並んでいる。
y行x列の人が倉敷市を知っている確率をP[y][x]とする。(各々の確率は独立である)
それぞれの人は、合計4ポイント以上を得ると手を挙げる。
ポイントの得方は以下の2通りである。
- 倉敷市を知っている人は人によって異なるポイント(4-S[y][x])ポイントを得る。
- 前及び左右隣に手を挙げている人がいると、1人あたり1ポイント得る。
最終的に手を挙げている人数の期待値を求めよ。
解法
O(R*C*4^C)解法
各人の手の上げ下げは、以下で決まる。
- 各人が倉敷市を知っているか否か
- 左右の人が手を挙げているか否か
- 前の人が手を挙げているか否か
後ろの人の影響は受けないため、前から順にDPしていくと良さそうなのは容易に想像がつく。
一方左右両方の人の影響を受けるので、左右は一方向にDPしていくことができない。
そこで、前から順に各行について以下を総当たりする。
- 今着目している行の人の倉敷の知ってる知らない(2^C)通り
- 前の人の手の上げ下げ(2^C)通り
(前の行から順に処理していくので)前の人の手の上げ下げに対応する確率はわかっているとして、今着目している行で誰が手を挙げるかを求めよう。
左から右に手を挙げる影響が伝搬するケースと、右から左に手を挙げる影響が伝搬するケースがあるので両方1回ずつ順次チェックしていくと良い。
これで今着目している行で誰が手を挙げるかがわかるし、その事象が起きる確率もわかる。
よって(手を挙げた人数)×(確率)を解に加算していくと良い。
またこの確率は、次の行の人の手の上げ下げを計算するのに利用できる。
R行でO*1の組み合わせを総当たりし、かつ伝搬処理が毎回O(C)かかるので計算量は全体でO(R*C*4^C)。
これでも4sはギリギリ間に合う。
O((R+C)*4^C)解法
writerさんに指名された感があるので、少し計算量を落としてみよう。
手を挙げる伝搬処理O(C)を前処理として先に済ませることを考える。
自分が倉敷市を知っている場合、手持ちのポイントは0~4のいずれかである。
これに前の人が手を挙げている場合1ポイント追加されるので、追加されないケースも含め手持ちのポイントは0~5のいずれかになる。
自分が手を挙げるには、まだ0~4ポイント足りない可能性がある。
両隣の人が手を挙げてもあと2ポイントしかもらえないので、3ポイント足りないと4ポイント足りないは結局同じである。
よって、横1行分の人が手を挙げるのにあと0~3ポイント不足している状態で、最終的に誰が手を挙げるかを前処理しよう。
これで1人あたり0~3で2bit、計(2*C)bitのbitmaskに対し、最終的に各人が手を上げる上げないC bitのbitmaskを返すテーブルが作れる。
この処理はO(C*4^C)である。
次に各行のループ内では、先に前の人の手の上げ方2^C通りを決める。
そうすると、仮に全員が倉敷市を知っていた場合、前の人も合わせて各人手を挙げるまであと何ポイント必要かを求められる。
そして内側のループで、倉敷市を知っている/以内の2^C通りを総当たりしよう。
倉敷市を知らない=絶対に手を上げない=手を挙げるまであと3ポイント必要である、とみなし、ビット演算を使って「各人手を挙げるまであと何ポイント必要か」の情報を更新すれば、あとは先ほど作ったテーブルを参照するだけでO(C)の処理を回避して最終的に手を挙げる人の状態を求めることができる。
これによりこのループ処理をO(R*4^C)で終えることができる。
よって前処理と合わせてO((R+C)*4^C)である。
ここではこれ以上頑張らないが、さらに早くすることも可能。
「倉敷市を知らない人は、前の人の状態がどうだろうと手を上げないのは確定する」と考えると、4^Cかかる一部のループを3^Cに抑えることも出来そうだ。
以下のコードは、fastはO((R+C)*4^C)、slowはO(R*C*4^C)のコードである。
R,Cはそんなに大きくないので、頑張っても3倍程度しか高速されないのね…。
int H,W; double P[101][101]; int S[101][101]; int memomask[1<<22]; int failmask[1<<11]; double dp[13][1<<11]; double fast() { int i,j,k,l,r,x,y,mask,cur,up; FOR(mask,1<<(2*W)) { int& stand=memomask[mask]; FOR(j,2) FOR(i,W) { x=(j)?i:(W-1-i); if(((stand>>(x))&1) + ((stand>>(x+2))&1) >= ((mask>>(2*x))&3)) stand |= 1<<(1+x); } stand>>=1; } FOR(mask,1<<W) FOR(x,W) if((mask&(1<<x))==0) failmask[mask] |= 3<<(2*x); double ret=0; dp[0][0]=1; FOR(y,H) { double p[1<<11]; FOR(cur,1<<W) { double pat=1; FOR(x,W) { if(cur&(1<<x)) pat *= P[y][x]; else pat *= 1-P[y][x]; } p[cur]=pat; } FOR(up,1<<W) if(dp[y][up]>1e-12) { int curmask=0; FOR(x,W) { r = 4-S[y][x]+((up>>x)&1); r = max(0,min(4-r,3)); curmask |= r<<(2*x); } FOR(cur,1<<W) dp[y+1][memomask[curmask | failmask[cur]]] += p[cur]*dp[y][up]; } FOR(mask,1<<W) ret += dp[y+1][mask]*__builtin_popcount(mask); } return ret; } double slow() { int i,j,k,l,r,x,y; string s; ZERO(dp); double ret=0; dp[0][0]=1; FOR(y,H) { for(int cur=0;cur<1<<W;cur++) { double pat=1; int hand=0; FOR(x,W) { if(cur&(1<<x)) { pat *= P[y][x]; if(S[y][x]==0) hand |= 1<<(x+1); } else pat *= 1-P[y][x]; } for(int up=0;up<1<<W;up++) if(dp[y][up]>1e-12) { int hand2=hand; FOR(j,2) FOR(i,W) { x=(j)?i:(W-1-i); if(cur&(1<<x)) { int p=4-S[y][x]; if(hand2&(1<<x)) p++; if(hand2&(1<<(x+2))) p++; if(up&(1<<x)) p++; if(p>=4) hand2 |= 1<<(x+1); } } hand2 >>=1; ret += pat*dp[y][up]*__builtin_popcount(hand2); dp[y+1][hand2] += pat*dp[y][up]; } } } return ret; } void solve() { int i,j,k,l,r,x,y; string s; cin>>H>>W; FOR(y,H) FOR(x,W) cin>>r, P[y][x] = r/100.0; FOR(y,H) FOR(x,W) cin>>S[y][x]; _P("%.12lf\n",fast()); //_P("%.12lf\n",slow()); }
まとめ
まともな解説、URL差し替えません?
*1:2^C)*(2^C