一种利用线段树来在 $O(n \log n)$ 的时间复杂度内维护区间最值和区间历史最值的线段树变体。
区间最值
维护区间最值操作,说白了就是需要让一个区间内的所有数字和一个给定的数字 $x$ 取 $\min$ 或取 $\max$。
或者,我们可以将这个理解为对当前区间内大于(或小于)$x$ 的数字同意修改为 $x$。
以弱化的洛谷P6242【模板】线段树3为例,我们示范一下取min的操作。
弱化后的题面如下:
给出一个长度为 $n$ 的数列 $a$,接下来进行了 $m$ 次操作,操作有四种类型,按以下格式给出:
1 l r k
:对于所有的 $i \in [l,r]$,将 $a_i$ 加上 $k$($k$ 可以为负数)。
2 l r k
:对于所有的 $i \in [l,r]$,将 $a_i$ 变成 $\min(a_i,k)$。
3 l r
:求 $\sum_{i=l}^{r} a_i$。
4 l r
:对于所有的 $i \in [l,r]$,求 $a_i$ 的最大值。
我们需要首先处理的就是这个取 $\min$ 操作。
我们按照后一种理解,将其解释为将操作区间内所有大于等于 $k$ 的数都赋成 $k$。
那我们如果维护一下区间内所有数的类型就可以了,一个值域树状数组或许可以解决,不过维护的代价太大了。
尝试只维护当前的最大值和次大值,发现维护不了了就直接向下摊派,稍微模拟一下我们发现这样维护是完全可以的,空间复杂度十分宽裕,时间复杂度也没有那么多。如果尝试证明的话可以用势能分析,可以证明均摊时间复杂度是 $O(n \log n)$ 的。
具体过程就是,我们访问到当前区间的时候,看一下我们的这个 $k$ 与当前区间最大值和次大值的关系。
- 如果大于最大值,那么这个区间和其下面的区间也就没有必要维护了,直接返回。
- 如果小于最大值,但大于次大值,那么我们只需要修改该区间即可,剩下的信息我们通过懒标记来维护。
- 如果小于次大值,那么当前区间的这个体量不是直接修改能够维护的,我们选择向下递归维护。
下面是示例代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| void segmin(int p, int l, int r, ll k) { if(k >= tr[p].maxn)return; if(tr[p].l >= l && tr[p].r <= r && k > tr[p].maxs) { add(p, k - tr[p].maxn, k - tr[p].maxn, 0, 0); return; } pushdown(p); int mid = (tr[p].l + tr[p].r) >> 1; if(l <= mid)segmin(p << 1, l, r, k); if(r > mid)segmin(p << 1 | 1, l, r, k); pushup(p); }
|
此时我们发现我们需要维护的东西变成了下面这些:
1 2 3 4 5 6 7
| struct SegTree { int l, r; ll sum; ll maxn, maxs, cmax; ll addn, adds; };
|
懒标记的维护方法稍微有些繁琐:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| void pushup(int p) { tr[p].sum = tr[p << 1].sum + tr[p << 1 | 1].sum; if(tr[p << 1].maxn > tr[p << 1 | 1].maxn) { tr[p].maxn = tr[p << 1].maxn; tr[p].maxs = max(tr[p << 1].maxs, tr[p << 1 | 1].maxn); tr[p].cmax = tr[p << 1].cmax; } else if(tr[p << 1].maxn < tr[p << 1 | 1].maxn) { tr[p].maxn = tr[p << 1 | 1].maxn; tr[p].maxs = max(tr[p << 1].maxn, tr[p << 1 | 1].maxs); tr[p].cmax = tr[p << 1 | 1].cmax; } else { tr[p].maxn = tr[p << 1].maxn; tr[p].maxs = max(tr[p << 1].maxs, tr[p << 1 | 1].maxs); tr[p].cmax = tr[p << 1].cmax + tr[p << 1 | 1].cmax; } }
|
1 2 3 4 5 6 7
| void add(int p, ll vn, ll vs) { tr[p].sum += 1ll * vn * tr[p].cmax; tr[p].sum += 1ll * vs * (tr[p].r - tr[p].l + 1 - tr[p].cmax); tr[p].addn += vn, tr[p].adds += vs; tr[p].maxn += vn, tr[p].maxs += vs; }
|
1 2 3 4 5 6 7 8 9 10 11
| void pushdown(int p) { int maxn = max(tr[p << 1].maxn, tr[p << 1 | 1].maxn); if(maxn == tr[p << 1].maxn) add(p << 1, tr[p].addn, tr[p].adds); else add(p << 1, tr[p].adds, tr[p].adds); if(maxn == tr[p << 1 | 1].maxn) add(p << 1 | 1, tr[p].addn, tr[p].adds); else add(p << 1 | 1, tr[p].adds, tr[p].adds); tr[p].addn = tr[p].adds = 0; }
|
其他函数照常。
历史区间最值
历史区间最值并不是可持久化,而是给这个区间内所有出现过的不同的值取最值。
看一下上面那道线段树3的原题,我们可以发现还有一个操作我们没有维护,那就是维护一个历史最值数组,每一次更新原数组的值之后都会让历史最值数组的每一位与原数组的对应位置取 $\max$。
题目中的第5个操作是对这个历史最值数组询问区间最大值,这种操作对应的就是询问区间历史最大值。
对此,我们只需要多维护几个值即可。
1 2 3 4 5 6 7 8 9
| struct SegTree { int l, r; ll sum; ll maxn, maxs, cmax; ll hmaxn, hmaxs; ll addn, adds; ll haddn, hadds; };
|
维护的过程也添了不少东西:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| void pushup(int p) { tr[p].sum = tr[p << 1].sum + tr[p << 1 | 1].sum; tr[p].hmaxn = max(tr[p << 1].hmaxn, tr[p << 1 | 1].hmaxn); if(tr[p << 1].maxn > tr[p << 1 | 1].maxn) { tr[p].maxn = tr[p << 1].maxn; tr[p].maxs = max(tr[p << 1].maxs, tr[p << 1 | 1].maxn); tr[p].cmax = tr[p << 1].cmax; } else if(tr[p << 1].maxn < tr[p << 1 | 1].maxn) { tr[p].maxn = tr[p << 1 | 1].maxn; tr[p].maxs = max(tr[p << 1].maxn, tr[p << 1 | 1].maxs); tr[p].cmax = tr[p << 1 | 1].cmax; } else { tr[p].maxn = tr[p << 1].maxn; tr[p].maxs = max(tr[p << 1].maxs, tr[p << 1 | 1].maxs); tr[p].cmax = tr[p << 1].cmax + tr[p << 1 | 1].cmax; } }
|
1 2 3 4 5 6 7 8 9 10
| void add(int p, ll vn, ll hvn, ll vs, ll hvs) { tr[p].hmaxn = max(tr[p].hmaxn, tr[p].maxn + hvn); tr[p].haddn = max(tr[p].haddn, tr[p].addn + hvn); tr[p].hadds = max(tr[p].hadds, tr[p].adds + hvs); tr[p].sum += 1ll * vn * tr[p].cmax; tr[p].sum += 1ll * vs * (tr[p].r - tr[p].l + 1 - tr[p].cmax); tr[p].addn += vn, tr[p].adds += vs; tr[p].maxn += vn, tr[p].maxs += vs; }
|
1 2 3 4 5 6 7 8 9 10 11
| void pushdown(int p) { int maxn = max(tr[p << 1].maxn, tr[p << 1 | 1].maxn); if(maxn == tr[p << 1].maxn) add(p << 1, tr[p].addn, tr[p].haddn, tr[p].adds, tr[p].hadds); else add(p << 1, tr[p].adds, tr[p].hadds, tr[p].adds, tr[p].hadds); if(maxn == tr[p << 1 | 1].maxn) add(p << 1 | 1, tr[p].addn, tr[p].haddn, tr[p].adds, tr[p].hadds); else add(p << 1 | 1, tr[p].adds, tr[p].hadds, tr[p].adds, tr[p].hadds); tr[p].addn = tr[p].haddn = tr[p].adds = tr[p].hadds = 0; }
|