线段树
最基本的线段树
模板题:动态求连续区间和。
线段树模板,以维护区间和为例:
Java
// 维护区间和的线段树
// 序列长度为n 序列输入至a[1...n]
int[] a;
class Node {
int l, r, sum;
public Node(int l, int r, int sum) {
this.l = l;
this.r = r;
this.sum = sum;
}
}
// 线段树节点数组长度一般开到4n
Node[] tr;
// 用子节点的信息更新当前节点信息
void pushUp(int u) {
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
// 在区间[l...r]上初始化线段树 u为当前段根节点位置
void build(int u, int l, int r) {
if (l == r) tr[u] = new Node(l, r, a[r]);
else {
tr[u] = new Node(l, r, 0);
int m = l + r >> 1;
build(u << 1, l, m);
build(u << 1 | 1, m + 1, r);
pushUp(u);
}
}
// 查询[l...r] u为当前段根节点位置
int query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
else {
int m = tr[u].l + tr[u].r >> 1;
int sum = 0;
if (m >= l) sum += query(u << 1, l, r);
if (m + 1 <= r) sum += query(u << 1 | 1, l, r);
return sum;
}
}
// 把位置i的数字修改为x u为当前段根节点位置
void modify(int u, int i, int x) {
if (tr[u].l == tr[u].r) tr[u].sum = x;
else {
int m = tr[u].l + tr[u].r >> 1;
if (m >= i) modify(u << 1, i, x);
else modify(u << 1 | 1, i, x);
pushUp(u);
}
}
// 维护区间和的线段树
// 序列长度为n 序列输入至a[1...n]
int[] a;
class Node {
int l, r, sum;
public Node(int l, int r, int sum) {
this.l = l;
this.r = r;
this.sum = sum;
}
}
// 线段树节点数组长度一般开到4n
Node[] tr;
// 用子节点的信息更新当前节点信息
void pushUp(int u) {
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
// 在区间[l...r]上初始化线段树 u为当前段根节点位置
void build(int u, int l, int r) {
if (l == r) tr[u] = new Node(l, r, a[r]);
else {
tr[u] = new Node(l, r, 0);
int m = l + r >> 1;
build(u << 1, l, m);
build(u << 1 | 1, m + 1, r);
pushUp(u);
}
}
// 查询[l...r] u为当前段根节点位置
int query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
else {
int m = tr[u].l + tr[u].r >> 1;
int sum = 0;
if (m >= l) sum += query(u << 1, l, r);
if (m + 1 <= r) sum += query(u << 1 | 1, l, r);
return sum;
}
}
// 把位置i的数字修改为x u为当前段根节点位置
void modify(int u, int i, int x) {
if (tr[u].l == tr[u].r) tr[u].sum = x;
else {
int m = tr[u].l + tr[u].r >> 1;
if (m >= i) modify(u << 1, i, x);
else modify(u << 1 | 1, i, x);
pushUp(u);
}
}
C++
// 维护区间和的线段树
// 原序列长度为n 序列输入至a[1...n]
int a[N];
struct Node {
int l, r, sum;
} tr[N * 4]; // 线段树节点数组长度一般开到4n
// 用子节点的信息更新当前节点信息
void push_up(int u) {
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
// 在区间[l...r]上初始化线段树 u为当前段根节点位置
void build(int u, int l, int r) {
tr[u].l = l, tr[u].r = r;
if (l == r) return;
int m = l + r >> 1;
build(u << 1, l, m);
build(u << 1 | 1, m + 1, r);
push_up(u);
}
// 查询[l...r] u为当前段根节点位置
int query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
int m = tr[u].l + tr[u].r >> 1;
int res = 0;
if (m >= l) res += query(u << 1, l, r);
if (m + 1 <= r) res += query(u << 1 | 1, l, r);
return res;
}
// 把位置i的数字修改为x u为当前段根节点位置
void modify(int u, int i, int x) {
if (tr[u].l == i && tr[u].r == i) tr[u].sum = x;
else {
int m = tr[u].l + tr[u].r >> 1;
if (m >= i) modify(u << 1, i, x);
else modify(u << 1 | 1, i, x);
push_up(u);
}
}
// 维护区间和的线段树
// 原序列长度为n 序列输入至a[1...n]
int a[N];
struct Node {
int l, r, sum;
} tr[N * 4]; // 线段树节点数组长度一般开到4n
// 用子节点的信息更新当前节点信息
void push_up(int u) {
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
// 在区间[l...r]上初始化线段树 u为当前段根节点位置
void build(int u, int l, int r) {
tr[u].l = l, tr[u].r = r;
if (l == r) return;
int m = l + r >> 1;
build(u << 1, l, m);
build(u << 1 | 1, m + 1, r);
push_up(u);
}
// 查询[l...r] u为当前段根节点位置
int query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
int m = tr[u].l + tr[u].r >> 1;
int res = 0;
if (m >= l) res += query(u << 1, l, r);
if (m + 1 <= r) res += query(u << 1 | 1, l, r);
return res;
}
// 把位置i的数字修改为x u为当前段根节点位置
void modify(int u, int i, int x) {
if (tr[u].l == i && tr[u].r == i) tr[u].sum = x;
else {
int m = tr[u].l + tr[u].r >> 1;
if (m >= i) modify(u << 1, i, x);
else modify(u << 1 | 1, i, x);
push_up(u);
}
}
可以区间修改的线段树与懒标记
模板题:一个简单的整数问题2。
带懒标记的线段树模板,以区间加操作,查询区间和为例:
C++
int a[N];
struct Node {
int l, r;
int sum, lazy;
} tr[N * 4];
void push_up(int u) {
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
// 用当前节点的信息更新子节点信息
void push_down(int u) {
int& lazy = tr[u].lazy;
if (lazy == 0) return;
Node& ln = tr[u << 1], & rn = tr[u << 1 | 1];
ln.lazy += lazy, ln.sum += (ln.r - ln.l + 1) * lazy;
rn.lazy += lazy, rn.sum += (rn.r - rn.l + 1) * lazy;
lazy = 0;
}
void build(int u, int l, int r) {
tr[u].l = l, tr[u].r = r;
if (l == r) tr[u].sum = a[l];
else {
int m = l + r >> 1;
build(u << 1, l, m);
build(u << 1 | 1, m + 1, r);
push_up(u);
}
}
void modify(int u, int l, int r, int x) {
if (tr[u].l >= l && tr[u].r <= r) {
tr[u].sum += (tr[u].r - tr[u].l + 1) * x;
tr[u].lazy += x;
} else {
push_down(u);
int m = tr[u].l + tr[u].r >> 1;
if (m >= l) modify(u << 1, l, r, x);
if (m + 1 <= r) modify(u << 1 | 1, l, r, x);
push_up(u);
}
}
int query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
push_down(u);
int m = tr[u].l + tr[u].r >> 1;
int res = 0;
if (m >= l) res += query(u << 1, l, r);
if (m + 1 <= r) res += query(u << 1 | 1, l, r);
return res;
}
int a[N];
struct Node {
int l, r;
int sum, lazy;
} tr[N * 4];
void push_up(int u) {
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
// 用当前节点的信息更新子节点信息
void push_down(int u) {
int& lazy = tr[u].lazy;
if (lazy == 0) return;
Node& ln = tr[u << 1], & rn = tr[u << 1 | 1];
ln.lazy += lazy, ln.sum += (ln.r - ln.l + 1) * lazy;
rn.lazy += lazy, rn.sum += (rn.r - rn.l + 1) * lazy;
lazy = 0;
}
void build(int u, int l, int r) {
tr[u].l = l, tr[u].r = r;
if (l == r) tr[u].sum = a[l];
else {
int m = l + r >> 1;
build(u << 1, l, m);
build(u << 1 | 1, m + 1, r);
push_up(u);
}
}
void modify(int u, int l, int r, int x) {
if (tr[u].l >= l && tr[u].r <= r) {
tr[u].sum += (tr[u].r - tr[u].l + 1) * x;
tr[u].lazy += x;
} else {
push_down(u);
int m = tr[u].l + tr[u].r >> 1;
if (m >= l) modify(u << 1, l, r, x);
if (m + 1 <= r) modify(u << 1 | 1, l, r, x);
push_up(u);
}
}
int query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
push_down(u);
int m = tr[u].l + tr[u].r >> 1;
int res = 0;
if (m >= l) res += query(u << 1, l, r);
if (m + 1 <= r) res += query(u << 1 | 1, l, r);
return res;
}
区间加、乘,查询区间和模数的线段树模板
C++
int a[N];
struct Node {
int l, r;
LL sum, add, mul = 1LL;
} tr[N * 4];
void push_up(int u) {
tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % p;
}
// 把mul,add懒标记应用到区间节点u上(p为模数)
void apply(Node& u, LL mul, LL add) {
u.add = (u.add * mul % p + add) % p;
u.mul = u.mul * mul % p;
u.sum = (u.sum * mul % p + (u.r - u.l + 1) * add % p) % p;
}
void push_down(int u) {
if (tr[u].add == 0LL && tr[u].mul == 1LL) return;
apply(tr[u << 1], tr[u].mul, tr[u].add);
apply(tr[u << 1 | 1], tr[u].mul, tr[u].add);
tr[u].add = 0LL, tr[u].mul = 1LL;
}
void build(int u, int l, int r) {
tr[u].l = l, tr[u].r = r;
if (l == r) tr[u].sum = a[l];
else {
int m = l + r >> 1;
build(u << 1, l, m);
build(u << 1 | 1, m + 1, r);
push_up(u);
}
}
// [l, r] * mul + add
void modify(int u, int l, int r, LL mul, LL add) {
if (tr[u].l >= l && tr[u].r <= r) {
apply(tr[u], mul, add);
} else {
push_down(u);
int m = tr[u].l + tr[u].r >> 1;
if (m >= l) modify(u << 1, l, r, mul, add);
if (m + 1 <= r) modify(u << 1 | 1, l, r, mul, add);
push_up(u);
}
}
LL query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
push_down(u);
int m = tr[u].l + tr[u].r >> 1;
LL res = 0LL;
if (m >= l) res = (res + query(u << 1, l, r)) % p;
if (m + 1 <= r) res = (res + query(u << 1 | 1, l, r)) % p;
return res;
}
int a[N];
struct Node {
int l, r;
LL sum, add, mul = 1LL;
} tr[N * 4];
void push_up(int u) {
tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % p;
}
// 把mul,add懒标记应用到区间节点u上(p为模数)
void apply(Node& u, LL mul, LL add) {
u.add = (u.add * mul % p + add) % p;
u.mul = u.mul * mul % p;
u.sum = (u.sum * mul % p + (u.r - u.l + 1) * add % p) % p;
}
void push_down(int u) {
if (tr[u].add == 0LL && tr[u].mul == 1LL) return;
apply(tr[u << 1], tr[u].mul, tr[u].add);
apply(tr[u << 1 | 1], tr[u].mul, tr[u].add);
tr[u].add = 0LL, tr[u].mul = 1LL;
}
void build(int u, int l, int r) {
tr[u].l = l, tr[u].r = r;
if (l == r) tr[u].sum = a[l];
else {
int m = l + r >> 1;
build(u << 1, l, m);
build(u << 1 | 1, m + 1, r);
push_up(u);
}
}
// [l, r] * mul + add
void modify(int u, int l, int r, LL mul, LL add) {
if (tr[u].l >= l && tr[u].r <= r) {
apply(tr[u], mul, add);
} else {
push_down(u);
int m = tr[u].l + tr[u].r >> 1;
if (m >= l) modify(u << 1, l, r, mul, add);
if (m + 1 <= r) modify(u << 1 | 1, l, r, mul, add);
push_up(u);
}
}
LL query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
push_down(u);
int m = tr[u].l + tr[u].r >> 1;
LL res = 0LL;
if (m >= l) res = (res + query(u << 1, l, r)) % p;
if (m + 1 <= r) res = (res + query(u << 1 | 1, l, r)) % p;
return res;
}