侧边栏壁纸
博主头像
GabrielxD

列車は必ず次の駅へ。では舞台は?私たちは?

  • 累计撰写 674 篇文章
  • 累计创建 128 个标签
  • 累计收到 20 条评论

目 录CONTENT

文章目录

【算法模板】工具 & 模板 & 技巧

GabrielxD
2022-05-24 / 2 评论 / 10 点赞 / 1,785 阅读 / 24,782 字
温馨提示:
本文最后更新于 2023-11-13,若内容或图片失效,请留言反馈。部分素材来自网络,若不小心影响到您的利益,请联系我们删除。
本链接已停止更新,最新内容请转至 GabWiki 以获得更流畅的浏览体验

排序

快速排序

  1. 找到分界点 p(有三种选择:q[l]q[l + r >> 1]q[r])。
  2. 将区间 [l,r][l, r] 划分为两段, 使得分界点左边所有数 LeftpLeft \le p,分界点右边所有数 RightqRight \ge q
  3. 递归地排序左右两边:sort(Left)sort(Right)

应用:【快速选择】第k个数

void quickSort(int[] nums, int l, int r) {
    if (l >= r) return;
    int p = nums[l + (r - l >> 1)], i = l - 1, j = r + 1;
    while (i < j) {
        do ++i; while (nums[i] < p);
        do --j; while (nums[j] > p);
        if (i < j) {
            int t = nums[i];
            nums[i] = nums[j];
            nums[j] = t;
        }
    }
    quickSort(nums, l, j);
    quickSort(nums, j + 1, r);
}
void quick_sort(int nums[], int l, int r) {
    if (l >= r) return;
    int p = nums[l + (r - l >> 1)], i = l - 1, j = r + 1;
    while (i < j) {
        do ++i; while (nums[i] < p);
        do --j; while (nums[j] > p);
        if (i < j) swap(nums[i], nums[j]);
    }
    quick_sort(nums, l, j);
    quick_sort(nums, j + 1, r);
}

归并排序

  1. 把区间 [l,r][l, r] 从中点(mid = l + r >> 1)分割为 [l,mid][l, mid] 以及 [mid+1,r][mid + 1, r]
  2. 递归排序左右两边:sort(l, mid)sort(mid + 1, r)
  3. 归并,将左右两个有序序列合成为一个有序序列。

应用:【归并排序】逆序对的数量

int[] tmp; // 全局开一个与nums一样长的数组(在主函数中赋值)

void mergeSort(int[] nums, int l, int r) {
    if (l >= r) return;
    int mid = l + (r - l >> 1);
    mergeSort(nums, l, mid);
    mergeSort(nums, mid + 1, r);
    int i = l, j = mid + 1, k = 0;
 	while (i <= mid && j <= r) {
        if (nums[i] <= nums[j]) tmp[k++] = nums[i++];
        else tmp[k++] = nums[j++];
    }
    while (i <= mid) tmp[k++] = nums[i++];
    while (j <= r) tmp[k++] = nums[j++];
    for (i = l, j = 0; i <= r; ++i, ++j) nums[i] = tmp[j];
}
int tmp[N]; // 全局开一个与nums一样长的数组

void merge_sort(int nums[], int l, int r) {
    if (l >= r) return;
    int mid = l + (r - l >> 1);
    merge_sort(nums, l, mid);
    merge_sort(nums, mid + 1, r);
    int i = l, j = mid + 1, k = 0;
    while (i <= mid && j <= r) {
        if (nums[i] <= nums[j]) tmp[k++] = nums[i++];
        else tmp[k++] = nums[j++];
    }
    while (i <= mid) tmp[k++] = nums[i++];
    while (j <= r) tmp[k++] = nums[j++];
    for (i = l, j = 0; i <= r; ++i, ++j) nums[i] = tmp[j];
}

二分查找

整数二分

通用模板

check() 判断 mid 是否满足某种性质。

模板一

区间 [l,r][l, r] 被划分成 [l,mid][l, mid][mid+1,r][mid + 1, r] 时使用。

bool check(int x) {/* ... */} // 检查x是否满足某种性质

int bsearch(int l, int r) {
    while (l < r) {
        int mid = l + (r - l >> 1);
        if (check(mid)) r = mid;
        else l = mid + 1;
    }
    return l;
}
模板二
bool check(int x) {/* ... */} // 检查x是否满足某种性质

int bsearch(int l, int r) {
    while (l < r) {
        int mid = l + (r - l + 1 >> 1);
        if (check(mid)) l = mid;
        else r = mid - 1;
    }
    return l;
}

相等

int binarySearch(int[] nums, int target) {
    int left = 0, right = nums.length - 1;
    while (left <= right) {
        int mid = left + ((right - left) >> 1);
        if (nums[mid] == target) return mid;
        else if (nums[mid] < target) left = mid + 1;
        else right = mid - 1;
    }
    return -1;
}
int binary_search(vector<int>& nums, int target) {
    int left = 0, right = nums.size() - 1;
    while (left <= right) {
        int mid = left + ((right - left) >> 1);
        if (nums[mid] == target) return mid;
        else if (nums[mid] < target) left = mid + 1;
        else right = mid - 1;
    }
    return -1;
}
左边界

找到等于 x 的元素中最靠左的索引,没有符合要求的索引则返回 -1
等同于下文的:大于等于

int left_bound(vector<int>& nums, int target) {
    int left = 0, right = nums.size() - 1;
    while (left < right) {
        int mid = left + (right - left >> 1);
        if (nums[mid] >= target) right = mid;
        else left = mid + 1;
    }
    return nums[left] == target ? left : -1;
}
右边界

找到等于 x 的元素中最靠右的索引,没有符合要求的索引则返回 -1
等同于下文的:小于等于

int right_bound(vector<int>& nums, int target) {
    int left = 0, right = nums.size() - 1;
    while (left < right) {
        int mid = left + (right - left + 1 >> 1);
        if (nums[mid] <= target) left = mid;
        else right = mid - 1;
    }
    return nums[left] == target ? left : -1;
}

小于等于

找到小于等于 x 的元素中最大的索引,没有符合要求的索引则返回 -1

int bsearch_leq(vector<int>& nums, int target) {
    int left = 0, right = nums.size() - 1;
    while (left < right) {
        int mid = left + ((right - left + 1) >> 1);
        if (nums[mid] <= target) left = mid;
        else right = mid - 1;
    }
    return nums[left] <= target ? left : -1;
}

大于等于

找到大于等于 x 的元素中最小的索引,没有符合要求的索引则返回 -1

int bsearch_geq(vector<int>& nums, int target) {
    int left = 0, right = nums.size() - 1;
    while (left < right) {
        int mid = left + ((right - left) >> 1);
        if (nums[mid] >= target) right = mid;
        else left = mid + 1;
    }
    return nums[left] >= target ? left : -1;
}

浮点数二分

eps 表示精度,取决于题目对精度的要求(一般来说取题目精度要求小两个数量级,比如题目要求 10610^{-6} 那么 eps 按经验会取 10810^{-8})。

const double eps = 1e-6;

double bsearch(double l, double r) {
    while (r - l > eps) {
        double mid = (l + r) / 2;
        if (check(mid)) r = mid;
        else l = mid;
    }
    return l;
}

nm 次方根

模板题:数的三次方根

double root(double n, int m, double eps) {
    double l = 0.0, r = Math.max(n, 1.0);
    while (r - l > eps) {
        double c = (l + r) / 2;
        double prod = 1.0;
        for (int i = 0; i < m; ++i) prod *= c;
        if (prod >= n) r = c;
        else l = c;
    }
    return l;
}

高精度

待补充…

(Java 用不到,一时半会不会补www)

前缀和

一维前缀和

前缀和也是可以用在异或运算上的

原数组下标从 00 开始:

  • 定义:s[i]=a[0]+a[1]+...+a[i1]=j=0i1a[j]s[i] = a[0] + a[1] + ... + a[i-1] = \sum_{j=0}^{i-1}a[j]
  • 构造:s[0]=0s[0] = 0s[i]=s[i1]+a[i1]s[i] = s[i - 1] + a[i - 1]
  • 使用:sum(a[lr])=s[r+1]S[l]sum(a[l \dots r]) = s[r + 1] - S[l]

原数组下标从 11 开始:

  • 定义:s[i]=a[1]+a[2]+...+a[i]=j=1ia[j]s[i] = a[1] + a[2] + ... + a[i] = \sum_{j=1}^{i}{a[j]}
  • 构造:s[0]=0s[0]=0s[i]=s[i1]+a[i]s[i] = s[i - 1] + a[i]
  • 使用:sum(a[lr])=S[r]S[l1]sum(a[l \dots r]) =S[r] - S[l - 1]

二维前缀和

矩阵横纵坐标均从 11 开始:

  • 定义:s[i][j]=ij列格子左上部分所有元素的和=x=1i(y=1ja[x][y])s[i][j] = 第i行j列格子左上部分所有元素的和 = \sum_{x=1}^{i}(\sum_{y=1}^j{a[x][y]})
  • 构造:s[i][j]=s[i1][j]+s[i][j1]s[i1][j1]+a[i][j]s[i][j] = s[i - 1][j] + s[i][j - 1] - s[i - 1][j - 1] + a[i][j]
  • 使用:以 (x1,y1)(x1, y1) 为左上角,(x2,y2)(x2, y2) 为右下角的子矩阵的和为:
    s[x2][y2]s[x11][y2]s[x2][y11]+s[x11][y11]s[x2][y2] - s[x1 - 1][y2] - s[x2][y1 - 1] + s[x1 - 1][y1 - 1]

差分

一维差分

模板题:797. 差分 - AcWing题库

给数组 b 的区间 [l,r][l, r] 中的每个数加上 c

b[l] += c;
b[r + 1] -= c

二维差分

模板题:798. 差分矩阵 - AcWing题库

给矩阵 b 中以 (x1,y1)(x1, y1) 为左上角,(x2,y2)(x2, y2) 为右下角的子矩阵中的所有元素加上 c

b[x1][y1] += c;
b[x2 + 1][y1] -= c;
b[x1][y2 + 1] -= c;
b[x2 + 1][y2 + 1] += c;

双指针

常见问题分类:

  1. 对于一个序列,用两个指针维护一段区间。
  2. 对于两个序列,维护某种次序,比如归并排序中合并两个有序序列的操作。

实例:【双指针】最长连续不重复子序列

for (int i = 0, j = 0; i < n; ++i) {
    while (j < i && check(i, j)) ++j;
	// 具体问题的逻辑
}

滑动窗口

void slidingWindow(char[] s) {
    int n = s.length;
    Map<Character, Integer> window = new HashMap<>();
    int left = 0, right = 0;
    while (right < n) {
        char curr = s[right++];
    	window.put(curr, window.getOrDefault(curr, 0) + 1);
        // 更新窗口内数据
        // ...
        // 判断窗口是否需要收缩
        while (窗口需要收缩) {
            // 在窗口每次收缩前更新答案
            // ...
            curr = s[left++];
            window.put(curr, window.getOrDefault(curr, 0) - 1);
            // 更新窗口内数据
        	// ...
        }
        // 在窗口收缩后更新答案
        // ...
    }
}
void sliding_window(string s) {
    int n = s.length();
    unordered_map<char, int> window;
    int left = 0, right = 0;
    while (right < n) {
        char curr = s[right++];
        ++window[curr];
        // 更新窗口内数据
        // ...
        // 判断窗口是否需要收缩
        while (窗口需要收缩) {
            // 在窗口每次收缩前更新答案
            // ...
            curr = s[left++];
            --window[curr];
            // 更新窗口内数据
        	// ...
        }
        // 在窗口收缩后更新答案
        // ...
    }
}

位运算

  • n 的第 k 位数字:n >> k & 1
  • 获取 n 二进制中最低位的 1 (例如 n=110100(2)n = 110100_{(2)} , 那么 lowbit(n)=100(2)lowbit(n) = 100_{(2)} ):
    • lowbit(n) = n & -n
    • lowbit(n) = n & (n ^ (n-1))

异或运算

  • 交换律:pq=qpp \oplus q = q \oplus p
  • 结合律:p(qr)=(pq)rp \oplus (q \oplus r) = (p \oplus q) \oplus r
  • 恒等律:p0=pp \oplus 0 = p
  • 归零律:pp=0p \oplus p = 0
  • 自反性:pqq=p0=pp \oplus q \oplus q = p \oplus 0 = p

离散化

离散化是一种比较特殊的哈希方式。

详细介绍:离散化 - OI Wiki

实例:【离散化, 前缀和】区间和「离散化经典应用」

C++ 模板

vector<int> alls; // 存储所有待离散化的值
sort(alls.begin(), alls.end()); // 排序
alls.erase(unique(alls.begin(), alls.end()), alls.end()); // 去重

二分求出 x 对应的离散化的值:

int find(int x) {
    int l = 0, r = alls.size() - 1;
    while (l < r) {
        int mid = l + r >> 1;
        // 找到第一个大于等于x的位置
        if (alls[mid] >= x) r = mid;
        else l = mid + 1;
    }
    // 映射若从0开始直接返回 从1开始返回要+1
    return r + 1;
}

Java 模板

List<Integer> alls = new ArrayList<>(); // 存储所有待离散化的值
Collections.sort(alls); // 排序
alls = alls.subList(0, unique()); // 去重 

unique() 函数实现:
实现原理见:【双指针】删除有序数组中的重复项

int unique() {
    int n = alls.size();
    int slow = 0;
    for (int fast = 0; fast < n; ++fast) {
        // 注意这里拿到的是整数包装类 所以一定要使用equals()方法判断是否相同
        if (!alls.get(slow).equals(alls.get(fast))) alls.set(++slow, alls.get(fast));
    }
    return slow + 1;
}

二分求出 x 对应的离散化的值:

int find(int x) {
    int l = 0, r = alls.size() - 1;
    while (l < r) {
        int mid = l + r >> 1;
        if (alls.get(mid) >= x) r = mid;
        else l = mid + 1;
    }
    return l + 1;
}

区间合并

将所有存在交集的区间合并

C++ 模板

