kmjp's blog

競技プログラミング参加記です

AtCoder ARC #114 : F - Permutation Division

ARCの最終問題、AGCの中間の問題より実装が重い気がする。
https://atcoder.jp/contests/arc114/tasks/arc114_f

問題

1~NのPermutation Pと、整数Kが与えられる。
PをK個の空でない連続な部分列に分割したとする。
その後、K個の部分列を並べ替えて再度連結し、辞書順でできるだけ大きくなるようにする。

並べ替え後の結果が辞書順最小になるように分割したとき、再連結後の数列を求めよ。

解法

先頭要素がK以下の場合を考える。
K個の部分列にすると、各部分列の先頭を1~Kにすることができる。
そうすると並べ替えによって先頭がKになる。
これ以外の分割は、先頭に(K+1)以上の要素が出てくるので好ましくない。
よって、分割の仕方は一意に決まる。

以下、先頭要素がK+1以上のケースを考える。
分割した先頭要素が大きくなると好ましくない。
そこで、先頭からできるだけ先頭要素が減少列になるように選ぶことを考える。
そうすると、これらの要素は、Pから変化しない。
ただし減少列でK要素確保できない場合、末尾の部分は若干並べ替えが生じてしまう。

dp[i]を、P[0]を先頭都市、P[i]を末尾とするPの部分列のうち最長の減少列とする。
K分割のうちmin(K,dp[i])個の部分列は、この減少列を先頭とする要素で賄える。
あとは、P[i+1]以降でmax(0,K-dp[i])個の要素を選択しなければならない。
その場合、P[i]未満の要素のうち、数列の末尾に近いmax(0,K-dp[i])個を専用要素として選べばよい。
P[i]以降、数列の末尾に近いmax(0,K-dp[i])個の手前の要素は、P[i]と同じ部分列に入れる。

iを総当たりすると、dp[i]の計算はRMQをつかえるSegTreeでO(NlogN)、P[i]と同じ部分列に入れる長さの算出はBIT上の二分探索でO(Nlog^2N)で計算できる。
iを総当たりしながら、この部分列が最長であるものを探そう。

int N,K;
int P[202020],R[202020];
int dp[202020];

template<class V,int NV> class SegTree_1 {
public:
	vector<V> val;
	static V const def=0;
	V comp(V l,V r){ return max(l,r);};
	
	SegTree_1(){val=vector<V>(NV*2,def);};
	V getval(int x,int y,int l=0,int r=NV,int k=1) { // x<=i<y
		if(r<=x || y<=l) return def;
		if(x<=l && r<=y) return val[k];
		return comp(getval(x,y,l,(l+r)/2,k*2),getval(x,y,(l+r)/2,r,k*2+1));
	}
	void update(int entry, V v) {
		entry += NV;
		val[entry]=comp(v,val[entry]); //上書きかチェック
		while(entry>1) entry>>=1, val[entry]=comp(val[entry*2],val[entry*2+1]);
	}
};
SegTree_1<int,1<<18> lis;

template<class V, int ME> class BIT {
public:
	V bit[1<<ME];
	V operator()(int e) {if(e<0) return 0;V s=0;e++;while(e) s+=bit[e-1],e-=e&-e; return s;}
	void add(int e,V v) { e++; while(e<=1<<ME) bit[e-1]+=v,e+=e&-e;}
	int lower_bound(V val) {
		V tv=0; int i,ent=0;
		for(i=ME-1;i>=0;i--) if(tv+bit[ent+(1<<i)-1]<val) tv+=bit[ent+(1<<i)-1],ent+=(1<<i);
		return ent;
	}
};
BIT<int,19> bit;


void solve() {
	int i,j,k,l,r,x,y; string s;
	
	cin>>N>>K;
	FOR(i,N) cin>>P[i];
	if(P[0]<=K) {
		vector<vector<int>> V;
		FOR(i,N) {
			if(P[i]<=K) {
				V.push_back({P[i]});
			}
			else {
				V.back().push_back(P[i]);
			}
		}
		sort(ALL(V));
		reverse(ALL(V));
		FORR(v,V) FORR(a,v) cout<<a<<" ";
		cout<<endl;
	}
	else {
		FOR(i,N) {
			x=P[i];
			R[x]=i;
			if(P[i]<=P[0]) {
				dp[x]=lis.getval(x,N+1)+1;
				lis.update(x,dp[x]);
			}
		}
		vector<int> ret={0,0,0};
		for(i=1;i<=P[0];i++) {
			x=R[i];
			if(bit(N)-bit(x)>=K-dp[i]) {
				y=x;
				for(j=17;j>=0;j--) if(y+(1<<j)<N&&bit(N)-bit(y+(1<<j))>=K-dp[i]) y+=1<<j;
				ret=max(ret,{y,min(K,dp[i]),i});
			}
			bit.add(x,1);
		}
		vector<vector<int>> V;
		FOR(i,ret[0]+1) cout<<P[i]<<" ";
		for(i=ret[0]+1;i<N;i++) {
			if(P[i]<ret[2]) {
				V.push_back({P[i]});
			}
			else {
				V.back().push_back(P[i]);
			}
		}
		assert(V.size()==K-ret[1]);
		sort(ALL(V));
		reverse(ALL(V));
		FORR(v,V) FORR(a,v) cout<<a<<" ";
		cout<<endl;
	}
}

まとめ

CDEの時間によるけど、これ本番で解き切るのしんどそう。