0
点赞
收藏
分享

微信扫一扫

[JZOJ6053]【NOI2019模拟2019.3.12】Mas的仙人掌 【概率与期望】


Description

[JZOJ6053]【NOI2019模拟2019.3.12】Mas的仙人掌 【概率与期望】_c++


[JZOJ6053]【NOI2019模拟2019.3.12】Mas的仙人掌 【概率与期望】_#define_02

Solution

其实这道题也不太难,有一点很Tricky的东西

一开始我觉得这道题是DP,死磕结果磕不出来,连m^2都不会
后来看了题解发现自己走进了死胡同。

首先将每条路径的贡献分开计算,那么就是要求这条路径选的概率,乘上与它相交的所有路径的不选概率之积。m^2是容易的

考虑如何优化
我们希望相交的就计算一次(无论相交几条边),不相交的计算0次
接下来是一个很妙的操作,首先树上路径的交仍然是一个路径
考虑如何在路径上构造出"1",有1=路径的边数-路径上长度为2的子路径数

那么对于每一条路径,在这条路径上所有边以及所有长度为2的子路径都乘上这条边不选的概率,统计一条路径的贡献就算出这条路径每条边的权值乘积除以每个长度为2的子路径乘积,用前缀积即可,这样完美的解决了要求相交只算一次的问题。

注意长度为2的子路径有可能是在lca上的,这种情况要特殊判断,用个map记一下两棵子树分别是谁即可。

Code

#include <bits/stdc++.h>
#define fo(i,a,b) for(int i=a;i<=b;++i)
#define fod(i,a,b) for(int i=a;i>=b;--i)
#define N 1300005
#define mo 998244353
#define LL long long
using namespace std;
int n,m,fs[N],nt[2*N],dt[2*N],fa[22][N],dep[N],ask[N][4],m1;
struct num
{
LL x,y;
}s1[N],s2[N],sr[N];
num operator *(num a,num b) {return (num){a.x*b.x%mo,a.y+b.y};}
map<LL,num> h;
map<LL,int> bp;
int jump(int k,int v)
{
for(int c=0;v;v>>=1,c++) if(v&1) k=fa[c][k];
return k;
}
LL ksm(LL k,LL n)
{
LL s=1;
for(;n;n>>=1,k=k*k%mo) if(n&1) s=s*k%mo;
return s;
}
num T(LL x)
{
if(x==0) return (num){1,1};
else return (num){x,0};
}
num nT(LL x)
{
if(x==0) return ((num){1,-1});
else return ((num){ksm(x,mo-2),0});
}
num nR(num x)
{
return((num){ksm(x.x,mo-2),-x.y});
}
void dfs(int k)
{
dep[k]=dep[fa[0][k]]+1;
for(int i=fs[k];i;i=nt[i])
{
int p=dt[i];
if(p!=fa[0][k]) fa[0][p]=k,dfs(p);
}
}
int lca(int x,int y)
{
if(dep[x]>dep[y]) swap(x,y);
for(int j=dep[y]-dep[x],c=0;j;c++,j>>=1) if(j&1) y=fa[c][y];
for(int j=19;x!=y;)
{
while(j&&fa[j][x]==fa[j][y]) j--;
x=fa[j][x],y=fa[j][y];
}
return x;
}
void make(int k)
{
for(int i=fs[k];i;i=nt[i])
{
int p=dt[i];
if(p!=fa[0][k]) make(p),s1[k]=s1[k]*s1[p],s2[k]=s2[k]*s2[p];
}
}
void bd(int k)
{
for(int i=fs[k];i;i=nt[i])
{
int p=dt[i];
if(p!=fa[0][k]) s1[p]=s1[p]*s1[k],s2[p]=s2[p]*s2[k],bd(p);
}
}
void link(int x,int y)
{
nt[++m1]=fs[x];
dt[fs[x]=m1]=y;
}
int main()
{
cin>>n>>m;
fo(i,1,n-1)
{
int x,y;
scanf("%d%d",&x,&y);
link(x,y),link(y,x);
}
dfs(1);
fo(i,0,n) s1[i]=s2[i]=T(1);
fo(j,1,21) fo(i,1,n) fa[j][i]=fa[j-1][fa[j-1][i]];
fo(i,1,m)
{
int x,y,p;
LL v;
scanf("%d%d%lld",&x,&y,&v);
num vl=T(v),nvl=nT(v);
if(dep[x]>dep[y]) swap(x,y);
p=lca(x,y);
ask[i][0]=x,ask[i][1]=y,ask[i][2]=v,ask[i][3]=p;
s1[x]=s1[x]*vl,s1[y]=s1[y]*vl,s1[p]=s1[p]*nvl*nvl;

if(p==x)
{
int q=jump(y,dep[y]-dep[p]-1);
s2[y]=s2[y]*vl,s2[q]=s2[q]*nvl;
}
else
{
int u=jump(x,dep[x]-dep[p]-1),q=jump(y,dep[y]-dep[p]-1);

LL cm=(LL)min(u,q)*(LL)(n+1)+max(q,u);
if(!bp[cm]) bp[cm]=1,h[cm]=vl;
else h[cm]=h[cm]*vl;
s2[x]=s2[x]*vl,s2[y]=s2[y]*vl;
s2[u]=s2[u]*nvl,s2[q]=s2[q]*nvl;
}
}
make(1);
bd(1);
LL ans=0;
fo(i,1,m)
{
int x=ask[i][0],y=ask[i][1],p=ask[i][3];
LL v=ask[i][2];
if(p==x)
{
int q=jump(y,dep[y]-dep[x]-1);
num st=s1[y]*nR(s1[x])*s2[q]*nR(s2[y])*nT(v)*T((1-v+mo)%mo);
if(st.y==0) ans=(ans+st.x)%mo;
}
else
{
int u=jump(x,dep[x]-dep[p]-1),q=jump(y,dep[y]-dep[p]-1);
LL cm=(LL)min(u,q)*(LL)(n+1)+max(q,u);
if(!bp[cm]) bp[cm]=1,h[cm]=T(1);
num st=s1[y]*nR(s1[p])*s1[x]*nR(s1[p])*s2[q]*nR(s2[y])*s2[u]*nR(s2[x])*nR(h[cm])*nT(v)*T((1-v+mo)%mo);
if(st.y==0) ans=(ans+st.x)%mo;
}
}
printf("%lld\n",(ans+mo)%mo);
}


举报

相关推荐

0 条评论