void merge(vector<PII>& segs) {
    vector<PII> res;
    sort(segs.begin(), segs.end());
    int st = NEG_INF, ed = NEG_INF;
    for (auto& seg : segs) {
        if (ed < seg.first) {
            if (st != NEG_INF) res.push_back({st, ed});
            st = seg.first, ed = seg.second;
        } else ed = max(ed, seg.second);
    }
    if (st != NEG_INF) res.push_back({st, ed});
    segs = res;
}

Java 模板

List<int[]> merge(List<int[]> segs) {
    segs.sort((a, b) -> a[0] - b[0]);
    List<int[]> res = new ArrayList<>();
    int st = NEG_INF, ed = NEG_INF;	
    for (int[] seg : segs) {
        if (ed < seg[0]) {
            if (st != NEG_INF) res.add(new int[]{st, ed});
            st = seg[0];
            ed = seg[1];
        } else ed = Math.max(ed, seg[1]);
    }
    if (st != NEG_INF) res.add(new int[]{st, ed});
    return res;
}

链表

单链表

数组模拟单链表:

// vals存储节点的值 nexts存储指向下一个节点的指针
int vals[N], nexts[N];
// head表示指向表头的指针(vals[head]就会取到表头的值)
// idx指向当前可用的节点(链表尾的空节点)
int head, idx;

// 初始化 头节点指针为-1表示没有头节点 idx=1表示当前可用的节点是第1个节点
void init() {
    head = -1;
    idx = 0;
}

// 向链表头部插入一个值为x的节点
void insert_to_head(int x) {
    vals[idx] = x;
    nexts[idx] = head;
    head = idx++;
}

// 在第k个节点*之后*插入一个值为x的节点
void insert(int k, int x) {
    vals[idx] = x;
    nexts[idx] = nexts[k];
    nexts[k] = idx++;
}

// 删除第k个节点*之后*的一个节点
void remove(int k) {
    nexts[k] = nexts[nexts[k]];
}

双链表

数组模拟双链表:

// vals存储节点的值 nexts存储指向下一个节点的指针 prevs存储上一个节点的指针
int vals[N], prevs[N], nexts[N];
int idx; // idx指向当前可用的节点(链表尾的空节点)

// 初始化 在双链表中默认第0个节点为头节点 第1个节点为尾节点
// 此时头节点下一个节点是尾节点 尾节点上一个节点是头节点 idx=2表示当前可用的节点是第2个节点
void init() {
    nexts[0] = 1;
    prevs[1] = 0;
    idx = 2;
}

// 在第k个节点*之后*插入一个值为x的节点
void insert(int k, int x) {
    vals[idx] = x;
    prevs[idx] = k;
    nexts[idx] = nexts[k];
    prevs[nexts[idx]] = idx;
    nexts[k] = idx++;
}

// 删除第k个节点
void remove(int k) {
    prevs[nexts[k]] = prevs[k];
    nexts[prevs[k]] = nexts[k];
}

数组模拟栈:

int stk[N];
int tt = 0; // 栈顶初始化为0 栈中元素从1开始

// 向栈顶推入一个数x
void push(int x) {
    stk[++tt] = x;
}

// 从栈顶弹出一个数
void pop() {
    --tt;
}

// 判断栈是否为空 tt==0时表示栈为空
bool empty() {
    return !tt;
}

// 返回栈顶的值
int top() {
    return stk[tt];
}

队列

数组模拟普通队列:

int que[N];
int hh = 0, tt = -1; // hh表示队头 tt表示队尾 队中元素从0开始

// 向队尾放入一个数x
void offer(int x) {
    que[++tt] = x;
}

// 从队头拉出一个数
void poll() {
    ++hh;
}

// 判断队列是否为空 tt<hh时表示队列为空
bool empty() {
    return tt < hh;
}

// 返回队头的值
int front() {
    return que[hh];
}

// 返回队尾的值
int back() {
    return que[tt];
}

数组模拟循环队列:

int que[N];
int hh = 0, tt = 0; // hh表示队头 tt表示队尾的后一个位置

// 向队尾放入一个数x
void offer(int x) {
    que[tt++] = x;
    if (tt == N) tt = 0;
}

// 从队头拉出一个数
void poll() {
    if (++hh == N) hh = 0;
}

// 返回队头的值
int front() {
    return que[hh];
}

// 判断队列是否为空 tt==hh时表示队列为空
bool empty() {
    return tt == hh;
}

单调栈

模板题:找出每个数左边第一个比它大/小的数

int stk[N], tt = 0;
for (int x : nums) {
    while (tt > 0 && check(stk[tt], x)) --tt;
    stk[++tt] = x;
}

单调队列

模板题:找出滑动窗口中的最大值/最小值

int que[N], hh = 0, tt = -1;
for (int i = 0; i < n; ++i) {
    while (hh <= tt && check_out(que[hh])) ++hh; // 判断队头是否滑出窗口
    while (hh <= tt && check(que[tt], nums[i])) --tt;
    q[++tt] = i;
}

应用

求滑动窗口最小值

// 数字序列长度为n  序列输入至a[0...n] 滑动窗口长度为k
int que[N], hh = 0, tt = -1; // 数组模拟双端队列

for (int i = 0; i < n; ++i) {
    if (hh <= tt && i - que[hh] == k) ++hh;
    while (hh <= tt && a[que[tt]] >= a[i]) --tt;
    que[++tt] = i;
    if (i >= k - 1) printf("%d ", a[que[hh]]); // 输出滑动窗口最小值(仅在滑动窗口中元素满k个时输出)
}

求滑动窗口最大值

// 数字序列长度为n  序列输入至a[0...n] 滑动窗口长度为k
int que[N], hh = 0, tt = -1; // 数组模拟双端队列

for (int i = 0; i < n; ++i) {
    if (hh <= tt && i - que[hh] == k) ++hh;
    while (hh <= tt && a[que[tt]] <= a[i]) --tt;
    que[++tt] = i;
    if (i >= k - 1) printf("%d ", a[que[hh]]); // 输出滑动窗口最大值(仅在滑动窗口中元素满k个时输出)
}

KMP 算法

经典应用:【KMP算法, Rabin-Karp算法, 快速幂】找出字符串中第一个匹配项的下标

// pat: 模式串, str: 主串
int m = pat.length, n = str.length;

// 求next数组
int[] nexts = new int[m];
nexts[0] = -1;
for (int i = 1, j = -1; i < m; ++i) {
    while (j >= 0 && pat[j + 1] != pat[i]) j = nexts[j]; // 前后缀不相同了 向前回退
    if (pat[j + 1] == pat[i]) ++j; // 找到相同的前后缀
    nexts[i] = j;
}

// 匹配
for (int i = 0, j = -1; i < n; ++i) {
	while (j >= 0 && pat[j + 1] != str[i]) j = nexts[j];
	if (pat[j + 1] == str[i]) ++j;
	if (j + 1 == m) {
		// 匹配成功
		j = nexts[j]; // 回退进行下一次匹配
	}
}

Trie 树

模板题:Trie字符串统计

// N为最大可能插入的字符串数 M为最大可能的节点数(最大可能的字符串总长度)
int sons[M][26], cnt[N], idx;

void insert(char word[]) {
    int p = 0;
    for (int i = 0; word[i]; ++i) {
        int curr = word[i] - 'a';
        if (!sons[p][curr]) sons[p][curr] = ++idx;
        p = sons[p][curr];
    }
    ++cnts[p];
}

int count(char word[]) {
    int p = 0;
    for (int i = 0; word[i]; ++i) {
        int curr = word[i] - 'a';
        if (!sons[p][curr]) return 0;
        p = sons[p][curr];
    }
    return cnts[p];
}

并查集

路径压缩优化的并查集

模板题:合并集合

class UnionFind {
    private int[] root;
    
    public UnionFind(int size) {
        root = new int[size];
        for (int i = 0; i < size; ++i) root[i] = i;
    }
    
    public int find(int n) {
        return n == root[n] ? n : (root[n] = find(root[n]));
    }

    public void union(int p, int q) {
        root[find(p)] = find(q);
    }
    
    public boolean isConnected(int p, int q) {
        return find(p) == find(q);
    }
}
class UnionFind {
private:
    int* root;

public:
    UnionFind(int size) : root(new int[size]) {
        for (int i = 0; i < size; ++i) root[i] = i;
    }

    int find(int n) {
        return n == root[n] ? n : (root[n] = find(root[n]));
    }

    void join(int p, int q) {
        root[find(p)] = find(q);
    }
    
    bool is_connected(int p, int q) {
        return find(p) == find(q);
    }
};

基于路径压缩的按秩合并优化的并查集

class UnionFind {
    private int[] root;
    private int[] rank;
    
    public UnionFind(int size) {
        root = new int[size];
        rank = new int[size];
        for (int i = 0; i < size; ++i) root[i] = i;
    }
    
    public int find(int n) {
        return n == root[n] ? n : (root[n] = find(root[n]));
    }

    public void union(int p, int q) {
        int rootP = find(p), rootQ = find(q);
        if (rootP == rootQ) return;
        if (rank[rootP] > rank[rootQ]) root[rootQ] = rootP;
        else {
        	root[rootP] = rootQ;
            if (rank[rootP] == rank[rootQ]) ++rank[rootP];
        }
    }
    
    public boolean isConnected(int p, int q) {
        return find(p) == find(q);
    }
}
class UnionFind {
private:
    int* root;
    int* rank;

public:
    UnionFind(int size) : root(new int[size]), rank(new int[size]) {
        for (int i = 0; i < size; ++i) {
            root[i] = i;
            rank[i] = 1;
        }
    }

    int find(int n) {
        return n == root[n] ? n : (root[n] = find(root[n]));
    }

    void join(int p, int q) {
        int root_p = find(p), root_q = find(q);
        if (root_p == root_q) return;
        if (rank[root_p] > rank[root_q]) root[root_q] = root_p;
        else {
        	root[root_p] = root_q;
            if (rank[root_p] == rank[root_q]) ++rank[root_p];
        }
    }
    
    bool is_connected(int p, int q) {
        return find(p) == find(q);
    }
};

统计连通分量的综合并查集

class UnionFind {
    public int groups;
    private int[] root;
    private int[] rank;
    
    public UnionFind(int size) {
        groups = size;
        root = new int[size];
        rank = new int[size];
        for (int i = 0; i < size; ++i) root[i] = i;
    }
    
    public int find(int n) {
        return n == root[n] ? n : (root[n] = find(root[n]));
    }

    public void union(int p, int q) {
        int rootP = find(p), rootQ = find(q);
        if (rootP == rootQ) return;
        if (rank[rootP] > rank[rootQ]) root[rootQ] = rootP;
        else {
        	root[rootP] = rootQ;
            if (rank[rootP] == rank[rootQ]) ++rank[rootP];
        }
        --groups;
    }
    
    public boolean isConnected(int p, int q) {
        return find(p) == find(q);
    }
}

维护每个集合大小的并查集

模板题:连通块中点的数量

class UnionFind {
    private int[] roots, sizes;
    
    public UnionFind(int size) {
        roots = new int[size];
        for (int i = 0; i < size; ++i) {
            roots[i] = i;
            sizes[i] = 1;
        }
    }
    
    public int find(int n) {
        return n == root[n] ? n : (root[n] = find(root[n]));
    }

    public void union(int p, int q) {
        int rp = find(p), rq = find(q);
        if (rp != rq) {
            root[rp] = rq;
        	sizes[rq] += sizes[rp];
        }
    }
    
    public boolean isConnected(int p, int q) {
        return find(p) == find(q);
    }
}

维护每个节点到根节点距离的并查集

class UnionFind {
    private int[] roots, dists;
    
    public UnionFind(int size) {
        roots = new int[size];
        for (int i = 0; i < size; ++i) {
            roots[i] = i;
            dists[i] = 0;
        }
    }
    
    public int find(int n) {
        return n == root[n] ? n : (root[n] = find(root[n]));
    }

    public void union(int p, int q) {
        roots[find(p)] = find(q);
        dists[find(p)] = distance; // 根据具体问题,初始化find(a)的偏移量
    }
    
    public boolean isConnected(int p, int q) {
        return find(p) == find(q);
    }
}

反集

作用

并查集反集用来储存并查集维护集合性质相反(互斥)的集合。

并查集的反集适用于只有元素只有两种性质的题目,也就是说,这个元素不属于并查集维护集合,则其必定属于另一个集合。

原理

如果要将两个性质相斥的元素 a,ba,b 合并,可以用并查集合并 union(a+n,b),union(b+n,a)union(a + n, b), \enspace union(b + n, a)a+na+n 相当于 aa 的虚拟敌人, union(a+n,b)union(a + n, b) 相当于把 bbaa 的虚拟敌人合并, union(b+n,a)union(b + n, a) 同理)。再之,如果 aacc 的性质互斥,合并 union(a+n,c),union(c+n,a)union(a + n, c), \enspace union(c + n, a) ,此时 bbcc 性质相同,它们也成功合并在了一起(此时并查集中: {{a,b+n,c+n},{b,a+n,c}}\{\{a, b + n, c + n\}, \enspace\{b, a + n, c\}\} )。

实现方式

在初始化并查集时初始化两倍于数据范围大小的并查集,超出数据范围部分称为反集。

经典应用

判二分图。

模板题:堆排序模拟堆

// h存储堆中的值, h[1]是堆顶, x的左儿子是2x, 右儿子是2x + 1
// ph[k]存储第k个插入的点在堆中的位置
// hp[k]存储堆中下标是k的点是第几个插入的
int h[N], ph[N], hp[N], size;

// 交换两个点,及其映射关系
void heap_swap(int a, int b) {
    swap(ph[hp[a]],ph[hp[b]]);
    swap(hp[a], hp[b]);
    swap(h[a], h[b]);
}

void down(int k) {
    int min = k;
    if (k * 2 <= size && h[k * 2] < h[min]) min = k * 2;
    if (k * 2 + 1 <= size && h[k * 2 + 1] < h[min]) min = k * 2 + 1;
    if (k != min) {
        heap_swap(min, k);
        down(min);
    }
}

void up(int k) {
    while (k / 2 > 0 && h[k] < h[k / 2]) {
        heap_swap(k, k / 2);
        k /= 2;
    }
}

// O(n)时间把数组建成堆
for (int i = n / 2; i > 0; --i) down(i);

