0
点赞
收藏
分享

微信扫一扫

[NOI2019] 机器人

一枚路过的程序猿 2022-04-25 阅读 71
c++

目录

      • 题面
      • 题解
        • 20pts: n , B i ≤ 7 n,B_i\leq 7 n,Bi7
        • 35pts: B i ≤ 100 B_i\leq 100 Bi100
        • 50pts: B i ≤ 1 0 4 B_i\leq 10^4 Bi104
        • 60pts: A i = 1 , B i = 1 0 9 A_i=1,B_i=10^9 Ai=1,Bi=109
        • 100pts:

题面

题面描述

题解

20pts: n , B i ≤ 7 n,B_i\leq 7 n,Bi7

B i B_i Bi n n n都非常小,暴力模拟即可

35pts: B i ≤ 100 B_i\leq 100 Bi100

假设一个全局最高点 p p p,钦定 [ 1 , p − 1 ] [1,p-1] [1,p1]内任意点高度小于等于 p p p的高度, [ p + 1 , n ] [p+1,n] [p+1,n]内任意点的高度小于 p p p的高度,显然, p p p [ 1 , n ] [1,n] [1,n]分割为 [ 1 , p − 1 ] [1,p-1] [1,p1] [ p + 1 , n ] [p+1,n] [p+1,n]两个子问题。由此可以区间DP。
d p l , r , x dp_{l,r,x} dpl,r,x表示 [ l , r ] [l,r] [l,r]最大值恰好为x时的方案数,暴力枚举断点和最大值转移,容易得到:
d p l , r , x = ∑ m i d ∑ k ≤ x d p l , m i d − 1 , k ∑ k < x d p m i d + 1 , r , k dp_{l,r,x}=\sum_{mid}\sum_{k\leq x}dp_{l,mid-1,k}\sum_{k<x} dp_{mid+1,r,k} dpl,r,x=midkxdpl,mid1,kk<xdpmid+1,r,k

利用前缀和优化

参考代码:

#include<bits/stdc++.h>
#define ll long long
#define uit unsigned int 
using namespace std;
const int N=301;
const ll M=1e9+7;
int n,A[N],B[N],dp[N][N][110];
int main(){
//	freopen(".in","r",stdin);
//	freopen(".out","w",stdout);
	int MAX=0;
	scanf("%d",&n);
	for(int i=1;i<=n;i++) scanf("%d%d",&A[i],&B[i]),MAX=max(MAX,B[i]);
	for(int i=1;i<=n;i++){
		for(int j=1;j<=MAX;j++){
			if(j>=A[i]&&j<=B[i]) dp[i][i][j]++;
			dp[i][i][j]=(dp[i][i][j]+dp[i][i][j-1])%M;
		}
	}
	for(int len=2;len<=n;len++)
		for(int i=1;i+len-1<=n;i++){
			int j=i+len-1;
			for(int k=i;k<=j;k++)
				if(abs((k-i)-(j-k))<=2)
				for(int w=A[k];w<=B[k];w++){
					if(k==i) dp[i][j][w]=(dp[i][j][w]+dp[k+1][j][w-1])%M;
					else if(k==j) dp[i][j][w]=(dp[i][j][w]+dp[i][k-1][w])%M;
					else dp[i][j][w]=(dp[i][j][w]+1ll*dp[i][k-1][w]*dp[k+1][j][w-1]%M)%M;
				}
			for(int k=1;k<=MAX;k++)
				dp[i][j][k]=(dp[i][j][k]+dp[i][j][k-1])%M;
		}
	printf("%d\n",dp[1][n][MAX]);
	return 0;
}

50pts: B i ≤ 1 0 4 B_i\leq 10^4 Bi104

注意到用到的区间数远小于 n 2 n^2 n2,只对有用的区间进行转移即可。时间与空间耗费有所降低,足以通过 B i ≤ 1 0 4 B_i\leq 10^4 Bi104的测试点。
这里可以采用递归实现。

参考代码:

#include<bits/stdc++.h>
#define ll long long
#define uit unsigned int
#define mul(x,y) (1ll*x*y%M)
#define pls(x,y) ((x+y)%M) 
using namespace std;
const int N=305;
const ll M=1e9+7;
int n,idx=0,MAX=0,A[N],B[N],id[N][N],dp[2520][10001];
void solve(int l,int r){
	if(id[l][r]) return ;
	id[l][r]=++idx;
	if(l>r) dp[id[l][r]][0]=1;
	else{
		for(int k=l;k<=r;k++)
			if(abs((k-l)-(r-k))<=2){
				solve(l,k-1);solve(k+1,r);
				for(int j=A[k];j<=B[k];j++)
					dp[id[l][r]][j]=pls(dp[id[l][r]][j],mul(dp[id[l][k-1]][j],dp[id[k+1][r]][j-1]));
			}
	}
	for(int j=1;j<=MAX;j++) dp[id[l][r]][j]=pls(dp[id[l][r]][j],dp[id[l][r]][j-1]);
}
int main(){
//	freopen(".in","r",stdin);
//	freopen(".out","w",stdout);
	scanf("%d",&n);
	for(int i=1;i<=n;i++) scanf("%d%d",&A[i],&B[i]),MAX=max(MAX,B[i]);
	solve(1,n);
	printf("%d\n",dp[id[1][n]][MAX]);
	return 0;
}

