luogu4723(线性递推)

题意

给定一个 $n$ 阶递推式

给定系数 $a$ 和前 $0\sim n-1$ 项,求 $f(m)$

$n\le32000,m\le10^9$

题解

算法

  1. 构造 $n$ 次多项式 $G(x)$ ,其中 $gn=1$ ,$g_i=-a{n-i}$
  2. 用多项式快速幂求 $H(x)=x^m \bmod{G(x)}$
  3. $\displaystyle f(m)=\sum_{k=0}^{n-1}h_kf(k)$

证明

可以构造 $n$ 阶转移矩阵 $A_0$ ,现在需要求 $A_0^m$ ,时间复杂度为 $O(n^3\log m)$

但是时间复杂度过高,需要优化

我们构造一个矩阵的多项式 $G(A)$ ,满足 $G(A_0)=0$ ,那么只需要求 $H(A)=A^m\mod{G(A)}$ ,即可求得答案,即

这个利用多项式快速幂就可以了

现在问题的关键是求 $G(A)$ ,由 $Cayley-Hamiton$ 定理可得:$gn=1$ ,$g_i=-a{n-i}$

这里证明暂时还没学qaq




代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
/**
*         ┏┓    ┏┓
*         ┏┛┗━━━━━━━┛┗━━━┓
*         ┃       ┃  
*         ┃   ━    ┃
*         ┃ >   < ┃
*         ┃       ┃
*         ┃... ⌒ ...  ┃
*         ┃ ┃
*         ┗━┓ ┏━┛
*          ┃ ┃ Code is far away from bug with the animal protecting          
*          ┃ ┃ 神兽保佑,代码无bug
*          ┃ ┃           
*          ┃ ┃       
*          ┃ ┃
*          ┃ ┃           
*          ┃ ┗━━━┓
*          ┃ ┣┓
*          ┃ ┏┛
*          ┗┓┓┏━━━━━━━━┳┓┏┛
*           ┃┫┫ ┃┫┫
*           ┗┻┛ ┗┻┛
*/

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<queue>
#include<map>
#include<stack>
#include<cmath>
#include<set>
#include<bitset>
#include<complex>
#include<cstdlib>
#include<assert.h>
#define inc(i,l,r) for(int i=l;i<=r;i++)
#define dec(i,l,r) for(int i=l;i>=r;i--)
#define link(x) for(edge *j=h[x];j;j=j->next)
#define mem(a) memset(a,0,sizeof(a))
#define ll long long
#define eps 1e-8
#define succ(x) (1<<x)
#define lowbit(x) (x&(-x))
#define mid (x+y>>1)
#define sqr(x) ((x)*(x))
#define NM 270005
#define nm 400005
using namespace std;
const double pi=acos(-1);
const ll inf=998244353;
ll read(){
ll x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return f*x;
}


inline void reduce(ll&x){x+=x>>63&inf;}
inline ll qpow(ll x,ll t){
ll s=1;
for(;t;t>>=1,x=x*x%inf)if(t&1)s=s*x%inf;
return s;
}

namespace Poly{
int lim,bit,rev[NM],w[NM],W[NM];
ll invn;
void clear(ll*a,ll*b){if(a<b)memset(a,0,sizeof(ll)*(b-a));}
void init(int m){
for(lim=1,bit=0;lim<m;lim<<=1)bit++;invn=qpow(lim,inf-2);
inc(i,1,lim-1)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
ll t=qpow(3,(inf-1)/lim);W[0]=1;
inc(i,1,lim)W[i]=W[i-1]*t%inf;
}
void fft(ll*a,int f=0){
int n=lim;
inc(i,1,n-1)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int k=1;k<n;k<<=1){
int t=n/k>>1;
for(int i=0,j=0;i<k;i++,j+=t)w[i]=W[f?n-j:j];
for(int i=0;i<n;i+=k<<1)
for(int j=0;j<k;j++){
ll x=a[i+j],y=w[j]*a[i+j+k]%inf;
reduce(a[i+j]=x+y-inf);reduce(a[i+j+k]=x-y);
}
}
if(f)inc(i,0,n-1)a[i]=a[i]*invn%inf;
}
ll _a[NM];
void inv(ll*b,ll*a,int m){
if(m==1){b[0]=qpow(a[0],inf-2);return;}
inv(b,a,m+1>>1);init(m<<1);
copy(a,a+m,_a);clear(_a+m,_a+lim);clear(b+(m+1)/2,b+lim);
fft(b);fft(_a);
inc(i,0,lim-1)b[i]=b[i]*(2-_a[i]*b[i]%inf+inf)%inf;
fft(b,1);
}
ll _b[NM];
void div(ll*c,ll*a,ll*b,int n,int m){
reverse_copy(a,a+n,c);reverse(b,b+m);
clear(b+m,b+n-m+1);
inv(_b,b,n-m+1);reverse(b,b+m);
init(n-m+1<<1);
clear(c+n-m+1,c+lim);clear(_b+n-m+1,_b+lim);
fft(_b);fft(c);
inc(i,0,lim-1)c[i]=c[i]*_b[i]%inf;
fft(c,1);
reverse(c,c+n-m+1);
}
void mod(ll*c,ll*a,ll*b,int n,int m){
div(c,a,b,n,m);
init(n);
clear(c+n-m+1,c+lim);
copy(b,b+m,_b);clear(_b+m,_b+lim);m--;
fft(_b);fft(c);
inc(i,0,lim-1)c[i]=c[i]*_b[i]%inf;
fft(c,1);
inc(i,0,m-1)reduce(c[i]=a[i]-c[i]);
clear(c+m,c+lim);clear(_b,_b+lim);
}
ll tmp[NM],_c[NM];
//线性递推专用
void pow(ll*c,ll*a,ll*b,int m,ll t){
copy(a,a+m,_c);c[0]=1;clear(c+1,c+m);
for(;t;t>>=1){
if(t&1){
init(m<<1);clear(_c+m,_c+lim);clear(c+m,c+lim);
fft(c);fft(_c);
inc(i,0,lim-1)c[i]=c[i]*_c[i]%inf;
fft(c,1);fft(_c,1);
mem(tmp);
mod(tmp,c,b,m<<1,m);
copy(tmp,tmp+m,c);
}
init(m<<1);clear(_c+m,_c+lim);
fft(_c);
inc(i,0,lim-1)_c[i]=sqr(_c[i])%inf;
fft(_c,1);
mem(tmp);
mod(tmp,_c,b,m<<1,m);
copy(tmp,tmp+m,_c);
}
}
}

int n;
ll m,a[NM],b[NM],c[NM],ans;


int main(){
m=read();n=read();
inc(i,0,n-1)reduce(a[i]=-read());
reverse(a,a+n);a[n]=1;
inc(i,0,n-1)reduce(b[i]=read());
c[1]=1;
Poly::pow(c,c,a,n+1,m);
inc(i,0,n-1)reduce(ans+=c[i]*b[i]%inf-inf);
printf("%lld\n",ans);
return 0;
}