Java API

Queue<Integer> minHeap = new PriorityQueue<>(); // 小根堆
Queue<Integer> maxHeap = new PriorityQueue<>((a, b) -> b - a); // 大根堆

C++ STL

priority_queue<int, vector<int>, greater<int>> min_heap; // 小根堆
priority_queue<int> max_heap; // 大根堆

哈希

一般哈希

模板题:【哈希表】模拟散列表「哈希表基础」

素数选取表:

lwr upr prime
2^5 2^6 53
2^6 2^7 101
2^7 2^8 193
2^8 2^9 389
2^9 2^10 769
2^10 2^11 1531
2^11 2^12 3061
2^12 2^13 6113
2^13 2^14 12253
2^14 2^15 24379
2^15 2^16 48883
2^16 2^17 97787
2^17 2^18 195677
2^18 2^19 391627
2^19 2^20 783259
2^20 2^21 1566401
2^21 2^22 3133987
2^22 2^23 6269119
2^23 2^24 12538073
2^24 2^25 25082363
2^25 2^26 50170979
2^26 2^27 100353503
2^27 2^28 200730139
2^28 2^29 401498927
2^29 2^30 803081491

开放寻址法

// 一般N取比数据范围大两倍的素数 INF取0x3f3f3f3f
int ht[N];

// 把ht中的值全部初始化为INF
memset(ht, 0x3f, sizeof(ht));

// 如果x在哈希表中就返回x的下标 否则返回x应该插入的位置
int find(int x) {
    int y = (x % N + N) % N;
    while (ht[y] != INF && ht[y] != x) {
        if (++y == N) y = 0;
    }
    return y;
}

拉链法

拉链的方法与单链表向头节点前插入相同。

// 一般N取两到三倍数据范围的素数
int ht[N], vals[N], nexts[N], idx;

// 初始化
void init() {
    idx = 0;
    memset(ht, -1, sizeof(ht));
}

// 向哈希表中插入一个数
void insert(int x) {
    int y = (x % N + N) % N;
    // 从单链表头插入
    vals[idx] = x;
    nexts[idx] = ht[y];
    ht[y] = idx++;
}

// 在哈希表中查询某个数是否存在
bool find(int x) {
    int y = (x % N + N) % N;
    for (int i = ht[y]; i != -1; i = nexts[i]) {
        if (vals[i] == x) return true;
    }
    return false;
}

字符串哈希

核心思想:将字符串看成P进制数,P的经验值是 1311311333113331,取这两个值的冲突概率低。
小技巧 - 自然溢出:取模的数用 2642^{64},这样直接用 unsigned long long (Java 用 long)存储,溢出的结果就是取模的结果,这样可以省去取模的代码和时间。

模板题:字符串哈希

typedef unsigned long long ULL;

// h[k]存储字符串前k个字母的哈希值, p[k]存储 P^k mod 2^64
ULL h[N], p[N]; 

// 初始化
void init() {
    // s是下标从0开始的字符串 n是其长度
    p[0] = 1;
    for (int i = 1; i <= n; ++i) {
        h[i] = h[i - 1] * P + str[i - 1];
        p[i] = p[i - 1] * P;
    }
}

// 计算子串 str[l...r] 的哈希值
ULL sub_hash(int l, int r) {
    return h[r] - h[l - 1] * p[r - l + 1];
}

BFS

平面坐标系中模拟向周围 BFS。

const int DIR[][2] = {{1, 0}, {-1, 0}, {0, 1}, {0, -1}}; // 四个方向
const int DIR[][2] = {{1, -1}, {1, 0}, {1, 1}, {0, -1}, {0, 1}, {-1, -1}, {-1, 0}, {-1, 1}}; // 八个方向

void bfs(int x, int y) {
    queue<PII> que;
    // ...
    while (!que.empty()) {
        auto pr = que.front();
        que.pop();
        int x = pr.first, y = pr.second;
        // ...
        for (auto& DIR : DIRS) {
            if (nx >= 0 && nx < n && ny >= 0 && ny < m /* 其它条件 */) {
            	que.push({nx, ny});
                // ...
            }
        }
    }
}

求「最少能从初始状态到达结束状态的步数」问题。

import java.util.*;
import java.io.*;

public class Main {
    // 状态对象
    static class State {
        T content; // 状态
        int step; // 步数

        State(T content, int step) {
            // 初始化
            this.content = content;
            this.step = step;
        }

        // 重写equals方法对比状态是否相同
        @Override
        public boolean equals(Object obj) {
            if (!(obj instanceof State)) return false;
            State s = (State) obj;
            // 对比状态
            // ...
        }

        // 重写hashCode方法用于存入集合 以便于快速确认状态的存在与否
        @Override
        public int hashCode() {
            // ...
        }

        // 由当前状态生成新的状态
        State newFrom(T content) {
            // ...
            return new State(content, step + 1);
        }
    }

    static State origin, target;
    
    public static void main(String[] args) throws IOException {
        // 输入状态 初始状态的步数皆为 0
        origin = new State(/* 初始状态 */, 0);
        target = new State(/* 目标状态 */, 0);
        // 输出结果(最少能从初始状态到达结束状态的步数)
        System.out.println(bfs());
    }

    static int bfs() {
        Queue<State> que = new LinkedList<>(); // BFS辅助队列
    	Set<State> vis = new HashSet<>(); // 判断状态是否已经见到过 用于剪枝
        que.offer(origin); // 放入初始状态
        while (!que.isEmpty()) {
            // 取出队尾状态
            State state = que.poll();
            // base case: 当前状态和达到目标则直接返回步数即是最短步数
            if (state.equals(target)) return state.step;
            // 枚举所有合法的状态转移
            for (eachValid in possibilities) {
                // 生成新状态
                State newState = state.newFrom(eachValid);
                // 若新生成的状态在以前没遇到过就把它加入队列与集合
                if (!set.contains(newState)) {
                    set.add(newState);
                    que.add(newState);
                }
            }
        }
        return -1;
    }
}

DFS

应用

指数型枚举

1n1∼nnn 个整数中随机选取任意多个,输出所有可能的选择方案。

int n;
boolean used = new int[n + 1];

// 尝试选/不选x
void dfs(int x) {
    if (x > n) {
        for (int i = 1; i <= n; ++i) {
            if (used[i]) System.out.print(i + " ");
        }
        System.out.print("\n");
        return;
    }
    dfs(x + 1);
    used[x] = true;
    dfs(x + 1);
    used[x] = false;
}

// usage: dfs(1)
int n, used[N];

// 尝试选/不选x
void dfs(int x) {
    if (x > n) {
        for (int i = 1; i <= n; ++i) {
            if (used[i]) printf("%d ", i);
        }
        return;
    }
    dfs(x + 1);
    used[x] = true;
    dfs(x + 1);
    used[x] = false;
}

// usage: dfs(1)

组合型枚举

1n1∼nnn 个整数中随机选出 mm 个,输出所有可能的选择方案。

int n, m;
int[] chosen = new int[m];

// 尝试把st...n放在chosen[idx]
void dfs(int idx, int st) {
    if (idx + n - st + 1 < m) return; // 剪枝
    if (idx == m) {
        for (int x : chosen) System.out.print(x + " ");
        System.out.print("\n");
        return;
    }
    for (int i = st; i <= n; ++i) {
        chosen[idx] = i;
        dfs(idx + 1, i + 1);
    }
}

// usage: dfs(0, 1)
int chosen[M], idx;

// 尝试把st...n放在chosen[idx]
void dfs(int idx, int st) {
    if (idx + n - st + 1 < m) return; // 剪枝
    if (idx == m) {
        for (int i = 0; i < m; ++i) printf("%d ", chosen[i]);
        puts("");
        return;
    }
    for (int i = st; i <= n; ++i) {
        chosen[idx] = i;
        dfs(idx + 1, i + 1);
    }
}

// usage: dfs(0, 1)

排列型枚举

1n1∼nnn 个整数排成一行后随机打乱顺序,输出所有可能的次序方案。

int n;
int[] chosen = new int[n];
boolean[] used = new boolean[n + 1];

// 尝试把1...n中没用过的数放在chosen[idx]
void dfs(int idx) {
    if (idx == n) {
        for (int x : chosen) System.out.print(x + " ");
        System.out.print("\n");
        return;
    }
    for (int x = 1; x <= n; ++x) {
        if (!used[x]) {
            chosen[idx] = x;
            used[x] = true;
            dfs(idx + 1);
            used[x] = false;
        }
    }
}

// usage: dfs(0)
int n, chosen[N];
bool used[N];

// 尝试把1...n中没用过的数放在chosen[idx]
void dfs(int idx) {
    if (idx == n) {
        for (int i = 0; i < n; ++i) printf("%d ", chosen[i]);
        puts("");
        return;
    }
    for (int x = 1; x <= n; ++x) {
        if (!used[x]) {
            chosen[idx] = x;
            used[x] = true;
            dfs(idx + 1);
            used[x] = false;
        }
    }
}

// usage: dfs(0)

剪枝

  • 优化搜索顺序。
    • 大部分情况下,我们应该优先搜索分支较少的节点。
  • 排除等效冗余。
  • 可行性剪枝。
  • 最优性剪枝。
  • 记忆化搜索(DP)。

树与图

存储

树是一种特殊的图(无环连通图),与图的存储方式相同:对于无向图中的边 aba \to b,存储两条有向边 ab,baa \to b, b \to a,因此我们可以只考虑有向图的存储。

邻接矩阵

g[a][b]g[a][b] 存储边 aba \to b,使用空间较多,适合存稠密图。

邻接表

链式前向星

加边函数类似于单链表中的插入函数。

// N为点数 M为边数(一般取N的两倍)
// heads中每个节点下标都存储着一个链表(头)
int heads[N], vals[M], nexts[M], idx;

// 初始化
idx = 0;
Arrays.fill(heads, -1);

// 添加一条边a->b
void add(int a, int b) {
    vals[idx] = b;
    nexts[idx] = heads[a];
    heads[a] = idx++;
}
// N为点数 M为边数(一般取N的两倍)
int h[N], e[M], ne[M], idx = 0;

// 初始化
idx = 0;
memset(h, -1, sizeof(h));

// 添加一条边a->b
void add(int a, int b) {
    e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

遍历

时间复杂度:O(N+M)O(N + M)NN 表示顶点数,MM 表示边数)

深度优先遍历

模板题:树的重心

int heads[N], vals[M], nexts[M], idx;
int vis[N]; // 记录点是否被遍历过

void dfs(int u) {
    vis[u] = true;
    for (int i = heads[u]; i != -1; i = nexts[i]) {
        int v = vals[i];
        if (!vis[v]) dfs(v);
    }
}

宽度优先遍历

模板题:图中点的层次

int heads[N], vals[M], nexts[M], idx;
int vis[N]; // 记录点是否被遍历过

void bfs(int u) {
    queue<int> que;
    que.push(u); // 从顶点u开始遍历
	vis[u] = true;
    while (!que.empty()) {
        int t = que.front();
        que.pop();
        for (int i = heads[t]; i != -1; i = nexts[i]) {
            int v = vals[i];
            if (!vis[v]) {
                que.push(v);
                st[v] = true;
            }
        }
    }
}

拓扑排序

时间复杂度:O(N+M)O(N + M)NN 表示顶点数,MM 表示边数)

模板题:有向图的拓扑序列

int n, m, idx;
int heads[N], vals[M], nexts[M], ind[N]; // ind维护所有顶点的入度
int que[N], hh, tt = -1; // 维护一个队列

bool top_sort() {
    // 把所有入度为0的顶点先入队
    for (int i = 1; i <= n; ++i) {
        if (!ind[i]) que[++tt] = i;
    }
    while (hh <= tt) {
        int t = que[hh++];
        for (int i = heads[t]; i != -1; i = nexts[i]) {
            int v = vals[i];
            if (--ind[v] == 0) que[++tt] = v;
        }
    }
    // 如果所有点都入队了 说明存在拓扑序列 否则不存在拓扑序列
    return tt == n - 1;
}

if (top_sort()) {
    // 存在拓扑排序 输出拓扑排序
    for (int i = 0; i < n; ++i) printf("%d ", que[i]);
}

最短路算法

image-20221111182633376

单源最短路

朴素 Dijkstra 算法

基于贪心

时间复杂度: O(N2+M)O(N^2 + M)NN 表示顶点数,MM 表示边数)

一般用在不存在负权边的稠密图[1]求单源最短路中。

求图中 srcdestsrc \rightarrow dest 的最短路,算法步骤:

变量定义:

  • 使用邻接矩阵( g )存图,n 表示顶点的数量。
  • 维护 distsdists[u] 表示从点 srcsrc 出发到达点 uu 的最短距离
  • 维护 visvis[u]true 时表示 srcusrc \rightarrow u 的距离已经被确定
  1. 初始化:把 dists 中所有数初始化为正无穷(一般用 INF=0x3f3f3f3f 代替), dists[src] = 0
  2. 循环 n1n-1 次:
    1. 取出还未确定最短距离( vis[i] == false )的,距离 src 最近的顶点(mn)。
    2. mn 置为最短距离已确认的点。
    3. mn 更新 src 到其他点的最短距离。
  3. 循环结束后若 dists 中存储的 srcdest 的距离不为 INF 就说明 src 到点 dest 的最短距离是 dists[dest]

模板代码:

static final int INF = 0x3f3f3f3f;
int n, m;
int[][] g;
    
// 返回src到dest的最短距离 不存在则返回-1
static int dijkstra(int src, int dest) {
    int[] dists = new int[n + 1];
    boolean[] vis = new boolean[n + 1];
    Arrays.fill(dists, INF);
    dists[src] = 0;
    for (int i = 1; i < n; ++i) {
        int mn = -1;
        for (int j = 1; j <= n; ++j) {
            if (!vis[j] && (mn == -1 || dists[j] < dists[mn])) mn = j;
        }
        vis[mn] = true;
        for (int j = 1; j <= n; ++j) {
            dists[j] = Math.min(dists[j], dists[mn] + g[mn][j]);
        }
    }
    return dists[dest] == INF ? -1 : dists[dest];
}
const int INF = 0x3f3f3f3f;
int n;
int g[N][N], dists[N];
bool vis[N];