60pts: A i = 1 , B i = 1 0 9 A_i=1,B_i=10^9 Ai=1,Bi=109

前50pts采用以上所述的区间DP,考虑10pts的特殊点: A i = 1 , B i = 1 0 9 A_i=1,B_i=10^9 Ai=1,Bi=109
因为 A i = A j , B i = B j A_i=A_j,B_i=B_j Ai=Aj,Bi=Bj,任意两点等价,不用考虑值域变化对不同点的不同影响。
l = r l=r l=r时, d p l , r , x dp_{l,r,x} dpl,r,x的最高项为 x 0 x^0 x0,其前缀和最高项为 x 1 x^1 x1.观察dp时对区间合并的过程,显然 d p l , r , x dp_{l,r,x} dpl,r,x的前缀和是 ( r − l + 1 ) (r-l+1) (rl+1)次多项式。
要求 d p 1 , n , V dp_{1,n,V} dp1,n,V,递归dp求出前 n + 1 n+1 n+1项,然后拉插即可。

参考代码:

#include<bits/stdc++.h>
#define ll long long
#define uit unsigned int
#define mul(x,y) (1ll*x*y%M)
#define pls(x,y) ((x+y)%M)
using namespace std;
const int N=305;
const ll M=1e9+7;
int n,idx=0,MAX=0,A[N],B[N],id[N][N],dp[2520][10001];ll pw[N],cir[N],f[2520][55];
ll ftp(ll b,ll p){
	ll r=1;
	while(p){
		if(p&1) r=r*b%M;
		b=b*b%M;
		p>>=1;
	}
	return r%M;
}
void solve(int l,int r){
	if(id[l][r]) return ;
	id[l][r]=++idx;
	if(l>r) dp[id[l][r]][0]=1;
	else{
		for(int k=l;k<=r;k++)
			if(abs((k-l)-(r-k))<=2){
				solve(l,k-1);solve(k+1,r);
				for(int j=A[k];j<=B[k];j++)
					dp[id[l][r]][j]=pls(dp[id[l][r]][j],mul(dp[id[l][k-1]][j],dp[id[k+1][r]][j-1]));
			}
	}
	for(int j=1;j<=MAX;j++) dp[id[l][r]][j]=pls(dp[id[l][r]][j],dp[id[l][r]][j-1]);
}
void sol1(){
	solve(1,n);
	printf("%d\n",dp[id[1][n]][MAX]);
}
void DP(int l,int r){
    if(id[l][r]) return;
    id[l][r]=++idx;
    if(l>r) dp[id[l][r]][0]=1;
    else{
        for(int k=l;k<=r;k++)
            if(abs((k-l)-(r-k))<=2){
                DP(l,k-1);DP(k+1,r);
                for(int j=1;j<=n+1;j++)
                    dp[id[l][r]][j]=pls(dp[id[l][r]][j],mul(dp[id[l][k-1]][j],dp[id[k+1][r]][j-1]));
            }
    }
    for(int j=1;j<=n+1;j++) dp[id[l][r]][j]=pls(dp[id[l][r]][j],dp[id[l][r]][j-1]);
}
ll lang(){
    ll val=1,res=0,val2=0;
    for(int i=1;i<=n+1;i++){
        val=dp[id[1][n]][i];val2=1;
        for(int j=1;j<=n+1;j++)
            if(i!=j) val=1ll*val*((MAX-j)%M)%M,val2=1ll*val2*((i-j+M)%M)%M;
        val=val*ftp((val2%M+M)%M,M-2)%M;
        res=(res+val)%M;
    }
    return res;
}
void sol2(){
	pw[0]=1ll;for(int i=1;i<=n+1;i++) pw[i]=pw[i-1]*(ll)i%M;
	cir[n+1]=ftp(pw[n+1],M-2);for(int i=n;i>=0;i--) cir[i]=cir[i+1]*(ll)(i+1)%M;
	DP(1,n);
	printf("%lld\n",lang());
}
int main(){
//	freopen(".in","r",stdin);
//	freopen(".out","w",stdout);
	bool onenine=1;
	scanf("%d",&n);
	for(int i=1;i<=n;i++){
		scanf("%d%d",&A[i],&B[i]),MAX=max(MAX,B[i]);
		if(A[i]!=1||B[i]!=(int)(1e9)) onenine=0;
	}
	if(n<=300&&MAX<=10000){sol1();return 0;}
	if(onenine){sol2();return 0;}
	return 0;
}

100pts:

