[BZOJ2006]超级钢琴

题目大意

从一个序列中,找出k段连续的子序列,且每个子序列的长度不小于L,不大于R。

一个序列的权值定义为序列中所有数的和,求k个子序列和的最大值。具体看原题

Solution

这道题应该用主席树来做,但是太麻烦。相比起来,ST表+堆的算法显然比较方便(可能我代码写得太丑)。

考虑每次取出一段最大的合法的子序列,这样取k次。第一次取的时候,如果枚举了一个右端点r,那么r-R+1<=l<=r-L+1。
如果用前缀和sum数组表示这个区间的和,就是sum[r]-sum[l-1]。当r固定时,l也在一个固定的区间内变化,只要sum[l]取这个区间中的最小值,这个值就是这个右端点所对应的最大值。最小值用st表维护。再将所有的右端点对应的最大值扔进一个大根堆中,每次取最大值。

然而当一个右端点取了其对应的区间后,这个区间就不能取了,即这个左端点就不能取了。但是ST表是不能删数的(可以写主席树)。
但是可以转化为一个右端点一个左端点l后,原来可行的区间为[x,y],现在变为了[x,l-1]和[l+1,y]。所以可以将这个可行的区间扔进堆中。
具体就是开priority_queue<pair<pair<int,int>,pair<int,int> > >q 这样一个优先队列,每次把分别把前缀和最小值,右端点,对应的可行区间的x,y四个值放入堆中。

Code
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<queue>
#define mp make_pair
using namespace std;
typedef long long ll;
priority_queue<pair<pair<int,int>,pair<int,int> > >q;
int Max[500010][22],Pos[500010][22],sum[500010],a[500010];
int main()
{
	int n,k,L,R;
	scanf("%d%d%d%d",&n,&k,&L,&R);
	for (int i=1;i<=n;i++)
		scanf("%d",&a[i]);
	for (int i=1;i<=n;i++)
		sum[i]=sum[i-1]+a[i];
	int n1=log(n)/log(2);
	for (int i=1;i<=n;i++)
		Max[i][0]=sum[i],Pos[i][0]=i;
	for (int i=1;i<=n1;i++)
	{
		int q1=n-(1<<i)+1;
		for (int j=1;j<=q1;j++)
		{
			int r1=j+(1<<(i-1));
			if (Max[j][i-1]>Max[r1][i-1])
				Max[j][i]=Max[j][i-1],Pos[j][i]=Pos[j][i-1];
			else Max[j][i]=Max[r1][i-1],Pos[j][i]=Pos[r1][i-1];
			//Max[j][i]=max(Max[j][i-1],Max[j+(1<<(i-1))][i-1]);
		}
	}
	for (int i=1;i<=n;i++)
	{
		int l1=i+L-1,r1=i+R-1;
		r1=min(r1,n);
		if (l1>r1) break;
		int k=log(r1-l1+1)/log(2),Sum,pos;
		if (Max[l1][k]>Max[r1-(1<<k)+1][k])
			Sum=Max[l1][k],pos=Pos[l1][k];
		else Sum=Max[r1-(1<<k)+1][k],pos=Pos[r1-(1<<k)+1][k];
		//printf("%d %d\n",i,Sum);
		q.push(mp(mp(Sum-sum[i-1],i),mp(l1,r1)));
	}
	ll ans=0;
	while (k--)
	{
		pair<int,int>Fi=q.top().first;
		pair<int,int>Se=q.top().second;
		int sum1=Fi.first;
		int start=Fi.second;
		int l=Se.first;
		int r=Se.second;
		q.pop();
		ans+=sum1;
		int k=log(r-l+1)/log(2),Sum,pos;
		if (Max[l][k]>Max[r-(1<<k)+1][k])
			pos=Pos[l][k];
		else pos=Pos[r-(1<<k)+1][k];
		int l1=pos+1,r1=r,pos1=pos;
		//printf("%d %d %d %d\n",sum1,pos,l,r);
		if (l1<=r1)
		{
			int k=log(r1-l1+1)/log(2);
			if (Max[l1][k]>Max[r1-(1<<k)+1][k])
				Sum=Max[l1][k],pos=Pos[l1][k];
			else Sum=Max[r1-(1<<k)+1][k],pos=Pos[r1-(1<<k)+1][k];
			q.push(mp(mp(Sum-sum[start-1],start),mp(l1,r1)));
		}
		l1=l,r1=pos1-1;
		if (l1<=r1)
		{
			int k=log(r1-l1+1)/log(2);
			if (Max[l1][k]>Max[r1-(1<<k)+1][k])
				Sum=Max[l1][k],pos=Pos[l1][k];
			else Sum=Max[r1-(1<<k)+1][k],pos=Pos[r1-(1<<k)+1][k];
			q.push(mp(mp(Sum-sum[start-1],start),mp(l1,r1)));
		}
	}
	printf("%lld\n",ans);
	return 0;
}

还没有评论,快来抢沙发!

发表评论