int dijkstra(int src, int dest) {
    memset(dists, 0x3f, sizeof(dists));
    dists[src] = 0;
    for (int i = 1; i < n; ++i) {
        int mn = -1;
        for (int j = 1; j <= n; ++j) {
            if (!vis[j] && (mn == -1 || dists[j] < dists[mn])) mn = j;
        }
        vis[mn] = true;
        for (int j = 1; j <= n; ++j) {
            dists[j] = min(dists[j], dists[mn] + g[mn][j]);
        }
    }
    return dists[dest] == INF ? -1 : dists[dest];
}
堆优化 Dijkstra 算法

基于贪心

时间复杂度: O(MlogN)O(M \log{N})NN 表示顶点数,MM 表示边数)

一般用在不存在负权边的稀疏图[1:1]求单源最短路中。

和朴素 Dijkstra 算法思路一样,但是会把遍历求出最小值的操作优化为小根堆取堆顶,由于是稀疏图,使用邻接表存图。

模板代码:

static final int INF = 0x3f3f3f3f;
int n, m;
int[] h, e, ne, w; // 邻接表存图
int idx;

// 返回src到dest的最短距离 不存在则返回-1
int dijkstra(int src, int dest) {
    int[] dists = new int[n + 1]; // 所有点到src的距离
    // 小根堆 以最短距离为标准排序
    Queue<int[]> pq = new PriorityQueue<>((a, b) -> a[0] - b[0]);
    Arrays.fill(dists, INF);
    dists[src] = 0;
    pq.offer(new int[]{0, src});
    while (!pq.isEmpty()) {
        int[] curr = pq.poll();
        int d = curr[0], mn = curr[1];
        for (int i = h[mn]; i != -1; i = ne[i]) {
            int v = e[i];
            if (d + w[i] < dists[v]) {
                dists[v] = d + w[i];
                pq.offer(new int[]{dists[v], v});
            }
        }
    }
    return dists[dest] == INF ? -1 : dists[dest];
}
typedef pair<int, int> PII;

const int INF = 0x3f3f3f3f;
int heads[N], vals[N], weights[N], nexts[N], idx;
int dists[N];

int dijkstra(int src, int dest) {
    memset(dists, 0x3f, sizeof(dists));
    dists[src] = 0;
    priority_queue<PII, vector<PII>, greater<PII>> heap; // 小根堆
    heap.push({0, src}); // 优先队列存pair时默认按照first排序 first存储距离 second存储节点编号
    while (!heap.empty()) {
        auto curr = heap.top();
        heap.pop();
        int mn = curr.second;
        if (vis[mn]) continue;
        int dist = curr.first;
        vis[mn] = true;
        for (int i = heads[mn]; i != -1; i = nexts[i]) {
            int j = vals[i];
            if (dist + weights[i] < dists[j]) {
                dists[j] = dist + weights[i];
                heap.push({dists[j], j});
            }
        }
    }
    return dists[dest] == INF ? -1 : dists[dest];
}
Bellman-Ford 算法

基于动态规划

时间复杂度: O(NM)O(NM)NN 表示顶点数,MM 表示边数)

一般用在存在负权边的图求单源最短路中。

求图中点 src 到点 dest 的最短路,算法步骤:
使用结构体(edges)存图(只要能把所有边存下来用于遍历就行),n 表示顶点的数量。
维护 dists 表示从点 src 出发到每个顶点的最短距离。伪代码:

for n: // 迭代k次时dists数组表示从src经过不超过k条边走到每个点的最短距离
    for a, b, w in edges:
        dists[b] = min(dists[b], dists[a] + w) // 松弛操作

循环之后,所有边的距离一定满足:dist[b]<dist[a]+wdist[b] < dist[a] + w (三角不等式)。

模板代码:

const INF = 0x3f3f3f3f;
struct Edge {
    int a, b, w;
} edges[M];
int dists[N], bak[N];
int n, m, k; // n表示点数 m表示边数

int bellman_ford(int src, int dest, int k) {
    memset(dists, 0x3f, sizeof(dists));
    dists[src] = 0;
    while (k--) {
        memcpy(bak, dists, sizeof(dists)); // 备份上一次的结果 防止更新串联
        for (int i = 0; i < m; ++i) {
            int a = edges[i].a, b = edges[i].b, w = edges[i].w;
            dists[b] = min(dists[b], bak[a] + w);
        }
    }
    return dists[dest];
}

int ans = bellman_ford(1, n, k); // 求最短路
if (ans > INF / 2) // 有可能出现最短路不存在但是dists[dest]被更新的情况 所以只要取到的ans比数据范围大就记为不存在最短路
else printf("%d\n", ans);
SPFA 算法

Bellman-Ford 算法的宽搜优化,基于动态规划

时间复杂度:一般情况下 O(M)O(M) 最坏情况下 O(NM)O(NM)NN 表示顶点数,MM 表示边数)

一般用在不存在负权回路的图求单源最短路中。

SPFA 算法是优化 Bellman-Ford 算法得到的,Bellman-Ford 算法每次迭代时是枚举所有边来更新,但实际上每次迭代时并不是每一条边都可以用来更新,于是 SPFA 就在这一点上做了优化,它利用BFS来进行每一次的迭代,尽可能地保证每一次迭代都更新最短距离。

求图中点 src 到点 dest 的最短路,算法步骤:

使用邻接表存图,n 表示顶点的数量。
维护队列 que 存储所有最短距离变小了的顶点
维护 dists 从点 src 出发到每个顶点的最短距离
维护 has 表示某个点是否在队列中。

  1. 初始化:把 dists 中所有数初始化为正无穷(一般用 INF=0x3f3f3f3f 代替), dists[src] = 0,把起始点放入队列,并标记起始点在队中。
  2. 当队列不为空时,循环:
    1. 取出队头(u),标识 u 已不在队中( has[u] = false )。
    2. 遍历顶点 u 的所有出边:
      • 拿到顶点 v ,尝试用 u 的最短路更新 dists[v],如果可以更新并且 v 不在队列中,就把 v 入队。
  3. 循环结束后如果 dists[dest] 不大于 INF / 2 就说明 srcdest 的最短路存在且为 dists[dest]

模板代码:

static final int INF = 0x3f3f3f3f;
int n, m;
int[] h, e, ne, w;
int idx;

// 返回src到dest的最短距离 不存在则返回INF
static int spfa(int src, int dest) {
    int[] dists = new int[n + 1];
    boolean[] has = new boolean[n + 1];
    Arrays.fill(dists, INF);
    dists[src] = 0;
    Queue<Integer> que = new LinkedList<>();
    que.offer(src);
    has[src] = true;
    while (!que.isEmpty()) {
        int u = que.poll();
        has[u] = false;
        for (int i = h[u]; i != -1; i = ne[i]) {
            int v = e[i];
            if (dists[u] + w[i] < dists[v]) {
                dists[v] = dists[u] + w[i];
                if (!has[v]) {
                    que.offer(v);
                    has[v] = true;
                }
            }
        }
    }
    return dists[dest];
}
const int INF = 0x3f3f3f3f;
int heads[N], vals[M], wts[M], nexts[M], idx;
int dists[N];
bool has[N];

int spfa(int src, int dest) {
    memset(dists, 0x3f, sizeof(dists));
    queue<int> que;
    dists[src] = 0;
    que.push(src);
    has[src] = true;
    while (!que.empty()) {
        int u = que.front();
        que.pop();
        has[u] = false;
        for (int i = heads[u]; i != -1; i = nexts[i]) {
            int v = vals[i];
            if (dists[u] + wts[i] < dists[v]) {
                dists[v] = dists[u] + wts[i];
                if (!has[v]) {
                    que.push(v);
                    has[v] = true;
                }
            }
        }
    }
    return dists[dest];
}
SPFA 算法判复环

时间复杂度:O(NM)O(NM)NN 表示顶点数,MM 表示边数)

一般用在图求是否存在负权回路中。

算法步骤与普通 SPFA 算法几乎相同。

原理:如果某条最短路径上有 n 个点(除了自己),那么加上自己之后一共有 n+1 个点,由抽屉原理一定有两个点相同,所以存在环。

模板代码:

int n, m;
int[] h, e, ne, w;
int idx;

// 如果图中存在负环返回true 否则返回false
boolean spfa() {
    // 不用初始化dists 因为我们只需能看出短路的变化趋势即可 不需要真的求出最短路
    int[] dists = new int[n + 1];
    boolean[] has = new boolean[n + 1];
    int[] cnts = new int[n + 1]; // cnts存储这每个点最短路中经过的点数
    Queue<Integer> que = new LinkedList<>();
    // 有些节点的最短路可能并不经过负权回路 为了找出整个图中是否存在负权回路 要把所有顶点全部算上
    for (int u = 1; u <= n; ++u) {
        que.offer(u);
        has[u] = true;
    }
    while (!que.isEmpty()) {
        int u = que.poll();
        has[u] = false;
        for (int i = h[u]; i != -1; i = ne[i]) {
            int v = e[i];
            if (dists[u] + w[i] < dists[v]) {
                dists[v] = dists[u] + w[i];
                cnts[v] = cnts[u] + 1;
                // 如果从1号点到x的最短路中包含至少n个点(不包括自己) 说明存在负权回路
                if (cnts[v] >= n) return true;
                if (!has[v]) {
                    que.offer(v);
                    has[v] = true;
                }
            }
        }
    }
    return false;
}
int heads[N], vals[M], wts[M], nexts[M], idx;
int dists[N], cnts[N];
bool vis[N];

bool spfa() {
    queue<int> que;
    for (int x = 1; x <= n; ++x) {
        que.push(x);
        vis[x] = true;
    }
    while (!que.empty()) {
        int u = que.front();
        que.pop();
        vis[u] = false;
        for (int i = heads[u]; i != -1; i = nexts[i]) {
            int v = vals[i];
            if (dists[u] + wts[i] < dists[v]) {
                dists[v] = dists[u] + wts[i];
                cnts[v] = cnts[u] + 1;
                if (cnts[v] >= n) return true;
                if (!vis[v]) {
                    que.push(v);
                    vis[v] = true;
                }
            }
        }
    }
    return false;
}

多源汇最短路

Floyd 算法

基于动态规划

时间复杂度: O(N3)O(N^3)NN 表示顶点数)

一般用在不存在负权回路的图求多源汇最短路中。

使用邻接表存图,n 表示顶点的数量,维护 dists 用于存储任意点到任意点的最短距离。

算法步骤(伪代码):

// 初始化 n表示顶点的数量
for i from 1 to n:
    for j from 1 to n:
        d[i][j] = i == j ? 0 : INF

// Floyd算法
for k from 1 to n:
    for i from 1 to n:
        for j from 1 to n:
            d[i][j] = min(d[i][j], d[i][k] + d[k][j])

模板代码:

static final int INF = 0x3f3f3f3f;
int n; // 顶点数量
int dists[N][N]; // 邻接矩阵存图

// 初始化邻接矩阵(在输入所有边之前)
for (int i = 1; i <= n; ++i) {
    for (int j = 1; j <= n; ++j) {
        if (i == j) continue;
        dists[i][j] = INF;
    }
}

void floyd() {
    for (int k = 1; k <= n; ++k) {
        for (int i = 1; i <= n; ++i) {
            for (int j = 1; j <= n; ++j) {
                dists[i][j] = Math.min(dists[i][j], dists[i][k] + dists[k][j]);
            }
        }
    }
}

// src->dest的最短路
if (dists[src][dest] > INF / 2) // 不存在最短路
else // ans: dists[src][dest]
const INF = 0x3f3f3f3f;
int n;
int dists[N][N];

void floyd() {
    for (int k = 1; k <= n; ++k) {
        for (int i = 1; i <= n; ++i) {
            for (int j = 1; j <= n; ++j) {
                dists[i][j] = min(dists[i][j], dists[i][k] + dists[k][j]);
            }
        }
    }
}

// src->dest的最短路
if (dists[src][dest] > INF / 2) // 不存在最短路
else // ans: dists[src][dest]

最小生成树算法

image-20221115094034018

  • 生成树生成树 指的是「无向图」中,具有该图的 全部顶点边数最少连通子图
  • 最小生成树最小生成树 指的是「加权无向图」中总权重最小的生成树。

朴素 Prim 算法

时间复杂度: O(N2)O(N^2)NN 表示顶点数)

一般用在稠密图[1:2]求最小生成树中。

Dijkstra 算法 相似,算法步骤:

使用邻接矩阵(g)存图,n 表示顶点的数量,维护 dists 用于存储集合外的点到集合的最短距离,维护 vis 表示某个点是否已经被放入最小生成树。

  1. 初始化:把 dists 中所有数初始化为正无穷。
  2. 循环 n 次:
    1. 找到集合外距离集合最近的点(mn)。
    2. 若该点不是第一个迭代的点且 与集合内所有顶点都不相连 ,说明该图不是一个连通图,不存在最小生成树。
    3. mn 放入最小生成树( vis[mn] = true )并更新最小生成树的总权值( res += vis[mn] )。
    4. mn 更新其他点到集合的最短距离。

为什么 Dijkstra 算法外层循环 n-1 次而 Prim 算法要循环 n 次?

——因为 Dijkstra 算法在迭代前先把源点加入了集合,所以它只用迭代 n-1 个点,而 Prim 算法要迭代所有 n 个顶点。

模板代码:

const INF = 0x3f3f3f3f;
int n; // 顶点数量
int g[N][N]; // 邻接矩阵存图
dists[N];
bool vis[N];

// 返回最小生成树的权值和 若不存在则返回INF
int prim() {
    memset(dists, 0x3f, sizeof(dists));
    int res = 0;
    for (int i = 0; i < n; ++i) {
        int mn = -1;
        for (int j = 1; j <= n; ++j) {
            if (!vis[j] && (mn == -1 || dists[j] < dists[mn])) mn = j;
        }
        if (i) {
            if (dists[mn] == INF) return INF; // 发现不是所有点连通 不存在最小生成树
            res += dists[mn]; // 把mn加入最小生成树 并把dists[mn]累加进权值和
        }
        vis[mn] = true;
        for (int j = 1; j <= n; ++j) dists[j] = min(dists[j], g[j][mn]);
    }
    return res;
}

