CC BY 4.0 (除特别声明或转载文章外)
如果这篇博客帮助到你,可以请我喝一杯咖啡~
继续学习可持久化数据结构…
这次学习区间更新的可持久化线段树。具体来说,就是每次操作都复制一下根节点,down操作的时候创建子结点即可。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int NPOS = -1, N = 1e5 + 9;
ll a[N];
struct SegmentTree
{
struct Seg
{
int l, r;
ll sum;
void upd(ll add) { sum += add * (r - l + 1); }
friend Seg up(const Seg &lc, const Seg &rc) { return {lc.l, rc.r, lc.sum + rc.sum}; }
};
struct Node : Seg
{
int lc, rc;
ll add;
};
vector<Node> v;
SegmentTree(int l, int r) { build(l, r); }
void build(int l, int r)
{
int rt = v.size();
v.push_back({});
v[rt].Seg::operator=({l, r, a[r] - a[l - 1]});
v[rt].lc = v[rt].rc = NPOS;
v[rt].add = 0;
}
void down(int rt)
{
int m = v[rt].l + v[rt].r >> 1;
if (v[rt].lc == NPOS)
v[rt].lc = v.size(), build(v[rt].l, m);
else
v.push_back(v[v[rt].lc]), v[rt].lc = v.size() - 1;
if (v[rt].rc == NPOS)
v[rt].rc = v.size(), build(m + 1, v[rt].r);
else
v.push_back(v[v[rt].rc]), v[rt].rc = v.size() - 1;
upd(v[v[rt].lc].l, v[v[rt].lc].r, v[rt].add, v[rt].lc);
upd(v[v[rt].rc].l, v[v[rt].rc].r, v[rt].add, v[rt].rc);
v[rt].add = 0;
}
void upd(int l, int r, ll add, int rt)
{
if (l <= v[rt].l && v[rt].r <= r)
return v[rt].add += add, v[rt].upd(add);
down(rt);
if (r <= v[v[rt].lc].r)
upd(l, r, add, v[rt].lc);
else if (l >= v[v[rt].rc].l)
upd(l, r, add, v[rt].rc);
else
upd(l, v[v[rt].lc].r, add, v[rt].lc), upd(v[v[rt].rc].l, r, add, v[rt].rc);
v[rt].Seg::operator=(up(v[v[rt].lc], v[v[rt].rc]));
}
Seg ask(int l, int r, int rt)
{
if (l <= v[rt].l && v[rt].r <= r)
return v[rt];
down(rt);
if (r <= v[v[rt].lc].r)
return ask(l, r, v[rt].lc);
if (l >= v[v[rt].rc].l)
return ask(l, r, v[rt].rc);
return up(ask(l, v[v[rt].lc].r, v[rt].lc), ask(v[v[rt].rc].l, r, v[rt].rc));
}
};
int main()
{
for (int n, m, l, r, d; ~scanf("%d%d", &n, &m);)
{
for (int i = 1; i <= n; ++i)
scanf("%lld", &a[i]), a[i] += a[i - 1];
SegmentTree t(1, n);
vector<int> root{0}, size{t.v.size()};
for (char s[9]; m--;)
{
scanf("%s%d", s, &l);
if (s[0] == 'C')
{
scanf("%d%d", &r, &d);
t.v.push_back(t.v[root.back()]);
root.push_back(t.v.size() - 1);
t.upd(l, r, d, root.back());
size.push_back(t.v.size());
}
else if (s[0] == 'B')
root.resize(l + 1), size.resize(l + 1);
else
{
scanf("%d", &r);
if (s[0] == 'Q')
d = root.size() - 1;
else
scanf("%d", &d);
t.v.push_back(t.v[root[d]]);
printf("%lld\n", t.ask(l, r, t.v.size() - 1).sum);
}
t.v.resize(size.back(), t.v[0]);
}
}
}
网上常见的解法是区间更新时不把标记下传而是直接带入返回结果,可以减少空间消耗(然而这里偷懒直接用vector
,还直接把左右端点维护在线段树节点内…)。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll NPOS = -1, N = 1e5 + 9;
struct HJTree
{
struct Node
{
int l, r, lc, rc;
ll sum, add;
void up(Node &lc, Node &rc)
{
l = lc.l;
r = rc.r;
sum = lc.sum + rc.sum + add * (r - l + 1);
}
void _add(ll a)
{
add += a, sum += a * (r - l + 1);
}
};
vector<Node> v;
vector<int> root;
HJTree(ll a[], int l, int r) : root(1, build(a, l, r)) {}
int build(ll a[], int l, int r)
{
if (l >= r)
return v.push_back({l, r, NPOS, NPOS, a[l], 0}), v.size() - 1;
int m = l + r >> 1;
return v.push_back({l, r, build(a, l, m), build(a, m + 1, r), 0, 0}), v.back().up(v[v.back().lc], v[v.back().rc]), v.size() - 1;
}
int add(int l, int r, int val, int rt)
{
if (l <= v[rt].l && v[rt].r <= r)
return v.push_back(v[rt]), v.back()._add(val), v.size() - 1;
int m = v[rt].l + v[rt].r >> 1, lc = v[rt].lc, rc = v[rt].rc;
if (r <= m)
lc = add(l, r, val, lc);
else if (l > m)
rc = add(l, r, val, rc);
else
{
lc = add(l, m, val, lc);
rc = add(m + 1, r, val, rc);
}
return v.push_back(v[rt]), v.back().lc = lc, v.back().rc = rc, v.back().up(v[lc], v[rc]), v.size() - 1;
}
ll ask(int l, int r, int rt)
{
if (l <= v[rt].l && v[rt].r <= r)
return v[rt].sum;
int m = v[rt].l + v[rt].r >> 1;
ll ret = v[rt].add * (r - l + 1);
if (r <= m)
return ret + ask(l, r, v[rt].lc);
if (l > m)
return ret + ask(l, r, v[rt].rc);
return ret + ask(l, m, v[rt].lc) + ask(m + 1, r, v[rt].rc);
}
};
ll a[N];
int main()
{
for (int n, m, l, r, d; ~scanf("%d%d", &n, &m);)
{
for (int i = 1; i <= n; ++i)
scanf("%lld", &a[i]);
HJTree t(a, 1, n);
for (char s[9]; m--;)
{
scanf("%s%d", s, &l);
if (s[0] == 'C')
{
scanf("%d%d", &r, &d);
t.root.push_back(t.add(l, r, d, t.root.back()));
}
else if (s[0] == 'B')
{
t.root.resize(l + 1);
t.v.resize(t.root[l] + 1);
}
else
{
scanf("%d", &r);
if (s[0] == 'Q')
d = t.root.size() - 1;
else
scanf("%d", &d);
printf("%lld\n", t.ask(l, r, t.root[d]));
}
}
}
}