kmjp's blog

競技プログラミング参加記です

AtCoder ARC #186 (AtCoder Japan Open -予選-) : B - Typical Permutation Descriptor

幸いこれはすんなり解けた。
https://atcoder.jp/contests/arc186/tasks/arc186_b

問題

1-indexedなN要素の数列Aが与えられる。なお、0≦A[i]<iである。
以下を満たす、1-indexedな1~Nの順列Pは何通りか。

  • A[i]<j<iであるi,jに対し、P[j]>P[i]
  • A[i]>0ならばP[A[i]]<P[i]

解法

Pを0-indexedな0~Nの順列とし、P[0]=0に固定すると、2つ目の条件を無視できる。

P[L,R]のうち未確定要素を、値の小さい順に定めることを考える。
A[L,R]のうち最小値であり、また同着で最小値の要素が複数あるならそのうち最大のindexを考える。これをMとする。

こうするとP[M]が区間内最小なことが確定する。
この時、A[M+1....R]がすべてM以上であれば、P[L...(M-1)]とP[(M+1)...R]の値は互いに影響しないので、それぞれの組み合わせを再帰的に計算して掛け合わせればよい。
そうでない場合、また同様にP[L,R]の未確定要素のうち最小値を定めて行く。

int N,A[303030];
const ll mo=998244353;

template<class V,int NV> class SegTree_Pair {
public:
	vector<pair<V,int> > val;
	static V const def=-(1<<30);
	pair<V,int> comp(pair<V,int> l,pair<V,int> r){ return max(l,r);}
	SegTree_Pair(){
		val.resize(NV*2);
		int i;
		FOR(i,NV) val[i+NV]=make_pair(def,i);
		for(i=NV-1;i>=1;i--) val[i]=comp(val[2*i],val[2*i+1]);
	};
	pair<V,int> getval(int x,int y,int l=0,int r=NV,int k=1) {
		if(r<=x || y<=l) return make_pair(def,NV);
		if(x<=l && r<=y) return val[k];
		return comp(getval(x,y,l,(l+r)/2,k*2),getval(x,y,(l+r)/2,r,k*2+1));
	}
	void update(int entry, V v) {
		entry += NV;
		val[entry]=make_pair(v,entry-NV);
		while(entry>1) entry>>=1, val[entry]=comp(val[entry*2],val[entry*2+1]);
	}
};

template<class V, int ME> class BIT {
public:
	V bit[1<<ME];
	V operator()(int e) {if(e<0) return 0;V s=0;e++;while(e) s+=bit[e-1],e-=e&-e; return s;}
	void add(int e,V v) { e++; while(e<=1<<ME) bit[e-1]+=v,e+=e&-e;}
};
BIT<int,20> bit;

SegTree_Pair<int,1<<20> stmi,stma;

ll ret=1;

ll comb(ll N_, ll C_) {
	const int NUM_=400001;
	static ll fact[NUM_+1],factr[NUM_+1],inv[NUM_+1];
	if (fact[0]==0) {
		inv[1]=fact[0]=factr[0]=1;
		for (int i=2;i<=NUM_;++i) inv[i] = inv[mo % i] * (mo - mo / i) % mo;
		for (int i=1;i<=NUM_;++i) fact[i]=fact[i-1]*i%mo, factr[i]=factr[i-1]*inv[i]%mo;
	}
	if(C_<0 || C_>N_) return 0;
	return factr[C_]*fact[N_]%mo*factr[N_-C_]%mo;
}


void dfs(int L,int R) {
	if(bit(R-1)-bit(L-1)<=1) return;
	
	auto p=stmi.getval(L,R);
	int M=p.second;
	
	auto p2=stma.getval(M+1,R);
	bit.add(M,-1);
	stmi.update(M,-1000000);
	stma.update(M,-1000000);
	if(p2.first<M) {
		dfs(L,R);
		return;
	}
	else {
		int a=bit(M-1)-bit(L-1);
		int b=bit(R-1)-bit(M);
		ret=ret*comb(a+b,a)%mo;
		dfs(L,M);
		dfs(M+1,R);
	}
	
	
	
}


void solve() {
	int i,j,k,l,r,x,y; string s;
	
	cin>>N;
	FOR(i,N) {
		cin>>A[i+1];
		stmi.update(i+1,-A[i+1]);
		stma.update(i+1,A[i+1]);
		bit.add(i+1,1);
	}
	dfs(1,N+1);
	cout<<ret<<endl;
	
}

まとめ

想定解と違ったっぽいな。