int ans = prim(); // 若ans==INF 说明该图不存在最小生成树
// 最小生成树的总权重为: ans

Kruskal 算法

时间复杂度: O(MlogM)O(M \log M)MM 表示边数)

一般用在稀疏图[1:3]求最小生成树中。

算法步骤:

  1. 将所有边按照权重从小到大排序。
  2. 枚举每条边 uwvu \xleftrightarrow{w} v
    • 如果 u,vu, v 不连通,将这条边加入集合中。

模板代码:

const int INF = 0x3f3f3f3f;
int m; // 边数
struct Edge {
    int u, v, w;
    
    bool operator<(const Edge& e) {
        return w < e.w;
    }
} edges[M]; // 结构体存储所有边 并按照权值升序排序
int roots[N];

// 并查集操作
int find(int x) {
    return x == roots[x] ? x : (roots[x] = find(roots[x]));
}

void join(int p, int q) {
    roots[find(p)] = find(q);
}

// 返回最小生成树的权值和 若不存在则返回INF
int kruskal() {
    sort(edges, edges + m);
    for (int i = 1; i <= n; ++i) roots[i] = i;
    int cnt = 0, res = 0;
    for (int i = 0; i < m; ++i) {
        int u = edges[i].u, v = edges[i].v, w = edges[i].w;
        if (find(u) != find(v)) {
            join(u, v);
            res += w;
            if (++cnt == n - 1) break;
        }
    }
    return cnt < n - 1 ? INF : res;
}

int ans = kruskal(); // 若ans==INF 说明该图不存在最小生成树
// 最小生成树的总权重为: ans

二分图

定义

二分图(Bipartite graph),又称二部图 。

它是:所有顶点由两个集合组成,且两个集合内部没有边的图。

换言之,存在一种方案,将所有顶点划分成满足以上性质的两个集合。

img

性质

  • 如果两个集合中的点分别染成黑色和白色,可以发现二分图中的每一条边都一定是连接一个黑色点和一个白色点。

  • 二分图不存在长度为奇数的环

    因为每一条边都是从一个集合走到另一个集合,只有走偶数次才可能回到同一个集合。

染色法判定二分图

时间复杂度: O(N+M)O(N + M)NN 表示顶点数,MM 表示边数)

模板代码:

int n, m; // 顶点数 边数
int[] h = new int[N], e = new int[M], ne = new int[M]; // 邻接表存图
int idx;
int[] colors; // 记录每个节点的颜色: 0表示还未被染色  1和2为两种不同的颜色

// 把顶点u染成c色 返回是否有矛盾
boolean dye(int u, int c) {
    colors[u] = c; // 把u染色c色
    // 枚举与u相连的所有顶点
    for (int i = h[u]; i != -1; i = ne[i]) {
        int v = e[i];
        // 若顶点v没染过色 给它染与u不同的色 如果矛盾就说明不是二分图
        // 若顶点v染过色 看它与u的颜色是否相同来判断是否有矛盾
        if (colors[v] == 0 && !dye(v, 3 - c) || colors[v] == c) return false;
    }
    return true;
}

// 枚举所有顶点
for (int u = 1; u <= n; ++u) {
    if (colors[u] == 0) {
        // 尝试给还未被染色的顶点染色
        if (!dye(u, 1)) // 染色矛盾 该图不是二分图
    }
}
// 枚举能走完就说明该图是二分图
int n, m; // 顶点数 边数
int h[N], e[M], ne[M], idx; // 邻接表存图
int colors[N]; // 记录每个节点的颜色: 0表示还未被染色  1和2为两种不同的颜色

bool dfs(int u, int c) {
    colors[u] = c; // 给u染色
    // 枚举与u相连的所有顶点
    for (int i = heads[u]; i != -1; i = nexts[i]) {
        int v = vals[i];
        // 若顶点v没染过色 给它染与u不同的色 如果矛盾就说明不是二分图
        // 若顶点v染过色 看它与u的颜色是否相同来判断是否有矛盾
        if (!colors[v] && !dfs(v, 3 - c) || colors[v] == c) return false;
    }
    return true;
}

// 枚举所有顶点
for (int i = 1; i <= n; ++i) {
    if (!colors[i]) {
        // 尝试给还未被染色的顶点染色
        if (!dfs(i, 1)) // 染色矛盾 该图不是二分图
    }
}
// 枚举能走完就说明该图是二分图

匈牙利算法求二分图最大匹配

时间复杂度: O(NM)O(NM) (实际运行时间一般远小于 O(NM)O(NM) )( NN 表示顶点数,MM 表示边数)

模板代码:

int n1, n2, idx; // n1 n2 分别表示二分图的两部分的顶点数
int matchs[N]; // 记录匹配
bool vis[N]; // 记录在对某个节点的匹配尝试中 另一种颜色的节点是否被被用过

bool find(int u) {
    for (int i = heads[u]; i != -1; i = nexts[i]) {
        int v = vals[i];
        if (!vis[v]) {
            vis[v] = true;
            if (!matchs[v] || find(matchs[v])) {
                matchs[v] = u;
                return true;
            }
        }
    }
    return false;
}

int ans = 0; // 记录最大匹配
for (int i = 1; i <= n1; ++i) {
    memset(vis, false, sizeof(vis)); // 重置vis
    if (find(i)) ++ans;
}

最近公共祖先(LCA)

最近公共祖先简称 LCA(Lowest Common Ancestor)。两个节点的最近公共祖先,就是这两个点的公共祖先里面,离根最远的那个。

爬山法

时间复杂度:

  • 单次查询: O(N)O(N)NN 指查询点的高度 )

模板题:【LCA, 递归】二叉树

模板代码:

// l[u]: 节点u的左子节点  r[u]: 节点u的右子节点  p[u]: 节点u的父节点
int[] l = new int[N], r = new int[N], p = new int[N];
int[] dep = new int[N]; // dep[u]: 根节点到节点u的距离

// 初始化
Arrays.fill(l, -1);
Arrays.fill(r, -1);

// 预处理到根节点的距离
static void dfs(int u, int d) {
    dep[u] = d;
    if (l[u] != -1) dfs(l[u], d + 1);
    if (r[u] != -1) dfs(r[u], d + 1);
}
dfs(1, 0);
    
static int lca(int a, int b) {
    if (dep[a] > dep[b]) {
        int t = a;
        a = b;
        b = t;
    }
    while (dep[b] != dep[a]) b = p[b];
    while (a != b) {
        a = p[a];
        b = p[b];
    }
    return a;
}
// 求节点a, b的最近公共祖先: lca(a, b)

倍增法

时间复杂度:

  • 预处理: O(NlogN)O(N\log{N})NN 指查询点的高度 )
  • 单次查询: O(logN)O(\log{N})

模板代码

NN 为节点数, MM 为树中的边数(由于树中都是无向边, 所以一般取 2N2N ), KK 一般取 log2N\lfloor \log_{2}{N} \rfloor

// 邻接表存树  无向边加两次
int[] h = new int[N], e = new int[M], ne = new int[M];
int idx, root;
int[] dep = new int[N];
int[][] fa = new int[N][K];

// 预处理
void bfs() {
    Arrays.fill(dep, -1);
    dep[0] = 0; // 哨兵
    dep[root] = 1;
    Queue<Integer> que = new LinkedList<>();
    que.offer(root);
    while (!que.isEmpty()) {
        int u = que.poll();
        for (int i = h[u]; i != -1; i = ne[i]) {
            int v = e[i];
            if (dep[v] == -1) {
                dep[v] = dep[u] + 1;
                que.offer(v);
                fa[v][0] = u;
                for (int k = 1; k < K; ++k) {
                    fa[v][k] = fa[fa[v][k - 1]][k - 1];
                }
            }
        }
    }
}
bfs();
    
int lca(int a, int b) {
    if (dep[a] > dep[b]) {
        int t = a;
        a = b;
        b = t;
    }
    for (int k = K - 1; k >= 0; --k) {
        if (dep[fa[b][k]] >= dep[a]) b = fa[b][k];
    }
    if (a == b) return a;
    for (int k = K - 1; k >= 0; --k) {
        if (fa[a][k] != fa[b][k]) {
            a = fa[a][k];
            b = fa[b][k];
        }
    }
    return fa[a][0];
}
// 求节点a, b的最近公共祖先: lca(a, b)

数学

向上取整

ab=a+b1b\lceil\frac{a}{b}\rceil = \lfloor \frac{a+b-1}{b} \rfloor

int ceil = (a + b - 1) / b

加法原理

完成一个工程可以有 nn 类办法, ai(1in)a_i(1 \le i \le n) 代表第 ii 类方法的数目。那么完成这件事共有 S=a1+a2++anS = a_1 + a_2 + \dots + a_n 种不同的方法。

乘法原理

完成一个工程需要分 nn 个步骤,ai(1in)a_i(1 \le i \le n) 代表第 ii 个步骤的不同方法数目。那么完成这件事共有 S=a1×a2××anS = a_1 \times a_2 \times \cdots \times a_n种不同的方法。

算术基本定理

任何一个大于 11正整数都能分解为有限个质数的乘积:

n=p1α1×p2α2××pkαk,p1<p2<<pkn = p_1^{\alpha_1} \times p_2^{\alpha_2} \times \cdots \times p_k^{\alpha_k}, \enspace p_1 < p_2 < \cdots < p_k

其中 αiZ\alpha_i \in \mathbb{Z}pip_i 都是质数,且 p1<p2<pkp_1 < p_2 < \dots p_k ;在不计次序的意义下,该表示唯一。

裴蜀定理

裴蜀定理,又称贝祖定理(Bézout’s lemma)。是一个关于最大公约数的定理。

其内容是:

a,ba,b不全为零的整数,则存在整数 x,yx,y ,使得 ax+by=gcd(a,b)ax+by=\gcd(a,b)

平方和公式

i=1ni2=12+22++n2=n(n+1)(2n+1)6\sum_{i=1}^{n}{i^2} = 1^2 + 2^2 + \dots + n^2 = \frac{n(n+1)(2n+1)}{6}

质数

试除法判质数

模板代码:

boolean isPrime(int n) {
    if (n < 2) return false;
    for (int i = 2; i <= n / i; ++i) {
        if (n % i == 0) return false;
    }
    return true;
}
bool is_prime(int n) {
    if (n < 2) return false;
    for (int i = 2; i <= n / i; ++i) {
        if (n % i == 0) return false;
    }
    return true;
}

时间复杂度:O(n)O(\sqrt{n})

试除法分解质因数

分解质因数(Prime Factorization):根据算术基本定理,每个合数都可以写成几个质数相乘的形式,其中每个质数都是这个合数的因数,把一个合数用质因数相乘的形式表示出来,叫做分解质因数。

模板题:【数学】分解质因数

原理:试除法,枚举 2n2 \dots n 中的所有质数,如果 iinn质因子,就把 nn 中的 ii 除干净。

模板代码:

void pf(int n) {
    for (int i = 2; i <= n / i; ++i) {
        if (n % i == 0) {
            int s = 0;
            while (n % i == 0) {
                n /= i;
                ++s;
            }
            System.out.println(i + " " + s);
        }
    }
    if (n > 1) System.out.println(n + " 1");
}
void pf(int n) {
    for (int i = 2; i <= n / i; ++i) {
        if (n % i == 0) {
            int s = 0;
            while (n % i == 0) {
                n /= i;
                ++s;
            }
            printf("%d %d\n", i, s);
        }
    }
    if (n > 1) printf("%d 1\n", n);
}

时间复杂度:

  • 最差: O(n)O(\sqrt{n})
  • 最好: O(logN)O(\log{N})

其中枚举的 ii 一定是质数,因为:当枚举到 ii 的时候,已经把 [2,i1][2, i-1] 中的质因子都除掉了,当 if (n % i == 0) 成立的时候,nn 中已经没有任何在 [2,i1][2, i-1] 范围内的质因子,又因为 nnii 的倍数,所以 ii 也没有任何在 [2,i1][2, i-1] 范围内的因子,所以此时 ii 一定是质数。

筛质数

模板题:【数学】筛质数

朴素筛

原理:

  • 枚举 x[2,n]x \in [2, n]xx 的倍数一定是合数,所以可以在枚举到 xx 的时候筛掉所有 小于等于 nnxx 的倍数
  • 而在枚举到 p[2,n]p \in [2, n] 的时候,如果 pp 还没有被筛掉,说明 pp 不是 x[2,p1]x \in [2, p-1] 中任意一个数的倍数,所以 pp 一定是质数。

模板代码:

int countPrimes(int n) {
    int cnt = 0;
    int[] primes = new int[n + 1];
    boolean[] isNotPrime = new boolean[n + 1];
    for (int x = 2; x <= n; ++x) {
        if (!isNotPrime[x]) primes[cnt++] = x;
        for (int k = 2; k * x <= n; ++k) isNotPrime[k * x] = true;
    }
    return cnt;
}

时间复杂度:O(NlogN)O(N\log{N})

埃氏筛

原理:要得到自然数 nn 以内的全部质数,把不大于 n\sqrt{n} 的所有质数倍数筛除,剩下的就都是质数。

证明:根据算术基本定理每一个大于 11 的正整数都能分解成有限个质数的幂的乘积,且由于是从小到大枚举,在枚举到某个合数之前,一定先会枚举到它的质因数,也就是说所有的合数都会被它的质因数筛掉。

New_Animation_Sieve_of_Eratosthenes.gif

模板代码:

int countPrimes(int n) {
    int cnt = 0;
    int[] primes = new int[n + 1];
    boolean[] isNotPrime = new boolean[n + 1];
    for (int x = 2; x <= n; ++x) {
        if (!isNotPrime[x]) {
            primes[cnt++] = x;
            for (int k = 2; k * x <= n; ++k) isNotPrime[k * x] = true;
        }
    }
    return cnt;
}

时间复杂度:O(NloglogN)O(N\log\log{N})

线性筛(欧拉筛)

原理:线性筛保证了每个合数都只被其最小质因子筛掉。

模板代码:

