UOJ#424 【集训队作业2018】count
题意
我们定义长度为
题解
首先
考虑什么样的序列是同构的,那么我们首先要有一个能方便的表示区间最大值的位置的数据结构,那就是笛卡尔树。显然只要两个序列的笛卡尔树同构,这两个序列就同构。
那么关于
由于有多个相同的算在最左边,因此可以发现在笛卡尔树中,每个点的左儿子的键值都小于这个点的键值,右儿子则是小于等于。那么如果这颗笛卡尔树要是合法的,就有一个必要条件:记
接下来我们证明在
于是我们只需要求
设
平凡情况有
那么将上式表达为卷积,就有:
等价于:
直接做似乎很不可做,但是通过这个式子我们可以得到
于是就得到
那么可以通过矩阵快速幂来求
#include<cstdio>
#include<cstring>
#include<algorithm>
using std::swap;
const int mod=998244353;
inline int add(int a,int b)
{
return (a+=b)>=mod?a-mod:a;
}
inline int sub(int a,int b)
{
return (a-=b)<0?a+mod:a;
}
inline int mul(int a,int b)
{
return (long long)a*b%mod;
}
inline int qpow(int a,int b)
{
int res=1;
for(;b;a=mul(a,a),b>>=1)
if(b&1)
res=mul(res,a);
return res;
}
const int N=1e6+5;
int rev[N];
inline void ntt(int *f,int n,int p)
{
int w,wi,u,t;
register int i,j,k;
for(i=0;i<n;i++)
if(i<(rev[i]=i&1?rev[i^1]|n>>1:rev[i>>1]>>1))
swap(f[i],f[rev[i]]);
for(i=1;wi=qpow(qpow(3,(mod-1)/(i<<1)),p^1?mod-2:1),i<<1<=n;i<<=1)
for(j=0;w=1,j<n;j+=i<<1)
for(k=0;k<i;w=mul(w,wi),k++)
u=f[j+k],t=mul(w,f[j+k+i]),f[j+k]=add(u,t),f[j+k+i]=sub(u,t);
if(!~p)
for(w=qpow(n,mod-2),i=0;i<n;i++)
f[i]=mul(w,f[i]);
return;
}
inline void poly_mul(int *f,int *g,int n)
{
register int i;
memset(f+n,0,sizeof(int)*n);
memset(g+n,0,sizeof(int)*n);
ntt(f,n<<1,1);f==g?void():ntt(g,n<<1,1);
for(i=0;i<n<<1;i++)
f[i]=mul(f[i],g[i]);
ntt(f,n<<1,-1);f==g?void():ntt(g,n<<1,-1);
return;
}
int F[N],G[N];
int _g[N];
inline void poly_inv(int *f,int n)
{
register int i,j;
memset(_g,0,sizeof(int)*n);
_g[0]=qpow(f[0],mod-2);
for(i=1;i<<1<=n;i<<=1)
{
memcpy(F,f,sizeof(int)*(i<<1));
memcpy(G,_g,sizeof(int)*i);
poly_mul(G,G,i);poly_mul(F,G,i<<1);
for(j=0;j<i<<1;j++)
_g[j]=sub(add(_g[j],_g[j]),F[j]);
}
memcpy(f,_g,sizeof(int)*n);
return;
}
int a[2][2],b[2][2],res[2][2];
inline void matrix_qpow(int p)
{
res[0][0]=res[1][1]=1;res[0][1]=res[1][0]=0;
for(;p;p>>=1)
{
if(p&1)
{
b[0][0]=add(mul(res[0][0],a[0][0]),mul(res[0][1],a[1][0]));
b[0][1]=add(mul(res[0][0],a[0][1]),mul(res[0][1],a[1][1]));
b[1][0]=add(mul(res[1][0],a[0][0]),mul(res[1][1],a[1][0]));
b[1][1]=add(mul(res[1][0],a[0][1]),mul(res[1][1],a[1][1]));
memcpy(res,b,sizeof(b));
}
b[0][0]=add(mul(a[0][0],a[0][0]),mul(a[0][1],a[1][0]));
b[0][1]=add(mul(a[0][0],a[0][1]),mul(a[0][1],a[1][1]));
b[1][0]=add(mul(a[1][0],a[0][0]),mul(a[1][1],a[1][0]));
b[1][1]=add(mul(a[1][0],a[0][1]),mul(a[1][1],a[1][1]));
memcpy(a,b,sizeof(b));
}
return;
}
int n,m;
int f[N],g[N];
signed main()
{
int _=1<<17,w=1,wi=qpow(3,(mod-1)/_);
register int i;
scanf("%d%d",&n,&m);
if(n<m)
return puts("0"),0;
for(i=0;i<_;i++)
{
a[0][0]=0;a[0][1]=1;a[1][0]=sub(0,w);a[1][1]=1;
matrix_qpow(m);
f[i]=add(res[0][0],res[0][1]);g[i]=add(res[1][0],res[1][1]);
w=mul(w,wi);
}
ntt(f,_,-1);ntt(g,_,-1);
poly_inv(g,_);
poly_mul(f,g,_);
printf("%d\n",f[n]);
return 0;
}