幸いこれはすんなり解けた。
https://atcoder.jp/contests/arc186/tasks/arc186_b
問題
1-indexedなN要素の数列Aが与えられる。なお、0≦A[i]<iである。
以下を満たす、1-indexedな1~Nの順列Pは何通りか。
- A[i]<j<iであるi,jに対し、P[j]>P[i]
- A[i]>0ならばP[A[i]]<P[i]
解法
Pを0-indexedな0~Nの順列とし、P[0]=0に固定すると、2つ目の条件を無視できる。
P[L,R]のうち未確定要素を、値の小さい順に定めることを考える。
A[L,R]のうち最小値であり、また同着で最小値の要素が複数あるならそのうち最大のindexを考える。これをMとする。
こうするとP[M]が区間内最小なことが確定する。
この時、A[M+1....R]がすべてM以上であれば、P[L...(M-1)]とP[(M+1)...R]の値は互いに影響しないので、それぞれの組み合わせを再帰的に計算して掛け合わせればよい。
そうでない場合、また同様にP[L,R]の未確定要素のうち最小値を定めて行く。
int N,A[303030]; const ll mo=998244353; template<class V,int NV> class SegTree_Pair { public: vector<pair<V,int> > val; static V const def=-(1<<30); pair<V,int> comp(pair<V,int> l,pair<V,int> r){ return max(l,r);} SegTree_Pair(){ val.resize(NV*2); int i; FOR(i,NV) val[i+NV]=make_pair(def,i); for(i=NV-1;i>=1;i--) val[i]=comp(val[2*i],val[2*i+1]); }; pair<V,int> getval(int x,int y,int l=0,int r=NV,int k=1) { if(r<=x || y<=l) return make_pair(def,NV); if(x<=l && r<=y) return val[k]; return comp(getval(x,y,l,(l+r)/2,k*2),getval(x,y,(l+r)/2,r,k*2+1)); } void update(int entry, V v) { entry += NV; val[entry]=make_pair(v,entry-NV); while(entry>1) entry>>=1, val[entry]=comp(val[entry*2],val[entry*2+1]); } }; template<class V, int ME> class BIT { public: V bit[1<<ME]; V operator()(int e) {if(e<0) return 0;V s=0;e++;while(e) s+=bit[e-1],e-=e&-e; return s;} void add(int e,V v) { e++; while(e<=1<<ME) bit[e-1]+=v,e+=e&-e;} }; BIT<int,20> bit; SegTree_Pair<int,1<<20> stmi,stma; ll ret=1; ll comb(ll N_, ll C_) { const int NUM_=400001; static ll fact[NUM_+1],factr[NUM_+1],inv[NUM_+1]; if (fact[0]==0) { inv[1]=fact[0]=factr[0]=1; for (int i=2;i<=NUM_;++i) inv[i] = inv[mo % i] * (mo - mo / i) % mo; for (int i=1;i<=NUM_;++i) fact[i]=fact[i-1]*i%mo, factr[i]=factr[i-1]*inv[i]%mo; } if(C_<0 || C_>N_) return 0; return factr[C_]*fact[N_]%mo*factr[N_-C_]%mo; } void dfs(int L,int R) { if(bit(R-1)-bit(L-1)<=1) return; auto p=stmi.getval(L,R); int M=p.second; auto p2=stma.getval(M+1,R); bit.add(M,-1); stmi.update(M,-1000000); stma.update(M,-1000000); if(p2.first<M) { dfs(L,R); return; } else { int a=bit(M-1)-bit(L-1); int b=bit(R-1)-bit(M); ret=ret*comb(a+b,a)%mo; dfs(L,M); dfs(M+1,R); } } void solve() { int i,j,k,l,r,x,y; string s; cin>>N; FOR(i,N) { cin>>A[i+1]; stmi.update(i+1,-A[i+1]); stma.update(i+1,A[i+1]); bit.add(i+1,1); } dfs(1,N+1); cout<<ret<<endl; }
まとめ
想定解と違ったっぽいな。