int countPrimes(int n) {
    int cnt = 0;
    int[] primes = new int[n + 1];
    boolean[] isNotPrime = new boolean[n + 1];
    for (int x = 2; x <= n; ++x) {
        if (!isNotPrime[i]) primes[cnt++] = x;
        for (int i = 0; primes[i] <= n / x; ++i) {
            isNotPrime[primes[i] * x] = true;
            if (x % primes[i] == 0) break;
        }
    }
    return cnt;
}

时间复杂度:O(N)O(N)

线性筛是如何保证每个合数只会被其最小质因子筛掉的?

  • x % primes[i] == 0 时,说明此时遍历到的 pip_i 不是 xx 的质因子,那么 **pip_i 一定小于 xx 的最小质因子 **,所以 pixp_i \cdot x 的最小质因子就是 pip_i
  • x % primes[i] != 0 时 ,说明 xx 的最小质因子是此时的 pip_i ,因此 pixp_i \cdot x 的最小质因子依旧应该是 pip_i ,但如果继续枚举的话,我们就把 pi+1xp_{i+1} \cdot x 这个数筛掉了,虽然这个数也是合数,但是筛掉它的时候并不是用它的最小质因数筛掉的 ,而是利用 pi+1p_{i + 1}xx 把它删掉的,而这个数的最小质因数其实是 pip_i ,如果不在这里退出循环的话,就会发现有些数是被重复筛了的。

勒让德定理

在正数 n!n! 的质因数分解中,质数 pp 的指数记作 νp(n!)\nu_p(n!) ,则:

νp(n!)=i=1logpnnpk\nu_p(n!) = \sum_{i=1}^{\left\lfloor\log_{p}{n}\right\rfloor}{\left\lfloor\frac{n}{p^k}\right\rfloor}

模板代码:

int count(int n, int p) {
    int res = 0;
    while (n) res += n /= p;
    return res;
}

约数

试除法求所有约数

模板代码:

set<int> get_divisors(int n) {
    set<int> divisors;
    for (int i = 1; i <= n / i; ++i) {
        if (n % i == 0) {
            divisors.insert(i);
            divisors.insert(n / i);
        }
    }
    return divisors;
}

时间复杂度:O(n)O(\sqrt{n})

int 范围内,约数最多的数其约数有 15361536 个。

约数个数

求约数个数

模板题:【数学】约数个数

原理:

  • 根据算术基本定理nn 质因数分解为:n=p1α1×p2α2××pkαkn = p_1^{\alpha_1} \times p_2^{\alpha_2} \times \cdots \times p_k^{\alpha_k}
  • 那么 nn 的约数个数为:f(n)=i=1k(αi+1)=(α1+1)(α2+1)(αk+1)f(n) = \prod_{i=1}^{k}{(\alpha_i + 1)} = (\alpha_1 + 1)(\alpha_2 + 1) \cdots (\alpha_k + 1)

证明:

  • 显然, p1α1p_1^{\alpha_1} 的约数为 p10,p11,,p1kp_1^{0}, p_1^{1}, \dots, p_1^{k} ,同理: piαip_i^{\alpha_i} 的约数为 pi0,pi1,,piip_i^{0}, p_i^{1}, \dots, p_i^{i}
  • 实际上 nn 的约数是在 p1α1,p2α2,pkαkp_1^{\alpha_1}, p_2^{\alpha_2}, \dots p_k^{\alpha_k} 每个数的约数分别挑一个相乘得来的,那么 nn 的每个约数就都可以被分解为 d=p1β1×p2β2××pkβkd = p_1^{\beta_1} \times p_2^{\beta_2} \times \cdots \times p_k^{\beta_k} (其中 0βiαi0 \le \beta_i \le \alpha_i );
  • 又根据算术基本定理:每一个数的质因数分解是唯一的,那么 βi\beta_i 的所有组合就对应了 nn 的所有约数
  • 而每个 βi\beta_i 都有 0αi0 \dots \alpha_iαi+1\alpha_i+1 种选择,所以 dd 共有 i=1k(αi+1)\prod_{i=1}^{k}{(\alpha_i + 1)} 个组合,那么 nn 也就有这么多个约数。

模板代码:

// 求一个数的约数个数
int count_divisors(int n) {
    int cnt = 1;
    for (int i = 2; i <= n / i; ++i) {
        if (n % i == 0) {
            int s = 0;
            while (n % i == 0) {
                n /= i;
                ++s;
            }
           	cnt *= s + 1;
        }
    }
    if (n > 1) cnt *= 2;
    return cnt;
}

时间复杂度:

  • 最差: O(n)O(\sqrt{n})
  • 最好: O(logN)O(\log{N})
一些与约数个数相关的性质
  • 1N1\sim{N} 中所有数的约数个数之和总和约为 NlogNN\log{N} 个;
  • [0,2×109][0, 2\times10^9] 范围内的约数个数最多的数的约数约有 1600 个。

约数之和

模板题:【数学】约数之和

原理:

  • 根据算术基本定理nn 质因数分解为:n=p1α1×p2α2××pkαkn = p_1^{\alpha_1} \times p_2^{\alpha_2} \times \cdots \times p_k^{\alpha_k}

  • 那么 nn 的约数之和为:

    σ(n)=i=1k(j=0αkpij)=(p10+p11+p12++p1α1)(p20+p21+p22++p2α2)(pk0+pk1+pk2++pkαk)\sigma(n) = \prod_{i=1}^{k}(\sum_{j=0}^{\alpha_k}{p_i^j}) = (p_1^0+p_1^1+p_1^2+ \cdots + p_1^{\alpha_1})(p_2^0+p_2^1+p_2^2+ \cdots + p_2^{\alpha_2}) \cdots (p_k^0+p_k^1+p_k^2+ \cdots + p_k^{\alpha_k})

证明:

  • 显然, p1α1p_1^{\alpha_1} 的约数为 p10,p11,,p1kp_1^{0}, p_1^{1}, \dots, p_1^{k} ,同理: piαip_i^{\alpha_i} 的约数为 pi0,pi1,,piip_i^{0}, p_i^{1}, \dots, p_i^{i}

  • 实际上 nn 的约数是在 p1α1,p2α2,pkαkp_1^{\alpha_1}, p_2^{\alpha_2}, \dots p_k^{\alpha_k} 每个数的约数分别挑一个相乘得来的;

  • 可知共有 (α1+1)(α2+1)(αk+1)(\alpha_1 + 1)(\alpha_2 + 1) \cdots (\alpha_k + 1) 种挑法,即约数个数。

  • 乘法原理可知它们的和为:

    σ(n)=i=1k(j=0αkpij)=(p10+p11+p12++p1α1)(p20+p21+p22++p2α2)(pk0+pk1+pk2++pkαk)\sigma(n) = \prod_{i=1}^{k}(\sum_{j=0}^{\alpha_k}{p_i^j}) = (p_1^0+p_1^1+p_1^2+ \cdots + p_1^{\alpha_1})(p_2^0+p_2^1+p_2^2+ \cdots + p_2^{\alpha_2}) \cdots (p_k^0+p_k^1+p_k^2+ \cdots + p_k^{\alpha_k})

模板代码:

// 求一个数的约数之和
int sum_divisors(int n) {
    int sum = 1;
    for (int i = 2; i <= n / i; ++i) {
        if (n % i == 0) {
			int tmp = 1;
            while (n % i == 0) {
                n /= i;
                tmp = i * tmp + 1;
            }
            sum *= tmp;
        }
    }
    if (n > 1) sum *= n + 1;
    return sum;
}

最大公约数

原理:辗转相除法(欧几里得算法): gcd(a,b)=gcd(b,amodb)\gcd(a, b) = \gcd(b, a \bmod b)

证明:

  • 已知 amodb=aabba \bmod b = a - \lfloor \frac{a}{b} \rfloor \cdot b ,设 c=abc = \lfloor \frac{a}{b} \rfloor ,则现在要证明 gcd(a,b)=gcd(b,acb)\gcd(a, b) = \gcd(b, a - c \cdot b)
  • 对于 gcd(a,b)\gcd(a, b) , 设 da,dbd \mid a, d \mid b
    • 则显然有 dax+by,x,yZd \mid ax+by, \enspace x, y \in \mathbb{Z} ,当 x=1,y=cx = 1, y = c 时, dacbd \mid a - c \cdot b
    • db,dacbd \mid b, d \mid a - c \cdot b 同时成立,则有 gcd(a,b)gcd(b,acb)\gcd(a, b) \subseteq \gcd(b, a - c \cdot b)
  • 对于 gcd(b,acb)\gcd(b, a-c \cdot b) ,设 db,dacbd \mid b, d \mid a - c \cdot b
    • 则显然有 dacb+cbdad \mid a - c \cdot b + c \cdot b \Rightarrow d \mid a
    • da,dbd \mid a, d \mid b 同时成立,则有 gcd(b,acb)gcd(a,b)\gcd(b, a - c \cdot b) \subseteq \gcd(a, b)
  • gcd(a,b)=gcd(b,acb)\gcd(a, b) = \gcd(b, a - c \cdot b) 得证,同时 gcd(a,b)=gcd(b,amodb)\gcd(a, b) = \gcd(b, a \bmod b) 也成立。

模板代码:(递归实现)

int gcd(int a, int b) {
    return b ? gcd(b, a % b) : a;
}
int gcd(int a, int b) {
    return b == 0 ? a : gcd(b, a % b);
}

欧拉函数

公式求欧拉函数

详细介绍及原理:欧拉函数 - OI Wiki

若在算术基本定理中:n=p1α1×p2α2××pkαkn = p_1^{\alpha_1} \times p_2^{\alpha_2} \times \cdots \times p_k^{\alpha_k}

那么 nn 的欧拉函数:ϕ(n)=n×i=1k(11pi)=n×p11p1×p21p2××pk1pk\phi(n) = n \times \prod_{i=1}^{k}{(1-\frac{1}{p_i})} = n \times \frac{p_1 - 1}{p_1} \times \frac{p_2 - 1}{p_2} \times \cdots \times \frac{p_k - 1}{p_k}

模板代码:

// 求一个数的欧拉函数
LL phi(int n) {
    LL res = n;
    for (int i = 2; i <= n / i; ++i) {
        if (n % i == 0) {
            while (n % i == 0) n /= i;
            res = res * (i - 1) / i;
        }
    }
    if (n > 1) res = res * (n - 1) / n;
    return res;
}

时间复杂度: O(N)O(\sqrt{N})

筛法求欧拉函数之和

1n1\sim n 中每个数的欧拉函数之和:

long sumEulers(int n) {
    int cnt = 0;
    int[] primes = new int[n + 1], phi = new int[n + 1];
    boolean[] isNotPrime = new boolean[n + 1];
    phi[1] = 1;
    for (int i = 2; i <= n; ++i) {
        if (!isNotPrime[i]) {
            primes[cnt++] = i;
            phi[i] = i - 1;
        }
        for (int j = 0; primes[j] <= n / i; ++j) {
            isNotPrime[primes[j] * i] = true;
            if (i % primes[j] == 0) {
                phi[primes[j] * i] = phi[i] * primes[j];
                break;
            }
            phi[primes[j] * i] = phi[i] * (primes[j] - 1);
        }
    }
    long res = 0L;
    for (int i = 1; i <= n; ++i) res += phi[i];
    return res;
}

时间复杂度: O(N)O(N)

快速幂

详细介绍及原理:快速幂 - OI Wiki

模板代码:

long quickPow(long a, long b, long p) {
    long c = 1L;
    while (b != 0) {
        if ((b & 1) != 0) c = c * a % p;
        a = a * a % p;
        b >>= 1;
    }
    return c;
}
long long quick_pow(long long a, long long b, long long p) {
    long long c = 1LL;
    while (b) {
        if (b & 1) c = c * a % p;
        a = a * a % p;
        b >>= 1;
    }
    return c;
}

慢速乘

防止乘数及模数很大但还在 long long 范围内时直接做乘爆 long long 的问题。

long slowMul(long a, long b, long p) {
    long c = 0L;
    while (b > 0) {
        if ((b & 1) != 0) c = (c + a) % p;
        a = (a + a) % p;
        b >>= 1;
    }
    return c;
}
long long slow_mul(long long a, long long b, long long p) {
    long long c = 0LL;
    while (b) {
        if (b & 1) c = (c + a) % p;
        a = (a + a) % p;
        b >>= 1;
    }
    return c;
}

扩展欧几里得算法

扩展欧几里得算法(Extended Euclidean algorithm, EXGCD),常用于求 ax+by=gcd(a,b)ax+by=\gcd(a,b)裴蜀定理)的一组可行解。

过程

设:

ax1+by1=gcd(a,b)ax_1 + by_1 = \gcd(a, b)

bx2+(amodb)y2=gcd(b,amodb)bx_2 + (a \bmod b)y_2 = \gcd(b, a \bmod b)

欧几里得算法可知: gcd(a,b)=gcd(b,amodb)\gcd(a, b) = \gcd(b, a \bmod b)

所以:

ax1+by1=bx2+(amodb)y2ax1+by1=bx2+(aabb)y2ax1+by1=bx2+ay2baby2ax1+by1=ay2+b(x2aby2)\begin{aligned} ax_1 + by_1 &= bx_2 + (a \bmod b)y_2 \\ ax_1 + by_1 &= bx_2 + (a - \lfloor \frac{a}{b} \rfloor \cdot b)y_2 \\ ax_1 + by_1 &= bx_2 + ay_2 - b \lfloor \frac{a}{b} \rfloor y_2 \\ ax_1 + by_1 &= ay_2 + b (x_2 - \lfloor \frac{a}{b} \rfloor y_2) \\ \end{aligned}

