Skip to content

线段树

最基本的线段树

模板题:动态求连续区间和

线段树模板,以维护区间和为例:

Java
// 维护区间和的线段树
// 序列长度为n 序列输入至a[1...n]
int[] a;
class Node {
    int l, r, sum;

    public Node(int l, int r, int sum) {
        this.l = l;
        this.r = r;
        this.sum = sum;
    }
}
// 线段树节点数组长度一般开到4n
Node[] tr;

// 用子节点的信息更新当前节点信息
void pushUp(int u) {
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

// 在区间[l...r]上初始化线段树 u为当前段根节点位置
void build(int u, int l, int r) {
    if (l == r) tr[u] = new Node(l, r, a[r]);
    else {
        tr[u] = new Node(l, r, 0);
        int m = l + r >> 1;
        build(u << 1, l, m);
        build(u << 1 | 1, m + 1, r);
        pushUp(u);
    }
}

// 查询[l...r] u为当前段根节点位置
int query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
    else {
        int m = tr[u].l + tr[u].r >> 1;
        int sum = 0;
        if (m >= l) sum += query(u << 1, l, r);
        if (m + 1 <= r) sum += query(u << 1 | 1, l, r);
        return sum;
    }
}

// 把位置i的数字修改为x u为当前段根节点位置
void modify(int u, int i, int x) {
    if (tr[u].l == tr[u].r) tr[u].sum = x;
    else {
        int m = tr[u].l + tr[u].r >> 1;
        if (m >= i) modify(u << 1, i, x);
        else modify(u << 1 | 1, i, x);
        pushUp(u);
    }
}
// 维护区间和的线段树
// 序列长度为n 序列输入至a[1...n]
int[] a;
class Node {
    int l, r, sum;

    public Node(int l, int r, int sum) {
        this.l = l;
        this.r = r;
        this.sum = sum;
    }
}
// 线段树节点数组长度一般开到4n
Node[] tr;

// 用子节点的信息更新当前节点信息
void pushUp(int u) {
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

// 在区间[l...r]上初始化线段树 u为当前段根节点位置
void build(int u, int l, int r) {
    if (l == r) tr[u] = new Node(l, r, a[r]);
    else {
        tr[u] = new Node(l, r, 0);
        int m = l + r >> 1;
        build(u << 1, l, m);
        build(u << 1 | 1, m + 1, r);
        pushUp(u);
    }
}

// 查询[l...r] u为当前段根节点位置
int query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
    else {
        int m = tr[u].l + tr[u].r >> 1;
        int sum = 0;
        if (m >= l) sum += query(u << 1, l, r);
        if (m + 1 <= r) sum += query(u << 1 | 1, l, r);
        return sum;
    }
}

// 把位置i的数字修改为x u为当前段根节点位置
void modify(int u, int i, int x) {
    if (tr[u].l == tr[u].r) tr[u].sum = x;
    else {
        int m = tr[u].l + tr[u].r >> 1;
        if (m >= i) modify(u << 1, i, x);
        else modify(u << 1 | 1, i, x);
        pushUp(u);
    }
}
C++
// 维护区间和的线段树
// 原序列长度为n 序列输入至a[1...n]
int a[N];
struct Node {
    int l, r, sum;
} tr[N * 4]; // 线段树节点数组长度一般开到4n

// 用子节点的信息更新当前节点信息
void push_up(int u) {
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

// 在区间[l...r]上初始化线段树 u为当前段根节点位置
void build(int u, int l, int r) {
    tr[u].l = l, tr[u].r = r;
    if (l == r) return;
    int m = l + r >> 1;
    build(u << 1, l, m);
    build(u << 1 | 1, m + 1, r);
    push_up(u);
}

// 查询[l...r] u为当前段根节点位置
int query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
    int m = tr[u].l + tr[u].r >> 1;
    int res = 0;
    if (m >= l) res += query(u << 1, l, r);
    if (m + 1 <= r) res += query(u << 1 | 1, l, r);
    return res;
}

