洛谷3768 简单的数学题

题目大意:求$\sum{i = 1}^{n} \sum{j = 1}^{n} ij\gcd(i, j) \mod p$。

题解

遇到gcd的题目不仅仅可以反演,也可以采用一些其他方法。
这道题目就是,直接反演非常麻烦,所以采用下面的变换:

设$S(n)=\sum_{i = 1}^{n}i^2 \phi(i)$,那么有

用杜教筛,时间复杂度$O(n^{\frac{2}{3}})$。

貌似这种做法的底层原理还是莫比乌斯反演加上狄利克雷卷积,这个其实我没有太多的考证(其实是还没学到)。。。
不过运气比较好,这道题的两个维度都是$n$,如果是$m$和$n$就很难办了。。。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include <bits/stdc++.h>
#define INF 2000000000
#define MAXN 4641600
using namespace std;
typedef long long ll;
ll n;
int phi[MAXN + 5], tot = 0, prime[MAXN >> 1], p, inv2, inv4, inv6;
int sum[MAXN + 5] = {0}, sum2[MAXN + 5];
bool vis[MAXN + 5] = {0};
map<ll, int> mp;
int poww(int a, int b, int c){
int res = 1;
while(b){
if(b & 1) res = 1ll * res * a % c;
a = 1ll * a * a % c, b >>= 1;
}
return res;
}
inline int poww3(int r){
int res = 1ll * r * r % p;
int res2 = 1ll * (r + 1) * (r + 1) % p;
res = 1ll * res * res2 % p;
res = 1ll * res * inv4 % p;
return res;
}
inline int poww2(int r){
int res = 1ll * r * (r + 1) % p;
int res2 = 1ll * (r + r + 1) * inv6 % p;
res = 1ll * res * res2 % p;
return res;
}
inline int q3(int l, int r){
int res = poww3(r) - poww3(l - 1);
if(res < 0) res += p;
return res;
}
inline int q2(int l, int r){
int res = poww2(r) - poww2(l - 1);
if(res < 0) res += p;
return res;
}
void getPhi(int N){
phi[1] = sum[1] = sum2[1] = 1;
for(int i = 2; i <= N; ++i){
if(!vis[i])
prime[++tot] = i, phi[i] = i - 1;
for(int j = 1; j <= tot; ++j){
ll t = 1ll * i * prime[j];
if(t > 1ll * N) break;
vis[t] = true, phi[t] = phi[i] * (prime[j] - 1);
if(i % prime[j] == 0){
phi[t] += phi[i];
break;
}
}
sum[i] = sum[i - 1] + phi[i];
if(sum[i] >= p) sum[i] -= p;
int term = 1ll * i * i % p;
term = 1ll * term * phi[i] % p;
sum2[i] = sum2[i - 1] + term;
if(sum2[i] >= p) sum2[i] -= p;
}
}
int solve_s(ll N){
if(N <= MAXN) return sum2[N];
if(mp.count(N)) return mp[N];
int ans = poww3(N % p);
for(ll i = 2, lst = 0; i <= N; i = lst + 1){
lst = N / (N / i);
int distract = 1ll * q2(i % p, lst % p) * solve_s(N / i) % p;
ans -= distract;
if(ans < 0) ans += p;
}
mp[N] = ans;
return ans;
}
void init(){
scanf("%d%lld", &p, &n);
getPhi(MAXN);
inv2 = poww(2, p - 2, p);
inv4 = poww(4, p - 2, p);
inv6 = poww(6, p - 2, p);
}
void solve(){
int ans = 0;
for(ll i = 1, lst = 0; i <= n; i = lst + 1){
lst = n / (n / i);
int term = 1ll * q3(i % p, lst % p) * solve_s(n / i) % p;
ans += term;
if(ans >= p) ans -= p;
}
printf("%d\n", ans);
}
int main(){
init();
solve();
return 0;
}