template <classT, int N> classWeightedSegmentTree { private: int cnt, root, ch[2][N << 1]; T val[N << 1];
intnew_node(){ int cur = ++cnt; val[cnt] = 0, ch[0][cnt] = ch[1][cnt] = 0; return cur; }
voidupdate(int p, const T& v, int l, int r, int& rt){ if (!rt) rt = new_node(); val[rt] += v; if (l == r) return; int mid = (l + r) >> 1; if (p <= mid) update(p, v, l, mid, ch[0][rt]); elseupdate(p, v, mid + 1, r, ch[1][rt]); }
T query(int tl, int tr, int l, int r, int rt){ if (tl <= l && r <= tr) return val[rt]; int mid = (l + r) >> 1; T res(0); if (tl <= mid) res += query(tl, tr, l, mid, ch[0][rt]); if (tr > mid) res += query(tl, tr, mid + 1, r, ch[1][rt]); return res; }
intkth(const T& k, int l, int r, int rt){ if (l == r) return l; int mid = (l + r) >> 1; if (val[ch[0][rt]] >= k) returnkth(k, l, mid, ch[0][rt]); returnkth(k - val[ch[0][rt]], mid + 1, r, ch[1][rt]); }
voidupdate(int p, const T& v){update(p, v, 0, N, root);} T query(int l, int r){returnquery(l, r, 0, N, root);} intkth(const T& k){returnkth(k, 0, N, root);} };
可持久化的作用
现在让我们扩展一下这个问题:查询区间第 k 小。
区间问题有一个 trick:[l,r]=[1,r]−[1,l−1]。于是我们开 n 个权值线段树,第 i 棵维护序列前缀 [1,i],然后在第 l−1 和第 r 棵里递归查找第 k 小……
然而稍微分析一下,空间 O(nV) 已经爆炸,时间也不优。
如何优化?注意到,我们每次只加入一个数,这个操作的性质是很好的。考虑代码
1 2 3 4 5 6 7 8
voidupdate(int p, const T& v, int l, int r, int& rt){ if (!rt) rt = new_node(); val[rt] += v; if (l == r) return; int mid = (l + r) >> 1; if (p <= mid) update(p, v, l, mid, ch[0][rt]); elseupdate(p, v, mid + 1, r, ch[1][rt]); }
int n, m, q, a[N], l, r, k, d[N], mp[N]; // 这里m是值域
int cnt, roots[N], sum[N * 20], ls[N * 20], rs[N * 20]; intbuild(int l, int r){ int p = ++cnt; if (l == r) return p; int mid = (l + r) >> 1; return ls[p] = build(l, mid), rs[p] = build(mid + 1, r), p; } intcopy(int rt){ // 复制原树节点 int p = ++cnt; return ls[p] = ls[rt], rs[p] = rs[rt], sum[p] = sum[rt], p; } intupdate(int x, int l, int r, int rt){ int p = copy(rt); sum[p]++; if (l == r) return p; int mid = (l + r) >> 1; if (x <= mid) ls[p] = update(x, l, mid, ls[p]); else rs[p] = update(x, mid + 1, r, rs[p]); return p; } intquery(int l, int r, int k, int rt1, int rt2){ int mid = (l + r) >> 1, num = sum[ls[rt2]] - sum[ls[rt1]]; if (l == r) return l; if (k <= num) returnquery(l, mid, k, ls[rt1], ls[rt2]); elsereturnquery(mid + 1, r, k - num, rs[rt1], rs[rt2]); }
void _main() { cin >> n >> q; for (int i = 1; i <= n; i++) cin >> a[i], d[i] = a[i]; sort(d + 1, d + n + 1); m = unique(d + 1, d + n + 1) - d - 1; for (int i = 1; i <= n; i++) { int val = lower_bound(d + 1, d + m + 1, a[i]) - d; mp[val] = a[i], a[i] = val; } roots[0] = build(1, m); for (int i = 1; i <= n; i++) roots[i] = update(a[i], 1, m, roots[i - 1]); while (q--) { cin >> l >> r >> k; cout << mp[query(1, m, k, roots[l - 1], roots[r])] << '\n'; } }
int cnt, roots[N], last[N * 20], ls[N * 20], rs[N * 20]; intcopy(int rt){ int p = ++cnt; return ls[p] = ls[rt], rs[p] = rs[rt], last[p] = last[rt], p; } intupdate(int x, int c, int l, int r, int rt){ int p = copy(rt); if (l == r) return last[p] = c, p; int mid = (l + r) >> 1; if (x <= mid) ls[p] = update(x, c, l, mid, ls[p]); else rs[p] = update(x, c, mid + 1, r, rs[p]); last[p] = min(last[ls[p]], last[rs[p]]); return p; } intquery(int x, int l, int r, int rt){ if (l == r) return l; int mid = (l + r) >> 1; if (last[ls[rt]] < x) returnquery(x, l, mid, ls[rt]); elsereturnquery(x, mid + 1, r, rs[rt]); }
void _main() { cin >> n >> q; for (int i = 1; i <= n; i++) cin >> a[i]; for (int i = 1; i <= n; i++) roots[i] = update(a[i], i, 0, n + 1, roots[i - 1]); while (q--) { cin >> l >> r; cout << query(l, 0, n + 1, roots[r]) << '\n'; } }
int cnt, roots[N], sum[N << 5], ls[N << 5], rs[N << 5]; intcopy(int rt){ int p = ++cnt; return ls[p] = ls[rt], rs[p] = rs[rt], sum[p] = sum[rt], p; } intupdate(int x, int c, int l, int r, int rt){ int p = copy(rt); sum[p] += c; if (l == r) return p; int mid = (l + r) >> 1; if (x <= mid) ls[p] = update(x, c, l, mid, ls[p]); else rs[p] = update(x, c, mid + 1, r, rs[p]); return p; } intquery(int tl, int tr, int l, int r, int rt1, int rt2){ if (tl <= l && r <= tr) return sum[rt1] - sum[rt2]; int mid = (l + r) >> 1, res = 0; if (tl <= mid) res += query(tl, tr, l, mid, ls[rt1], ls[rt2]); if (tr > mid) res += query(tl, tr, mid + 1, r, rs[rt1], rs[rt2]); return res; }
void _main() { cin >> n; for (int i = 1; i <= n; i++) cin >> a[i]; for (int i = 1; i <= n; i++) roots[i] = update(a[i], a[i], 1, 1e9, roots[i - 1]); cin >> q; while (q--) { cin >> l >> r; int res = 1; while (true) { int val = query(1, res, 1, 1e9, roots[r], roots[l - 1]); if (val >= res) res = val + 1; elsebreak; } cout << res << '\n'; } }