// 把位置i的数字修改为x u为当前段根节点位置
void modify(int u, int i, int x) {
    if (tr[u].l == i && tr[u].r == i) tr[u].sum = x;
    else {
        int m = tr[u].l + tr[u].r >> 1;
        if (m >= i) modify(u << 1, i, x);
        else modify(u << 1 | 1, i, x);
        push_up(u);
    }
}
// 维护区间和的线段树
// 原序列长度为n 序列输入至a[1...n]
int a[N];
struct Node {
    int l, r, sum;
} tr[N * 4]; // 线段树节点数组长度一般开到4n

// 用子节点的信息更新当前节点信息
void push_up(int u) {
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

// 在区间[l...r]上初始化线段树 u为当前段根节点位置
void build(int u, int l, int r) {
    tr[u].l = l, tr[u].r = r;
    if (l == r) return;
    int m = l + r >> 1;
    build(u << 1, l, m);
    build(u << 1 | 1, m + 1, r);
    push_up(u);
}

// 查询[l...r] u为当前段根节点位置
int query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
    int m = tr[u].l + tr[u].r >> 1;
    int res = 0;
    if (m >= l) res += query(u << 1, l, r);
    if (m + 1 <= r) res += query(u << 1 | 1, l, r);
    return res;
}

// 把位置i的数字修改为x u为当前段根节点位置
void modify(int u, int i, int x) {
    if (tr[u].l == i && tr[u].r == i) tr[u].sum = x;
    else {
        int m = tr[u].l + tr[u].r >> 1;
        if (m >= i) modify(u << 1, i, x);
        else modify(u << 1 | 1, i, x);
        push_up(u);
    }
}

可以区间修改的线段树与懒标记

模板题:一个简单的整数问题2

带懒标记的线段树模板,以区间加操作,查询区间和为例:

C++
int a[N];
struct Node {
    int l, r;
    int sum, lazy;
} tr[N * 4];

void push_up(int u) {
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

// 用当前节点的信息更新子节点信息
void push_down(int u) {
    int& lazy = tr[u].lazy;
    if (lazy == 0) return;
    Node& ln = tr[u << 1], & rn = tr[u << 1 | 1];
    ln.lazy += lazy, ln.sum += (ln.r - ln.l + 1) * lazy;
    rn.lazy += lazy, rn.sum += (rn.r - rn.l + 1) * lazy;
    lazy = 0;
}

void build(int u, int l, int r) {
    tr[u].l = l, tr[u].r = r;
    if (l == r) tr[u].sum = a[l];
    else {
        int m = l + r >> 1;
        build(u << 1, l, m);
        build(u << 1 | 1, m + 1, r);
        push_up(u);
    }
}

void modify(int u, int l, int r, int x) {
    if (tr[u].l >= l && tr[u].r <= r) {
        tr[u].sum += (tr[u].r - tr[u].l + 1) * x;
        tr[u].lazy += x;
    } else {
        push_down(u);
        int m = tr[u].l + tr[u].r >> 1;
        if (m >= l) modify(u << 1, l, r, x);
        if (m + 1 <= r) modify(u << 1 | 1, l, r, x);
        push_up(u);
    }
}

int query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
    push_down(u);
    int m = tr[u].l + tr[u].r >> 1;
    int res = 0;
    if (m >= l) res += query(u << 1, l, r);
    if (m + 1 <= r) res += query(u << 1 | 1, l, r);
    return res;
}
int a[N];
struct Node {
    int l, r;
    int sum, lazy;
} tr[N * 4];

void push_up(int u) {
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

// 用当前节点的信息更新子节点信息
void push_down(int u) {
    int& lazy = tr[u].lazy;
    if (lazy == 0) return;
    Node& ln = tr[u << 1], & rn = tr[u << 1 | 1];
    ln.lazy += lazy, ln.sum += (ln.r - ln.l + 1) * lazy;
    rn.lazy += lazy, rn.sum += (rn.r - rn.l + 1) * lazy;
    lazy = 0;
}

void build(int u, int l, int r) {
    tr[u].l = l, tr[u].r = r;
    if (l == r) tr[u].sum = a[l];
    else {
        int m = l + r >> 1;
        build(u << 1, l, m);
        build(u << 1 | 1, m + 1, r);
        push_up(u);
    }
}

void modify(int u, int l, int r, int x) {
    if (tr[u].l >= l && tr[u].r <= r) {
        tr[u].sum += (tr[u].r - tr[u].l + 1) * x;
        tr[u].lazy += x;
    } else {
        push_down(u);
        int m = tr[u].l + tr[u].r >> 1;
        if (m >= l) modify(u << 1, l, r, x);
        if (m + 1 <= r) modify(u << 1 | 1, l, r, x);
        push_up(u);
    }
}

int query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
    push_down(u);
    int m = tr[u].l + tr[u].r >> 1;
    int res = 0;
    if (m >= l) res += query(u << 1, l, r);
    if (m + 1 <= r) res += query(u << 1 | 1, l, r);
    return res;
}
区间加、乘,查询区间和模数的线段树模板
C++
int a[N];
struct Node {
    int l, r;
    LL sum, add, mul = 1LL;
} tr[N * 4];

