kmjp's blog

競技プログラミング参加記です

TopCoder SRM 708 Div1 Medium PalindromicSubseq

何とか本番中に解けてよかった。
https://community.topcoder.com/stat?c=problem_statement&pm=14526

問題

N文字の文字列Sが与えられる。
数列X[i]の値を「Sの(非連続でもよい)部分文字列のうち、S[i]を含み、かつ回文であるものの数」とする。

sum((i+1)*X[i]) % (10^9+7)を答えよ。

解法

愚直にX[i]を求めて行こう。
S[i]を含む回文をどうもれなく列挙するかを考える。
部分文字列が回文であるならば、S[i]と対応する文字S[j]がどこかにあるはずである(S[i]が中心の場合、i=jとなる)。
これらは互いに重なり合う部分を持たない。
よって以下を求めよう。
F(L,R) := Sの非連続な部分文字列のうち、S[L]とS[R]を含む回文でかつS[L]とS[R]が対応するものの組み合わせ。
F(L,R)さえ求められれば、X[i] = sum_x(F(i,x))となる。

F(L,R)を、以下の4つの値を用いてDPしよう。
以下の値は第1引数と第2引数の順序はどちらでもいいので、対称性によりL≦Rの場合を考える。

  • Ins(L,R) := S[L..R]の部分文字列のうち、S[L]とS[R]を含む回文で、かつS[L]とS[R]が対応するものの組み合わせ
  • InsSum(L,R) := Ins(Gin(L',R')) (L≦L'≦R'≦R)
  • Out(L,R) := S[0..L]とS[R..(N-1)]から同数の文字を抽出した部分文字列のうち、前半と後半が互いに対称であり、かつS[L]とS[R]が対応するものの組み合わせ
  • OutSum(L,R) := sum(Out(L',R')) (0≦L'≦L、R≦R'≦N-1)

包除原理より、上記値は以下のように更新できる。

  • Ins(L,R) = InsSum(L+1,R-1) + 1 (S[L]==S[R]の場合。S[L]!=S[R]の場合0)
  • InsSum(L,R) = InsSum(L,R-1)+InsSum(L+1,R-1)-InsSum(L+1,R-1) + Ins(L,R)
  • Out(L,R) = OutSum(L+1,R-1) (S[L]==S[R]の場合。S[L]!=S[R]の場合0)
  • OutSum(L,R) = OutSum(L,R+1)+OutSum(L-1,R)-OutSum(L-1,R+1) + Out(L,R)

こうすると、F(L,R)はS[L],S[R]を含み、S[(L+1)...(R-1)]の部分文字列が回文で、かつS[0..L]とS[R..(N-1)]が対称なので、F(L,R) = Ins(L,R) * OutSum(L-1,R+1)で求められる。
上記すべての処理はO(N^2)なので何とか間に合う。
long long型の3K*3K配列を4つ作るとMLEするので、メモリ消費量だけ注意。

ll mo=1000000007;
ll X[3030];
int din[3030][3030];
int dins[3030][3030];
int douts[3030][3030];


class PalindromicSubseq {
	public:
	int solve(string S) {
		int N=S.size();
		int i,d,j;
		
		ZERO(din);
		ZERO(dins);
		ZERO(douts);
		
		for(i=1;i<=N;i++) din[i][i]=dins[i][i]=1;
		for(i=1;i<=N-1;i++) {
			din[i][i+1]=S[i-1]==S[i];
			dins[i][i+1]=2+din[i][i+1];
		}
		for(d=3;d<=N;d++) {
			for(i=1;i+d-1<=N;i++) {
				int j=i+d-1;
				if(S[i-1]==S[j-1]) {
					din[i][j]=dins[i+1][j-1]+1;
					if(din[i][j]>=mo) din[i][j]-=mo;
				}
				ll ret=(ll)dins[i+1][j] + dins[i][j-1] - dins[i+1][j-1] + din[i][j];
				while(ret<0) ret += mo;
				while(ret>=mo) ret -= mo;
				dins[i][j]=ret;
			}
		}
		
		FOR(i,N+2) douts[0][i]=douts[i][N+1]=1;
		for(d=N;d>=1;d--) {
			for(i=1;i+d-1<=N;i++) {
				int j=i+d-1;
				ll ret=douts[i-1][j] + douts[i][j+1] - douts[i-1][j+1] + mo;
				if(S[i-1]==S[j-1]) ret += douts[i-1][j+1];
				douts[i][j]=(ret+mo)%mo;
			}
		}
		
		ll ret=0;
		FOR(i,N) {
			ll tmp=0;
			FOR(j,N) if(S[i]==S[j]) {
				int x = min(i,j);
				int y = max(i,j);
				(tmp += 1LL*din[x+1][y+1]*douts[x][y+2])%=mo;
			}
			ret ^= (i+1)*tmp%mo;
		}
		return (int)ret;
		
	}
}

まとめ

もう少し速く解きたかったけど、まぁいい方か。