알고리즘

[백준] 2042번 구간 합 구하기 _ JAVA ( 주석 설명 ) 본문

백준 - JAVA/세그먼트 트리

[백준] 2042번 구간 합 구하기 _ JAVA ( 주석 설명 )

wch_s 2023. 3. 30. 20:38

풀이

- 세그먼트 트리 구현

https://book.acmicpc.net/ds/segment-tree

 

세그먼트 트리

누적 합을 사용하면, 1번 연산의 시간 복잡도를 $O(1)$로 줄일 수 있습니다. 하지만, 2번 연산으로 수가 변경될 때마다 누적 합을 다시 구해야 하기 때문에, 2번 연산의 시간 복잡도는 $O(N)$입니다.

book.acmicpc.net

 

 

코드

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());

        int N = Integer.parseInt(st.nextToken()); //수의 개수
        int M = Integer.parseInt(st.nextToken()); //수의 변경이 일어나는 횟수
        int K = Integer.parseInt(st.nextToken()); //구간의 합을 구하는 횟수

        long[] ary = new long[N]; //tree로의 값 복사를 위한 ary
        int h = (int) Math.ceil(Math.log(N)/Math.log(2)); //트리의 높이
        int tree_size = 1 << (h+1); //트리의 크기 : 2^(h+1)
        //tree: 세그먼트 트리
        // - 트리의 정보를 저장하기 위해 배열 사용
        // - 깊이가 가장 깊은 리프 노드와 가장 깊지 않은 리프 노드의 깊이 차이가 1 이하이므로, 공간 낭비가 적음
        long[] tree = new long[tree_size];
        for(int i=0;i<N;++i){
            ary[i] = Long.parseLong((br.readLine()));
        }

        //최상위노드(1)에 0~N-1까지의 구간합 정보,
        //하위 노드에 각 구간합 정보를 재귀적으로 불러 segment_tree를 그린다.
        init(ary, tree, 1, 0, N-1);

        StringBuilder sb = new StringBuilder();
        for(int i=0;i<M+K;++i){ //수의 변경 or 구간의 합
            st = new StringTokenizer(br.readLine());

            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            
            //b번째 수를 c로 바꾸기
            if(a==1){
                //입력값이 long 범위
                long c = Long.parseLong(st.nextToken());
                //1번째 수가 0인덱스에 있음 → b-1
                //N을 매개변수로 쓰는 이유
                    //b번째 수를 바꿀 때 해당 노드를 포함하는 모든 노드의 합을 변경해주기 위함
                update(ary, tree, N, b-1, c);
            }

            //b번째 수부터 c번째 수까지의 합 구하기
            //[left,right]가 [start,end]를 완전히 포함하는 경우 1노드 반환
            if(a==2){
                //입력값이 N이하 범위
                int c = Integer.parseInt(st.nextToken());
                long query = query(tree, 1, 0, N - 1, b-1, c-1);
                sb.append(query).append('\n');
            }
        }
        System.out.println(sb);
    }

    //a: 크기가 N인 정수 배열
    //tree: 세그먼트 트리
        // - 트리의 정보를 저장하기 위해 배열 사용
        // - 깊이가 가장 깊은 리프 노드와 가장 깊지 않은 리프 노드의 깊이 차이가 1 이하이므로, 공간 낭비가 적음
    //node: 트리의 노드 번호
    //start-end: 노드에 저장되어 있는 합의 범위, 노드에 저장된 구간
    static void init(long[] a, long[] tree, int node, int start, int end){
        //리프노드일 경우
        if(start==end){
            tree[node] = a[start]; //배열의 그 수를 바로 저장한다.
        }

        else{
            //노드의 왼쪽 자식의 번호 : node*2
                //[start ~ (start+end)/2]의 합이 저장된 구간
            //노드의 오른쪽 자식의 번호 : node*2+1
                //[(start+end)/2+1 ~ end]의 합이 저장된 구간
            init(a, tree, node*2, start, (start+end)/2);
            init(a, tree, node*2+1, (start+end)/2+1, end);

            //왼쪽 자식 값 + 오른쪽 자식 값을 먼저 구해야 한다.
            tree[node] = tree[node*2] + tree[node*2+1];
        }
    }

    static long query(long[] tree, int node, int start, int end, int left, int right){
        //[left,right]와 [start,end]가 겹치지 않는 경우
        if(right<start || end<left) { //각각 왼쪽, 오른쪽으로 start~end 범위를 벗어난 경우
            return 0;
        }
        //[left,right]가 [start,end]를 완전히 포함하는 경우
        if(left<=start && end<=right){ //left~right가 start~end보다 범위가 큰 경우
            return tree[node];
        }

        //[start,end]가 [left,right]를 완전히 포함하는 경우
        //[left,right]와 [start,end]가 겹쳐져 있는 경우 (1, 2, 3 제외한 나머지 경우)

        //왼쪽 자식과 오른쪽 자식을 루트로 하는 트리에서 다시 탐색을 시작해야 합니다.
        long lsum = query(tree, node*2, start, (start+end)/2, left, right);
        long rsum = query(tree, node*2+1,(start+end)/2+1, end, left, right);
        return lsum + rsum;
    }

    //수 변경하기
    //index번째 수를 val로 변경, index번째를 포함하는 노드에 들어있는 합만 변경
    static void update(long[] a, long[] tree, int n, int index, long val){
        long diff = val - a[index]; //구간합은 이만큼 변한다.
        a[index] = val;
        //최상위 노드는 update 시, 무조건 영향을 받음
        update_tree(tree, 1, 0, n-1, index, diff);
    }

    static void update_tree(long[] tree, int node, int start, int end, int index, long diff){
        //[start,end]에 index가 포함되지 않는 경우
        if(index<start || index>end)
            return;

        //차이만큼 더해주기
        tree[node] += diff;

        //index번째를 포함하는 모든 노드의 합에 diff를 더해서 수를 변경해준다.
        if(start!=end){
            update_tree(tree, node*2, start, (start+end)/2, index, diff);
            update_tree(tree, node*2+1, (start+end)/2+1, end, index, diff);
        }
    }
}

 

update 다른 방식

    update(ary, tree, 1, 0, N-1, b-1, c);
    
    //수 변경하기
    //index번째 수를 val로 변경
    //리프 노드를 찾을 때까지 계속 재귀 호출
        //리프 노드일 경우 해당 리프 노드 변경
        //이후 리턴 될 때마다 각 노드의 합을 자식에 저장된 합을 이용해 변경
    static void update(long[] a, long[] tree, int node, int start, int end, int index, long val){
        if(index<start || index>end)
            return;
        //리프노드일 경우
        if(start==end){
            a[index] = val;
            tree[node] = val;
            return;
        }

        update(a, tree, node*2, start, (start+end)/2, index, val);
        update(a, tree, node*2+1, (start+end)/2+1, end, index, val);

        tree[node] = tree[node*2] + tree[node*2+1];
    }