void push_up(int u) {
    tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % p;
}

// 把mul,add懒标记应用到区间节点u上(p为模数)
void apply(Node& u, LL mul, LL add) {
    u.add = (u.add * mul % p + add) % p;
    u.mul = u.mul * mul % p;
    u.sum = (u.sum * mul % p + (u.r - u.l + 1) * add % p) % p;
}

void push_down(int u) {
    if (tr[u].add == 0LL && tr[u].mul == 1LL) return;
    apply(tr[u << 1], tr[u].mul, tr[u].add);
    apply(tr[u << 1 | 1], tr[u].mul, tr[u].add);
    tr[u].add = 0LL, tr[u].mul = 1LL;
}

void build(int u, int l, int r) {
    tr[u].l = l, tr[u].r = r;
    if (l == r) tr[u].sum = a[l];
    else {
        int m = l + r >> 1;
        build(u << 1, l, m);
        build(u << 1 | 1, m + 1, r);
        push_up(u);
    }
}

// [l, r] * mul + add
void modify(int u, int l, int r, LL mul, LL add) {
    if (tr[u].l >= l && tr[u].r <= r) {
        apply(tr[u], mul, add);
    } else {
        push_down(u);
        int m = tr[u].l + tr[u].r >> 1;
        if (m >= l) modify(u << 1, l, r, mul, add);
        if (m + 1 <= r) modify(u << 1 | 1, l, r, mul, add);
        push_up(u);
    }
}

LL query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
    push_down(u);
    int m = tr[u].l + tr[u].r >> 1;
    LL res = 0LL;
    if (m >= l) res = (res + query(u << 1, l, r)) % p;
    if (m + 1 <= r) res = (res + query(u << 1 | 1, l, r)) % p;
    return res;
}
int a[N];
struct Node {
    int l, r;
    LL sum, add, mul = 1LL;
} tr[N * 4];

void push_up(int u) {
    tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % p;
}

// 把mul,add懒标记应用到区间节点u上(p为模数)
void apply(Node& u, LL mul, LL add) {
    u.add = (u.add * mul % p + add) % p;
    u.mul = u.mul * mul % p;
    u.sum = (u.sum * mul % p + (u.r - u.l + 1) * add % p) % p;
}

void push_down(int u) {
    if (tr[u].add == 0LL && tr[u].mul == 1LL) return;
    apply(tr[u << 1], tr[u].mul, tr[u].add);
    apply(tr[u << 1 | 1], tr[u].mul, tr[u].add);
    tr[u].add = 0LL, tr[u].mul = 1LL;
}

void build(int u, int l, int r) {
    tr[u].l = l, tr[u].r = r;
    if (l == r) tr[u].sum = a[l];
    else {
        int m = l + r >> 1;
        build(u << 1, l, m);
        build(u << 1 | 1, m + 1, r);
        push_up(u);
    }
}

// [l, r] * mul + add
void modify(int u, int l, int r, LL mul, LL add) {
    if (tr[u].l >= l && tr[u].r <= r) {
        apply(tr[u], mul, add);
    } else {
        push_down(u);
        int m = tr[u].l + tr[u].r >> 1;
        if (m >= l) modify(u << 1, l, r, mul, add);
        if (m + 1 <= r) modify(u << 1 | 1, l, r, mul, add);
        push_up(u);
    }
}

LL query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
    push_down(u);
    int m = tr[u].l + tr[u].r >> 1;
    LL res = 0LL;
    if (m >= l) res = (res + query(u << 1, l, r)) % p;
    if (m + 1 <= r) res = (res + query(u << 1 | 1, l, r)) % p;
    return res;
}