此时不再限制 A i = A j & & B i = B j A_i=A_j\&\&B_i=B_j Ai=Aj&&Bi=Bj
每个 d p l , r , x dp_{l,r,x} dpl,r,x是在值域上的分段函数,若仍直接对整个值域处理, d p l , r , x dp_{l,r,x} dpl,r,x的最高项不再是 x 0 x^0 x0
考虑针对每个 A i , B i A_i,B_i Ai,Bi,对值域分段,拆分为若干个左闭右开的区间 [ c t , c t + 1 ) [c_t,c_{t+1}) [ct,ct+1)。区间数量是 O ( n ) O(n) O(n)的。
逐段dp,每一段只需dp [ c t , c t + n − 1 ] [c_t,c_t+n-1] [ct,ct+n1]得到 n + 1 n+1 n+1项,做前缀和,拉插,再进行下一段即可。
注意实现方式。

时间复杂度: O ( n 2 m ) O(n^2m) O(n2m), m m m为有用区间数

参考代码:

#include<bits/stdc++.h>
#define re register
#define mid ((l+r)>>1)
using namespace std;
typedef long long ll;
typedef unsigned int uit;
const int N=302,M=1e9+7;
template<typename T> void read(T &x){
    x=0;bool flag=1;char c=getchar();
    for(;c<'0'||c>'9';c=getchar()) if(c=='-') flag=0;
    for(;c>='0'&&c<='9';c=getchar()) x=(x<<1)+(x<<3)+(c^48);
    x=flag?x:-x;
}
int n,m,idx=0,L,T,A[N],B[N],LS[N<<1],id[N][N],dp[50010][310],pw[N],cir[N],tmp[N];bool vis[N][N];
inline void add(re int &x,re int y){
    (x+=y)>=M?x-=M:x;
}
inline int ftp(ll b,ll p){
    int r=1;
    while(p){
        if(p&1) r=1ll*r*b%M;
        b=1ll*b*b%M;
        p>>=1;
    }
    return r%M;
}
inline void init(re int l,re int r){
    if(l>r||id[l][r]) return;
    id[l][r]=++idx;
    for(re int i=max(l,mid-1);i<=min(r,mid+1);i++){
        if(abs((i-l)-(r-i))<=2){
            init(l,i-1);init(i+1,r);
        }
    }
}
inline void DP(int l,int r){
    if(vis[l][r]||l>r) return ;
    vis[l][r]=1;
    for(re int i=max(l,mid-1);i<=min(r,mid+1);i++){
        if(abs((i-l)-(r-i))<=2){
            DP(l,i-1);DP(i+1,r);
            if(A[i]<=T&&B[i]>T){
                for(re int j=1;j<=L;j++)
                    add(dp[id[l][r]][j],1ll*dp[id[l][i-1]][j]*dp[id[i+1][r]][j-1]%M);
            }
        }
    }
    for(re int j=1;j<=L;j++)
        add(dp[id[l][r]][j],dp[id[l][r]][j-1]);
}
inline void lang(re int l,re int r){
    if(r-l<=n){
        for(re int i=1;i<=idx;i++) dp[i][0]=dp[i][r-l+1];
        return ;
    }
    for(re int i=1;i<=idx;i++) dp[i][0]=0;
    tmp[n+1]=1;for(int i=n;i>=0;i--) tmp[i]=1ll*tmp[i+1]*(r-(l+i))%M;
    re int val=1,res=0;
    for(re int i=l;i<=l+n;i++){
        res=1ll*val*tmp[i-l+1]%M*cir[i-l]%M*cir[l+n-i]%M;
        if((l+n-i)&1) res=M-res;
        for(re int j=1;j<=idx;j++) add(dp[j][0],1ll*res*dp[j][i-l+1]%M);
        val=1ll*val*(r-i)%M;
    }
}
int main(){
//    freopen(".in","r",stdin);
//    freopen(".out","w",stdout);
    read(n);
    for(re int i=1;i<=n;i++){
        read(A[i]);read(B[i]);++B[i];
        LS[++m]=A[i];LS[++m]=B[i];
    }
    sort(LS+1,LS+m+1);m=unique(LS+1,LS+m+1)-LS-1;
    for(int i=1;i<=n;i++) A[i]=lower_bound(LS+1,LS+m+1,A[i])-LS,B[i]=lower_bound(LS+1,LS+m+1,B[i])-LS;
    pw[0]=1ll;for(re int i=1;i<=n+1;i++) pw[i]=1ll*pw[i-1]*i%M;
    cir[n+1]=ftp(pw[n+1],M-2);for(re int i=n;i>=0;i--) cir[i]=1ll*cir[i+1]*(i+1ll)%M;
    init(1,n);
    for(int i=0;i<=n+1;i++) dp[0][i]=1;
    for(T=1;T<m;T++){
        L=min(LS[T+1]-LS[T],n+1);
        memset(vis,0,sizeof vis);
        DP(1,n);
        lang(LS[T],LS[T+1]-1);
        for(re int i=1;i<=idx;i++) for(re int j=1;j<=L;j++) dp[i][j]=0;
    }
    printf("%d\n",dp[id[1][n]][0]);
    return 0;
}
举报

相关推荐

0 条评论