すごく良くはないが、まぁレート増なので良いか…。
https://atcoder.jp/contests/arc207/tasks/arc207_a
問題
N個のランプがあり、i番のランプのコストはA[i]である。
いずれもランプは灯っていない。
1個ずる選択してランプを灯らせることを考える。
その際、その時点のA[i]分のコストがかかる。その後、まだ灯らせていないランプのコストが正の値であれば、それらをデクリメントできる。
ランプの灯らせ方N!通りに対し、総コストがX以下となる順番は何通りか。
解法
元のコストから、どれだけコストを削減できるかを考える。
i回目にランプjをともす場合、min(A[j],i-1)だけコストを削減できることになる。
まずランプをコスト昇順に並べておく。
dp(p,n,left,sum) := コストの大きい順に考える。コスト削減量p以上のランプを処理済みで、n番以降のランプを処理済みの場合、leftがまた位置未確定で、コスト削減量の和がsumとなるような組み合わせ
としてテーブルを考える。p,nを以下の通りデクリメントしていこう。
- p>min(N-1,A[n-1])のとき、p回目の点灯をどれにするかを考える。
- A[i]≧pであるものから選ぶ場合、dp(p-1,n,left-1,sum+p-1) += dp(p,n,left,sum)
- A[i]<pであるものから選ぶ場合、dp(p-1,n,left,sum+p) += dp(p,n,left,sum)
- p=min(N-1,A[n-1])の時、(n-1)番目のランプをいつ転倒するかを考える。
- コストがmin(N-1,A[n-1])だけ減るのは、min(N-1,A[n-1])+1回目以降で転倒するとき。dp(p,n-1,left,sum+min(N-1,A[n-1])) += dp(p,n,left,sum)
- コストがmin(N-1,A[n-1])未満だけ減るのは、min(N-1,A[n-1])回目以前で転倒するとき。dp(p,n-1,left+1,sum) += dp(p,n,left,sum)
上記はO(N^4)で済む。
最終的にdp(0,0,0,sum)のうち、必要なコスト削減量が達成できるsumにおける値の総和が解。
int N; ll X; int A[101]; const int NUM_=400001; static ll fact[NUM_+1],factr[NUM_+1],inv[NUM_+1]; const ll mo=998244353; ll comb(ll N_, ll C_) { if(C_<0 || C_>N_) return 0; return factr[C_]*fact[N_]%mo*factr[N_-C_]%mo; } ll from[101][6555]; ll to[101][6555]; void solve() { int i,j,k,l,r,x,y; string s; inv[1]=fact[0]=factr[0]=1; for (int i=2;i<=NUM_;++i) inv[i] = inv[mo % i] * (mo - mo / i) % mo; for (int i=1;i<=NUM_;++i) fact[i]=fact[i-1]*i%mo, factr[i]=factr[i-1]*inv[i]%mo; cin>>N>>X; ll S=0,S2=0; FOR(i,N) { cin>>A[i]; S+=A[i]; } sort(A,A+N); FOR(i,N) S2+=max(A[i]-i,0); if(S2>X) { cout<<0<<endl; return; } if(X>=S) { cout<<fact[N]<<endl; return; } int need=S-X; from[0][0]=1; int pre=N-1; FOR(i,N) { int v=min(N-1,A[N-1-i]); while(pre>v) { for(x=1;x<=i+1;x++) for(y=0;y<=1LL*N*(N+1)/2;y++) if(from[x][y]) (from[x-1][y+pre]+=from[x][y]*x)%=mo; pre--; } ZERO(to); // x未使用 // y減らす量 for(x=0;x<=i;x++) for(y=0;y<=1LL*N*(N+1)/2;y++) if(from[x][y]) { //より小さいところ to[x+1][y]+=from[x][y]; int unused=N-1-v-(i-x); if(unused>0) (to[x][y+v]+=from[x][y]*unused)%=mo; } swap(from,to); } while(pre>=0) { for(x=1;x<=N;x++) for(y=0;y<=1LL*N*(N+1)/2;y++) if(from[x][y]) (from[x-1][y+pre]+=from[x][y]*x)%=mo; pre--; } ll ret=0; for(i=S-X;i<=N*(N+1)/2;i++) ret+=from[0][i]; cout<<ret%mo<<endl; }
まとめ
これは少し手間取ったけどまぁ普通に解けたかな。