因为 a=a,b=ba = a, b = b ,所以:{x1=y2y1=x2aby2\begin{cases} x_1 = y_2 \\ y_1 = x_2 - \lfloor \frac{a}{b} \rfloor y_2 \end{cases}

x2,y2x_2,y_2 不断代入递归求解,直至 b=0b=0 时: ax+0y=aax + 0y = a ,显然 {x=1y=0\begin{cases} x = 1 \\ y= 0 \end{cases} 是一组解,此时退出递归。

模板代码

int x, y;
    
int exgcd(int a, int b) {
    if (b == 0) {
        x = 1;
        y = 0;
        return a;
    }
    int d = exgcd(b, a % b);
    int t = x;
    x = y;
    y = t - a / b * y;
    return d;
}

线性同余方程

定义

形如:

axb(modn)ax\equiv b\pmod n

的方程称为 线性同余方程(Congruence Equation)。其中,a,b,na, b, n 为给定整数,xx 为未知数,需要从区间 [0,n1][0, n-1] 中求解,当解不唯一时需要求出全体解。

用扩展欧几里得算法求解

根据以下两个定理,可以求出线性同余方程 axb(modn)ax\equiv b \pmod n 的解。

定理 1

axb(modn)ax\equiv b \pmod n 可以改写为 ax=ny+bax = ny' + b ,移项得 axny=bax - ny' = b 。设 y=yy = -y' ,则有 ax+ny=bax + ny = b 。其中 xxyy 是未知数。这两个方程是等价的,有整数解的充要条件gcd(a,n)b\gcd(a,n) \mid b (因为 ax+nyax + ny 一定要是 gcd(a,n)gcd(a, n) 的倍数 )。

很容易发现,经过转换后的方程很像裴蜀定理的结论( ax+by=gcd(a,b)ax+by=\gcd(a,b) ),于是我们可以先用扩展欧几里得算法求出一组 x0,y0x'_0,y'_0 ,也就是 ax0+ny0=gcd(a,n)ax'_0+ny'_0=\gcd(a,n) ,然后两边同时除以 gcd(a,n)\gcd(a,n) ,再乘 bb 。就得到了方程:

abgcd(a,n)x0+nbgcd(a,n)y0=ba \cdot \frac{b}{\gcd(a,n)} \cdot x'_0 + n \cdot \frac{b}{\gcd(a,n)} \cdot y'_0 = b

于是找到方程的一个解:

{x0=bgcd(a,n)x0y0=bgcd(a,n)y0\begin{cases} x_0 = \frac{b}{\gcd(a,n)} \cdot x'_0 \\ y_0 = \frac{b}{\gcd(a,n)} \cdot y'_0 \end{cases}

定理 2

gcd(a,n)=1\gcd(a,n)=1 ,且 x0,y0x_0, y_0 为方程 ax+ny=bax+ny=b 的一组解,设 t=ngcd(a,n)t = \frac{n}{\gcd(a,n)} ,则该方程的任意解可表示为:

{x=x0+nty=y0at\begin{cases} x=x_0 + nt \\ y=y_0 - at \end{cases}

并且对任意整数 tt 都成立。

根据定理 2,可以从已求出的一个解,求出方程的所有解。实际问题中,往往要求出一个最小整数解,也就是一个特解(例如使用扩展欧几里得算法求解乘法逆元):

xmin=(x0modt+t)modtx_{\min}=(x_0 \bmod t+t) \bmod t

因为 gcd(a,n)=1\gcd(a, n) = 1 ,所以 t=ngcd(a,n)=nt = \frac{n}{\gcd(a,n)} = n ,于是就有:

xmin=(x0modn+n)modnx_{\min} = (x_0 \bmod n + n) \bmod n

模板题:

模板代码
int x, y;

int exgcd(int a, int b) {
    if (b == 0) {
        x = 1; y = 0;
        return a;
    }
    int d = exgcd(b, a % b);
    int t = x;
    x = y;
    y = t - a / b * y;
    return d;
}

// ax === b (mod n) 的解
int solve(int a, int b, int n) {
    int d = exgcd(a, n);
    if (b % d != 0) return -1; // 不存在整数解
    else return (int) (1L * x * (b / d) % n);
}

乘法逆元(模逆元)

定义

若整数 b,pb, p 互质[1:4],且对于任意一个整数 aa ,如果满足 bab \mid a [2] ,则存在一个整数 xx ,使得 abax(modp)\frac{a}{b} \equiv a \cdot x \pmod p [3] 。称 xxbb 在模 pp 意义下的乘法逆元,记作 b1(modp)b^{-1} \pmod p

作用

在要除以一个数,再取模时,把除法变成乘法运算,然后再取模

快速幂求逆元

只能在模数为质数的情况下使用。

根据逆元的定义,有: abab1(modp)\frac{a}{b} \equiv a \cdot b^{-1} \pmod p

两边同乘 bb 得: aabb1(modp)a \equiv a \cdot b \cdot b^{-1} \pmod p

所以: bb11(modp)b \cdot b^{-1} \equiv 1 \pmod p ①;

根据费马小定理 :若 pp 为质数且 b,pb, p 互质,则 bp11(modp)b^{p-1} \equiv 1 \pmod p

从上式 bp1b^{p-1} 中拆出一个 bb 得到: bbp21(modp)b \cdot b^{p-2} \equiv 1 \pmod p

结合①式得到:b1bp2(modp)b^{-1} \equiv b^{p-2} \pmod p

综上,在 bb 存在乘法逆元的条件下,求出 bp2modpb^{p-2} \bmod p 即为 bb 在模 pp 意义下的乘法逆元

bp2b^{p-2} 使用快速幂求解。

扩展欧几里得算法求逆元

根据逆元的定义,有: abab1(modp)\frac{a}{b} \equiv a \cdot b^{-1} \pmod p

两边同乘 bb 得: aabb1(modp)a \equiv a \cdot b \cdot b^{-1} \pmod p

所以: bb11(modp)b \cdot b^{-1} \equiv 1 \pmod p

a=b,x=b1,n=pa=b, x = b^{-1}, n = p ,上式表达为: ax1(modn)ax \equiv 1 \pmod n ,需要求解式中的 xx ,其实就是求解线性同余方程 axb(modn)ax\equiv b \pmod nb=1b = 1 的特殊情况下的解。根据逆元的定义有: gcd(a,n)=1\gcd(a, n) = 1 ,参考求解线性同余方程 - 用扩展欧几里得算法求解 - 定理 2xx 的一个最小整数解为:

x=(xmodn+n)modnx = (x \bmod n + n) \bmod n

模板代码:

int x, y;

int exgcd(int a, int b) {
    if (b == 0) {
        x = 1; y = 0;
        return a;
    }
    int d = exgcd(b, a % b);
    int t = x;
    x = y;
    y = t - a / b * y;
    return d;
}

// a在模n意义下的逆元
int inv(int a, int n) {
    exgcd(a, n);
    return (x % n + n) % n;
}

组合数

nn 个不同元素中,任取 m(mn)m(m\leq n) 个元素组成一个集合,叫做从 nn 个不同元素中取出 mm 个元素的一个组合;从 nn 个不同元素中取出 m(mn)m(m\leq n) 个元素的所有组合的个数,叫做从 nn 个不同元素中取出 mm 个元素的组合数。用符号 Cnm\mathrm{C}_n^m 来表示,组合数也常用 (nm)n \choose m 表示,读作「nnmm」,即 Cnm=(nm)\mathrm{C}_n^m={n \choose m}

组合数计算公式:

Cnm=n!m!(nm)!\mathrm{C}_n^m = \frac{n!}{m!(n-m)!}

特别地,规定当 m>nm>n 时,Anm=Cnm=0\mathrm A_n^m=\mathrm C_n^m=0

递推求组合数

模板题:求组合数 I

原理: Cnm=Cn1m+Cn1m1\mathrm{C}_n^m = \mathrm{C}_{n-1}^m + \mathrm{C}_{n-1}^{m-1}

证明:

nn 个物品中选择 mm 个物品( Cnm\mathrm{C}_n^m ),假如当前拿出来了一个物品:

  • 如果该物品不在要选择的 mm 个物品中,那么我们还需要在 n1n-1 个物品中选择 mm 个物品( Cn1m\mathrm{C}_{n-1}^{m} );
  • 如果该物品要选择的 mm 个物品中,那么我们还需要在 n1n-1 个物品中选择 m1m-1 个物品( Cn1m1\mathrm{C}_{n-1}^{m-1} )。

综上, Cnm=Cn1m+Cn1m1\mathrm{C}_n^m = \mathrm{C}_{n-1}^m + \mathrm{C}_{n-1}^{m-1} 成立。

适用数据范围:1T10000,1mn20001 \le T \le 10000, 1 \le m \le n \le 2000

模板代码:

static final int N;
static int[][] c = new int[N][N];

static void comb() {
    for (int n = 0; n < N; ++n) {
        for (int m = 0; m <= n; ++m) {
			if (m == 0) c[n][m] = 1;
            else c[n][m] = c[n - 1][m] + c[n - 1][m - 1];
        }
    }
}

// ans: c[a][b]

预处理阶乘求组合数

模板题:求组合数 II

原理:Cnm=n!×(m!)1×(n!m!)1(mod109)\mathrm{C}_n^m = n! \times (m!)^{-1} \times (n! - m!)^{-1} \pmod{10^9} (逆元的简单应用)。

适用数据范围:1T10000,1mn1051 \le T \le 10000, 1 \le m \le n \le 10^5

模板代码:

// MOD必须是质数才能用快速幂求逆元
static final int N, MOD;
// fact[i]表示i的阶乘(% MOD)  infact[i]表示i的阶乘的逆元(% MOD)
static int[] fact = new int[N], infact = new int[N];

// 预处理1~N的阶乘及其逆元
static {
    fact[0] = 1;
    for (int i = 1; i < N; ++i) {
        fact[i] = (int) (1L * fact[i - 1] * i % MOD);
        infact[i] = (int) qmi(fact[i], MOD - 2, MOD);
    }
}

// 快速幂 用作求逆元
static long qmi(long base, long exp, long mod) {
    long res = 1L;
    while (exp > 0) {
        if ((exp & 1) == 1) res = (res * base) % mod;
        base = (base * base) % mod;
        exp >>= 1;
    }
    return res;
}

static int C(int n, int m) {
    return n == m ? 1 : (int) (1L * fact[n] * infact[m] % MOD * infact[n - m] % MOD);
}

// ans: C(a, b)

卢卡斯定理求组合数

模板题:求组合数 III

原理:

  • a < p \and b < p ,那么直接从定义出发,使用公式求解:

Cnm=n!m!(nm)!=i=nm+1nim!=n(n1)(nm+2)(nm+1)1×2××m\mathrm{C}_n^{m} = \frac{n!}{m!(n-m)!} = \frac{\prod_{i=n-m+1}^{n} i}{m!} = \frac{n(n-1) \dots (n - m + 2)(n - m + 1)}{1 \times 2 \times \dots \times m}

  • 否则代入 Lucas 公式 :

    Cnmmodp=CnmodpmmodpCn/pm/pmodp\mathrm{C}_{n}^{m} \bmod p = \mathrm{C}_{n \bmod p}^{m \bmod p} \cdot \mathrm{C}_{\lfloor n / p \rfloor}^{\lfloor m / p \rfloor} \bmod p

    递归计算。

适用数据范围:1T20,1mn1018,1p1051 \le T \le 20, 1 \le m \le n \le 10^{18}, 1 \le p \le 10^5

模板代码:

// 快速幂 用作且求逆元
static long qmi(long base, long exp, long mod) {
    long res = 1L;
    while (exp > 0) {
        if ((exp & 1) == 1) res = (res * base) % mod;
        base = (base * base) % mod;
        exp >>= 1;
    }
    return res;
}

// 求组合数
static long C(long n, long m, long p) {
    long nume = 1L, deno = 1L;
    for (long a = n, b = 1; b <= m; --a, ++b) {
        nume = nume * a % p;
        deno = deno * b % p;
    }
    return nume * qmi(deno, p - 2, p) % p;
}

// 卢卡斯定理
static long lucas(long n, long m, long p) {
    if (n < p && m < p) return C(n, m, p);
    else return lucas(n % p, m % p, p) * lucas(n / p, m / p, p) % p;
}

// ans: lucas(a, b, p)

高精度+质因数分解求组合数

模板题:求组合数 IV

原理:

  • 根据公式: Cnm=n!m!(nm)!\mathrm{C}_n^{m} = \frac{n!}{m!(n-m)!} ,分别把 n!,m!,(nm)!n!, m!, (n-m)! 质因数分解,对于 1N1 \sim N 中的每个质数,维护 mp[i]mp[i] 表示: n!n! 质因数分解后 ii 的次数 - m!m! 质因数分解后 ii 的次数 - (nm)!(n - m)! 质因数分解后 ii 的次数。最后把 mpmp 中有次数的质数按照其次数乘起来就能得到答案;

  • 快速计算 nn 的阶乘的质因数分解中质因子 xx 的数量:勒让德定理

  • 高精度 Java 用 BigInteger 包。

适用数据范围:1mn50001 \le m \le n \le 5000

// 预处理筛出1~N之间的所有质数
static final int N;
static int[] primes = new int[N];
static int cnt = 0;
static boolean[] np = new boolean[N + 1];
// 线性筛
static {
    for (int x = 2; x <= N; ++x) {
        if (!np[x]) primes[cnt++] = x;
        for (int i = 0; primes[i] <= N / x; ++i) {
            np[primes[i] * x] = true;
            if (x % primes[i] == 0) break;
        }
    }
}

// 求n!中因子p的数量
static int count(int n, int d) {
    int cnt = 0;
    while (n > 0) cnt += n /= d;
    return cnt;
}

static BigInteger C(int n, int m) {
    BigInteger res = BigInteger.valueOf(1);
    for (int i = 0; i < cnt; ++i) {
        int p = primes[i];
        // 求每个质因子的幂次数
        int cnt = count(n, p) - count(m, p) - count(n - m, p);
        BigInteger bp = BigInteger.valueOf(p);
        while (cnt-- > 0) res = res.multiply(bp);
    }
    return res;
}

// ans: C(a, b)

卡特兰数

模板题:满足条件的01序列

C2nnC2nn1=C2nnn+1\mathrm{C}_{2n}^n - \mathrm{C}_{2n}^{n-1} = \frac{\mathrm{C}_{2n}^n}{n+1}

动态规划

0-1背包问题

以下代码中均用 n 代表物品种类, c 代表背包容量(capacity), v 代表物品体积(volume), w 代表物品价值(worth), s 代表物品种类。

模板代码

int n, c;
// 体积输入至v[1...n]  价值输入至w[1...n]
int[] v = new int[n + 1], w = new int[n + 1];
int[][] f = new int[n + 1][c + 1];

for (int i = 1; i <= n; ++i) {
    for (int j = 0; j <= c; ++j) {
        f[i][j] = f[i - 1][j];
        if (v[i] <= j) f[i][j] = Math.max(f[i][j], f[i - 1][j - v[i]] + w[i]);
    }
}

// ans: f[n][c]
一维数组优化
int n, c;
// 体积输入至v[0...n-1]  价值输入至w[0...n-1]
int[] v = new int[n], w = new int[n];
int[] f = new int[c + 1];

for (int i = 0; i < n; ++i) {
    for (int j = c; j >= v[i]; --j) {
        f[j] = Math.max(f[j], f[j - v[i]] + w[i]);
    }
}

// ans: f[c]

完全背包问题

模板代码

int n, c;
// 体积输入至v[1...n]  价值输入至w[1...n]
int[] v = new int[n + 1], w = new int[n + 1];
int[][] f = new int[n + 1][c + 1];

for (int i = 1; i <= n; ++i) {
    for (int j = 0; j <= c; ++j) {
        for (int k = 0; k * v[i] <= j; ++k) {
            f[i][j] = Math.max(f[i][j], f[i - 1][j - k * v[i]] + k * w[i]);
        }
    }
}

// ans: f[n][c]
一维数组优化
int n, c;
// 体积输入至v[0...n-1]  价值输入至w[0...n-1]
int[] v = new int[n], w = new int[n];
int[] f = new int[c + 1];

for (int i = 0; i < n; ++i) {
    for (int j = v[i]; j <= c; ++j) {
        f[j] = Math.max(f[j], f[j - v[i]] + w[i]);
    }
}

// ans: f[c]

多重背包问题

模板代码

朴素多重背包
int n, c;
// 体积输入至v[1...n]  价值输入至w[1...n]  数量输入至s[1...n]
int[] v = new int[n + 1], w = new int[n + 1], s = new int[n + 1];
int[][] f = new int[n + 1][c + 1];

for (int i = 1; i <= n; ++i) {
    for (int j = 0; j <= c; ++j) {
        f[i][j] = f[i - 1][j];
        for (int k = 0; k * v[i] <= j && k <= s[i]; ++k) {
            f[i][j] = Math.max(f[i][j], f[i - 1][j - k * v[i]] + k * w[i]);
        }
    }
}

// ans: f[n][c]
二进制优化+01背包一维优化
// N: 最大物品种类  S: 单个物品的最大数量
static final int N, S;
// M = N * ⌈log_2{S}⌉
static final int M = (int) (N * Math.ceil(Math.log(S) / Math.log(2)));
int n, c;
int v = new int[M], w = new int[M];
int cnt = 0;

// 输入处理
for (int i = 0; i < n; ++i) {
    // 输入每个物品的体积, 价值, 最大数量
    int vi, wi, si;
    for (int x = 1; x <= si; ++cnt, si -= x, x <<= 1) {
        v[cnt] = x * vi;
        w[cnt] = x * wi;
    }
    if (si > 0) {
        v[cnt] = vi * si;
        w[cnt] = wi * si;
        ++cnt;
    }
}
n = cnt;
int[] f = new int[c + 1];

// 01背包一维优化
for (int i = 0; i < n; ++i) {
    for (int j = c; j >= v[i]; --j) {
        f[j] = Math.max(f[j], f[j - v[i]] + w[i]);
    }
}

// ans: f[c]
单调队列优化+拷贝数组优化
int n, c;
int[] f = new int[c + 1], g;
int[] que = new int[c + 1];
int hh, tt;
for (int i = 0; i < n; ++i) {
    // 输入每个物品的体积, 价值, 最大数量
    int v, w, s;
    g = f.clone(); // 拷贝上一层状态
    // 枚举余数
    for (int r = 0; r < v; ++r) {
        hh = 0; tt = -1; // 清空优先队列
        // 从余数开始枚举空间
        for (int j = r; j <= c; j += v) {
            // 将超出窗口范围的元素出队(`j - que[hh] / v`表示滑动窗口中的元素数量)
            while (hh <= tt && (j - que[hh]) / v > s) ++hh;
            // 当前状态比队尾元素表示的状态更优 队尾元素没有存在必要 队尾出队
            // 注意: 队尾元素需要加上价值偏移量: `(j - que[tt]) / v * w`
            while (hh <= tt && g[j] >= g[que[tt]] + (j - que[tt]) / v * w) --tt;
            // 当前下标入队
            que[++tt] = j;
            // 更新当前这一层的状态(注意依旧要加上价值偏移量)
            f[j] = g[que[hh]] + (j - que[hh]) / v * w;
        }
    }
}

// ans: f[c]

分组背包问题

优化后模板

static final int S; // 单个物品的最大数量
int n, c;
// 数量输入至s[0...n]
int[] s = new int[n];
// 体积输入至v[0...n-1][0...s[i]-1]  价值输入至w[0...n-1][0...s[i]-1] 
int[][] v = new int[n][S], w = new int[n][S];
int[] f = new int[c + 1];

for (int i = 0; i < n; ++i) {
    for (int j = c; j >= 0; --j) {
        for (int k = 0; k < s[i]; ++k) {
            if (v[i][k] <= j) f[j] = Math.max(f[j], f[j - v[i][k]] + w[i][k]);
        }
    }
}

// ans: f[c]

最长上升子序列(LIS)

时间复杂度:O(n2)O(n^2)

// 序列长度为n  序列输入至a[0...n-1]
int f[n], mx;

for (int i = 0; i < n; ++i) {
    f[i] = 1;
    for (int j = 0; j < i; ++j) {
        if (a[j] < a[i]) f[i] = max(f[i], f[j] + 1);
    }
    mx = max(mx, f[i]);
}

// ans: mx

贪心算法优化

时间复杂度:O(nlogn)O(n\log{n})

// 序列长度为n  序列输入至a[0...n-1]
int b[n + 1], cnt = 0;

int search(int l, int r, int x) {
    while (l < r) {
        int m = l + r + 1>> 1;
        if (b[m] < x) l = m;
        else r = m - 1;
    }
    return r;
}

for (int i = 0; i < n; ++i) {
    int len = search(0, cnt, a[i]);
    cnt = max(cnt, ++len);
    b[len] = a[i];
}

// ans: cnt

最长相同子序列(LCS)

// 两字符串序列长度分别为n, m  输入至a[1...n], b[1...m]
int n, m;
char[] a, b;
int[][] f = new int[n + 1][m + 1];

for (int i = 1; i <= n; ++i) {
    for (int j = 1; j <= m; ++j) {
        f[i][j] = Math.max(f[i - 1][j], f[i][j - 1]);
        if (a[i] == b[j]) f[i][j] = Math.max(f[i][j], f[i - 1][j - 1] + 1);
    }
}

// ans: f[n][m]

数位DP

记忆化搜索模板(不考虑前导零)

const int N, K, R;
int len, digits[N];
int f[N][K];

int dfs(int pos, int info, bool lim) {
    if (pos < 0) // 递归出口
    if (!lim && f[pos][info] != -1) return f[pos][info];
    int res = 0, up = lim ? digits[pos] : /* 无限制时该位能填的最大数 */;
    for (int d = /* 下界 */; d <= up; ++d) {
        res += dfs(pos - 1, /* 信息 */, lim && d == digits[pos]);
    }
    return lim ? res : f[pos][info] = res;
}

int count(int n) {
    for (len = 0; n; n /= R) digits[len++] = n % R;
    memset(f, -1, sizeof(f));
    return dfs(len - 1, /* 信息 */, true);
}
带注释模板
// N为数据范围在需要考虑的进制下的最大位数 比如: 
//  - int范围内10进制最大9位 N就可以赋15
//  - int范围内2进制最大31位 N可以赋35
// K为需要携带的数字信息的最大数 比如:
//  - 携带的信息是10进制下的数上一位(0~9) K就可以赋15
//  - 携带的信息是2进制下的数中数位为1的数量(0~31) K就可以赋35
// R为需要考虑的进制系统(一般为2进制或10进制)
const int N, K, R;
int len, digits[N];
int f[N][K];

// pos: 当前枚举到的位  info: 携带的数字信息  lim: 之前的每一位数是否都到达了上界(n里的该数位)
int dfs(int pos, int info, bool lim) {
    if (pos < 0) // 递归出口
    if (!lim && f[pos][info] != -1) return f[pos][info];
    int res = 0, up = lim ? digits[pos] : /* 无限制时该位能填的最大数 */;
    // 枚举当前位上可以填的数字 下界根据题目情况来定
    for (int d = /* 下界 */; d <= up; ++d) {
        // 这一行可以判断非法条件并根据情况continue
        res += dfs(pos - 1, /* 信息 */, lim && d == digits[pos]); // 递归处理下一位
    }
    // 该位未到达上界时 在返回的同时写入记忆数组
    return lim ? res : f[pos][info] = res;
}

int count(int n) {
    len = 0;
    // 预处理n在R进制下的每一位 记录在digits[0...len-1]中
    while (n) {
        digits[len++] = n % R;
        n /= R;
    }
    memset(f, -1, sizeof(f)); // 初始化记忆数组
    return dfs(len - 1, /* 信息 */, true);
}

记忆化搜索模板(考虑前导零)

const int N, K, R;
int len, digits[N];
int f[N][K];

int dfs(int pos, int info, bool lead, bool lim) {
    if (pos < 0) // 递归出口
    if (!lim && !lead && ~f[pos][info]) return f[pos][info];
    int res = 0, up = lim ? digits[pos] : /* 无限制时该位能填的最大数 */;
    for (int d = /* 下界 */; d <= up; ++d) {
        bool lead_zero = lead && !d;
        res += dfs(pos - 1, /* 信息 */, lead_zero, lim && d == digits[pos]);
    }
    return lim || lead ? res : f[pos][info] = res;
}

int count(int n) {
    for (len = 0; n; n /= R) digits[len++] = n % R;
    memset(f, -1, sizeof(f));
    return dfs(len - 1, /* 信息 */, true, true);
}
带注释模板
// N为数据范围在需要考虑的进制下的最大位数 比如: 
//  - int范围内10进制最大9位 N就可以赋15
//  - int范围内2进制最大31位 N可以赋35
// K为需要携带的数字信息的最大数 比如:
//  - 携带的信息是10进制下的数上一位(0~9) K就可以赋15
//  - 携带的信息是2进制下的数中数位为1的数量(0~31) K就可以赋35
// R为需要考虑的进制系统(一般为2进制或10进制)
const int N, K, R;
int len, digits[N];
int f[N][K];

// pos: 当前枚举到的位  info: 携带的数字信息
// lead: 上一位是否为前导零  lim: 之前的每一位数是否都到达了上界(n里的该数位)
int dfs(int pos, int info, bool lead, bool lim) {
    if (pos < 0) // 递归出口
    if (!lim && !lead && ~f[pos][info]) return f[pos][info];
    int res = 0, up = lim ? digits[pos] : /* 无限制时该位能填的最大数 */;
    // 枚举当前位上可以填的数字 下界根据题目情况来定
    for (int d = /* 下界 */; d <= up; ++d) {
        // 这一行可以判断非法条件并根据情况continue
        bool lead_zero = lead && !d; // 判断该位是否为前导零
        res += dfs(pos - 1, /* 信息 */, lead_zero, lim && d == digits[pos]);
    }
    // 该位未到达上界时 在返回的同时写入记忆数组
    return lim || lead ? res : f[pos][info] = res;
}

int count(int n) {
    // 预处理n在R进制下的每一位 记录在digits[0...len-1]中
    for (len = 0; n; n /= R) digits[len++] = n % R;
    memset(f, -1, sizeof(f)); // 初始化记忆数组
    return dfs(len - 1, /* 信息 */, true, true);
}

树状数组

模板题:动态求连续区间和

树状数组是一种可以以 O(logn)O(\log{n}) 的时间复杂度完成 单点修改区间查询 的,代码量小的数据结构。

// 序列长度为n  序列输入至a[1...n]
int[] a, tr;

int lowbit(int n) {
    return n & -n;
}

// 把a[i]自增x
void add(int i, int x) {
    for (; i <= n; i += lowbit(i)) tr[i] += x;
}

// 求a[1...i]的累和
int query(int i) {
    int sum = 0;
    for (; i > 0; i -= lowbit(i)) sum += tr[i];
    return sum;
}

线段树

模板题:动态求连续区间和

// 维护区间和的线段树
// 序列长度为n  序列输入至a[1...n]
int[] a;
class Node {
    int l, r, sum;

    public Node(int l, int r, int sum) {
        this.l = l;
        this.r = r;
        this.sum = sum;
    }
}
Node[] tr;

// 用子节点的信息更新当前节点信息
void pushup(int u) {
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

// 在区间[l...r]上初始化线段树  u为当前段根节点位置
void build(int u, int l, int r) {
    if (l == r) tr[u] = new Node(l, r, a[r]);
    else {
        tr[u] = new Node(l, r, 0);
        int m = l + r >> 1;
        build(u << 1, l, m);
        build(u << 1 | 1, m + 1, r);
        pushup(u);
    }
}

// 查询[l...r]  u为当前段根节点位置
int query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
    else {
        int m = tr[u].l + tr[u].r >> 1;
        int sum = 0;
        if (l <= m) sum += query(u << 1, l, r);
        if (r >= m + 1) sum += query(u << 1 | 1, l, r);
        return sum;
    }
}

// 把位置i的数字修改为x  u为当前段根节点位置
void modify(int u, int i, int x) {
    if (tr[u].l == tr[u].r) tr[u].sum += x;
    else {
        int m = tr[u].l + tr[u].r >> 1;
        if (i <= m) modify(u << 1, i, x);
        else modify(u << 1 | 1, i, x);
        pushup(u);
    }
}

  1. gcd(b,m)=1\gcd(b, m) = 1↩︎ ↩︎ ↩︎ ↩︎ ↩︎

  2. aa 能被 bb 整除( ab\frac{a}{b} 是一个整数 )。 ↩︎

  3. ab\frac{a}{b} 在模 pp同余于 axa \cdot xabmodp=axmodp\frac{a}{b} \bmod p = a \cdot x \bmod p )。 ↩︎

10

评论区