세그먼트 트리
세그먼트 트리
세그먼트 트리는 어떤 구간 내의 값 들의 연속 연산을 다룰 수 있는 트리 기반 알고리즘이다.
누적 합을 그냥 더하게 된다면 수가 변경될 때마다 누적 합을 다시 구해야 하기 때문에 연산 당 시간 복잡도가
이 되어버린다. 이를 해결하기 위해 세그먼트 트리를 만들어 연산당 시간복잡도를
수준으로 낮출 수 있다.
세그먼트 트리는 트리에 자신의 아래 자식들을 모두 더한 값을 저장하고 있다. 예를 들어 0부터 5 까지의 수로 예를 들어보면
mermaid
graph TD
A["0-5"]
B["0-2"]
C["3-5"]
D["0-1"]
E["2"]
F["3-4"]
G["5"]
H["0"]
I["1"]
J["3"]
K["4"]
A --> B
A --> C
B --> D
B --> E
D --> H
D --> I
C --> F
C --> G
F --> J
F --> K
이런 식으로 세그먼트 트리에 저장되게 된다.
맨 위의 노드는 0번부터 5번까지 모든 값의 합을 더한 값이 저장되어있고, 그 아래 두 노드는 각각 0부터 2까지의 합, 3부터 5까지의 합을 각각 저장하고 있다.
이 그래프에서 2 부터 4 까지 부분 합을 구한다면
mermaid
graph TD
A["0-5"]
B["0-2"]
C["3-5"]
D["0-1"]
E["2"]
F["3-4"]
G["5"]
H["0"]
I["1"]
J["3"]
K["4"]
A --> B
A --> C
B --> D
B --> E
D --> H
D --> I
C --> F
C --> G
F --> J
F --> K
style E fill:#ccf
style F fill:#ccf
위의 색칠된 2개를 더하면 된다. 이것을 재귀 함수로 잘 구현하는 것이 중요하다. 재귀함수를 호출 할 때 현재 노드의 번호를 n이라고 하면 왼쪽 자식 노드는 2n, 오른쪽 자식 노드는 2n+1로 호출한다.
아래는 부분합을 구하는 세그먼트 트리를 C++로 구현한 코드이다.
c++
vector<long long> num;
vector<long long> segTree;
// 세그먼트 트리 만드는 함수
void build(int node, int start, int end) {
if(start == end) {
segTree[node] = num[start];
// leaf nodes (단일 숫자)
} else {
int mid = (start + end) / 2;
build(node * 2, start, mid); // 왼쪽 부분
build(node * 2 + 1, mid + 1, end); // 오른쪽 부분
segTree[node] = segTree[node * 2] + segTree[node * 2 + 1]; // 왼쪽 + 오른쪽
}
}이렇게 트리를 만든 뒤에는 재귀 함수를 사용하여 탐색하면 된다.
아래는 세그먼트 트리를 C++로 구현한 코드이다.
c++
long long query(int node, int start, int end, int left, int right){
if(left > end || right < start) {
return 0;
} // 범위를 벗어나면 0을 반환
if(left <= start && end <= right) {
return segTree[node];
} // 범위 안에 있는 합이라면 그 값을 그대로 반환
// ex) 1부터 9까지의 합 중 5부터 9까지의 합이 저장되어있는 노드일 경우 그대로 반환
long long lsum = query(node*2, start, (start+end)/2, left,right); // 왼쪽 자식 탐색
long long rsum = query(node*2+1, (start+end)/2 + 1, end, left,right); // 오른쪽 자식 탐색
return lsum+rsum; // 좌우를 탐색하고 그 합을 반환
}탐색 알고리즘의 기본은 현재 있는 노드의 양쪽 자식 노드를 재귀 함수로 호출하고, 만약 leaf 노드에 도달하면 그 값부터 순서대로 반환해나가기 시작한다.
예시) 백준 #2042 구간 합 구하기
c++
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
int n,m,k;
vector<long long> num;
vector<long long> segTree;
void build(int node, int start, int end) {
if(start == end) {
segTree[node] = num[start];
} else {
int mid = (start + end) / 2;
build(node * 2, start, mid);
build(node * 2 + 1, mid + 1, end);
segTree[node] = segTree[node * 2] + segTree[node * 2 + 1];
}
}
long long query(int node, int start, int end, int left, int right){
if(left > end || right < start) {
return 0;
}
if(left <= start && end <= right) {
return segTree[node];
}
long long lsum = query(node*2, start, (start+end)/2, left,right);
long long rsum = query(node*2+1, (start+end)/2 + 1, end, left,right);
return lsum+rsum;
}
void update(int node, int start, int end, int idx, long long newValue) {
if(idx < start || idx > end) return;
if(start == end) {
segTree[node] = newValue;
num[idx] = newValue;
return;
}
int mid = (start + end) / 2;
if(idx <= mid)
update(node * 2, start, mid, idx, newValue);
else
update(node * 2 + 1, mid + 1, end, idx, newValue);
segTree[node] = segTree[node * 2] + segTree[node * 2 + 1];
}
void input() {
cin>>n>>m>>k;
num.resize(n);
for(int i = 0 ; i < n ; i++) {
cin>>num[i];
}
}
void solve() {
segTree.assign(4*n, 0);
build(1, 0, n-1);
for(int i = 0 ; i < m + k ; i++) {
int first;
long long x,y;
cin>>first;
if(first == 1) {
cin>>x>>y;
update(1,0,n-1,x-1,y);
} else if(first == 2) {
cin>>x>>y;
cout<<query(1,0,n-1,x-1,y-1)<<endl;
}
}
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
input();
solve();
}