백준 - JAVA/세그먼트 트리
[백준] 2042번 구간 합 구하기 _ JAVA ( 주석 설명 )
wch_s
2023. 3. 30. 20:38
풀이
- 세그먼트 트리 구현
https://book.acmicpc.net/ds/segment-tree
세그먼트 트리
누적 합을 사용하면, 1번 연산의 시간 복잡도를 $O(1)$로 줄일 수 있습니다. 하지만, 2번 연산으로 수가 변경될 때마다 누적 합을 다시 구해야 하기 때문에, 2번 연산의 시간 복잡도는 $O(N)$입니다.
book.acmicpc.net
코드
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;
public class Main {
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
int N = Integer.parseInt(st.nextToken()); //수의 개수
int M = Integer.parseInt(st.nextToken()); //수의 변경이 일어나는 횟수
int K = Integer.parseInt(st.nextToken()); //구간의 합을 구하는 횟수
long[] ary = new long[N]; //tree로의 값 복사를 위한 ary
int h = (int) Math.ceil(Math.log(N)/Math.log(2)); //트리의 높이
int tree_size = 1 << (h+1); //트리의 크기 : 2^(h+1)
//tree: 세그먼트 트리
// - 트리의 정보를 저장하기 위해 배열 사용
// - 깊이가 가장 깊은 리프 노드와 가장 깊지 않은 리프 노드의 깊이 차이가 1 이하이므로, 공간 낭비가 적음
long[] tree = new long[tree_size];
for(int i=0;i<N;++i){
ary[i] = Long.parseLong((br.readLine()));
}
//최상위노드(1)에 0~N-1까지의 구간합 정보,
//하위 노드에 각 구간합 정보를 재귀적으로 불러 segment_tree를 그린다.
init(ary, tree, 1, 0, N-1);
StringBuilder sb = new StringBuilder();
for(int i=0;i<M+K;++i){ //수의 변경 or 구간의 합
st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken());
//b번째 수를 c로 바꾸기
if(a==1){
//입력값이 long 범위
long c = Long.parseLong(st.nextToken());
//1번째 수가 0인덱스에 있음 → b-1
//N을 매개변수로 쓰는 이유
//b번째 수를 바꿀 때 해당 노드를 포함하는 모든 노드의 합을 변경해주기 위함
update(ary, tree, N, b-1, c);
}
//b번째 수부터 c번째 수까지의 합 구하기
//[left,right]가 [start,end]를 완전히 포함하는 경우 1노드 반환
if(a==2){
//입력값이 N이하 범위
int c = Integer.parseInt(st.nextToken());
long query = query(tree, 1, 0, N - 1, b-1, c-1);
sb.append(query).append('\n');
}
}
System.out.println(sb);
}
//a: 크기가 N인 정수 배열
//tree: 세그먼트 트리
// - 트리의 정보를 저장하기 위해 배열 사용
// - 깊이가 가장 깊은 리프 노드와 가장 깊지 않은 리프 노드의 깊이 차이가 1 이하이므로, 공간 낭비가 적음
//node: 트리의 노드 번호
//start-end: 노드에 저장되어 있는 합의 범위, 노드에 저장된 구간
static void init(long[] a, long[] tree, int node, int start, int end){
//리프노드일 경우
if(start==end){
tree[node] = a[start]; //배열의 그 수를 바로 저장한다.
}
else{
//노드의 왼쪽 자식의 번호 : node*2
//[start ~ (start+end)/2]의 합이 저장된 구간
//노드의 오른쪽 자식의 번호 : node*2+1
//[(start+end)/2+1 ~ end]의 합이 저장된 구간
init(a, tree, node*2, start, (start+end)/2);
init(a, tree, node*2+1, (start+end)/2+1, end);
//왼쪽 자식 값 + 오른쪽 자식 값을 먼저 구해야 한다.
tree[node] = tree[node*2] + tree[node*2+1];
}
}
static long query(long[] tree, int node, int start, int end, int left, int right){
//[left,right]와 [start,end]가 겹치지 않는 경우
if(right<start || end<left) { //각각 왼쪽, 오른쪽으로 start~end 범위를 벗어난 경우
return 0;
}
//[left,right]가 [start,end]를 완전히 포함하는 경우
if(left<=start && end<=right){ //left~right가 start~end보다 범위가 큰 경우
return tree[node];
}
//[start,end]가 [left,right]를 완전히 포함하는 경우
//[left,right]와 [start,end]가 겹쳐져 있는 경우 (1, 2, 3 제외한 나머지 경우)
//왼쪽 자식과 오른쪽 자식을 루트로 하는 트리에서 다시 탐색을 시작해야 합니다.
long lsum = query(tree, node*2, start, (start+end)/2, left, right);
long rsum = query(tree, node*2+1,(start+end)/2+1, end, left, right);
return lsum + rsum;
}
//수 변경하기
//index번째 수를 val로 변경, index번째를 포함하는 노드에 들어있는 합만 변경
static void update(long[] a, long[] tree, int n, int index, long val){
long diff = val - a[index]; //구간합은 이만큼 변한다.
a[index] = val;
//최상위 노드는 update 시, 무조건 영향을 받음
update_tree(tree, 1, 0, n-1, index, diff);
}
static void update_tree(long[] tree, int node, int start, int end, int index, long diff){
//[start,end]에 index가 포함되지 않는 경우
if(index<start || index>end)
return;
//차이만큼 더해주기
tree[node] += diff;
//index번째를 포함하는 모든 노드의 합에 diff를 더해서 수를 변경해준다.
if(start!=end){
update_tree(tree, node*2, start, (start+end)/2, index, diff);
update_tree(tree, node*2+1, (start+end)/2+1, end, index, diff);
}
}
}
update 다른 방식
update(ary, tree, 1, 0, N-1, b-1, c);
//수 변경하기
//index번째 수를 val로 변경
//리프 노드를 찾을 때까지 계속 재귀 호출
//리프 노드일 경우 해당 리프 노드 변경
//이후 리턴 될 때마다 각 노드의 합을 자식에 저장된 합을 이용해 변경
static void update(long[] a, long[] tree, int node, int start, int end, int index, long val){
if(index<start || index>end)
return;
//리프노드일 경우
if(start==end){
a[index] = val;
tree[node] = val;
return;
}
update(a, tree, node*2, start, (start+end)/2, index, val);
update(a, tree, node*2+1, (start+end)/2+1, end, index, val);
tree[node] = tree[node*2] + tree[node*2+1];
}