본문 바로가기
IT/Algorithm

Segment Tree

by 물통꿀꿀이 2019. 6. 7.

Segment Tree는 Tree의 일종으로 각 노드에 특정 범위에 대한 정보가 담겨있다. 

(범위를 Interval, Segment 등으로 표현한다.)


Tree의 응용은 당연스럽게 기존의 방식으로는 시간 복잡도가 높기 때문이다. 

한 가지 예를 든다면, [1,2,3,4,5]의 트리에서 (배열로 구성된 트리) 2~4 사이의 합을 알고 싶다면 어떻게 해야 할까?

누구나 접근할 수 있는 쉬운 방법은 전체를 다 훑어보는 것이다. 그렇다면 O(n) 정도의 시간 복잡도가 나온다. 그런데 만약 부분합과 같이 각 노드에 미리 합쳐진 값이 존재한다면 O(1)과 같이 상수의 시간 복잡도로 범위 값을 찾을 수 있다.


이렇듯 Segment Tree는 범위에 대한 정보를 미리 저장함으로써 다음 계산시에 사용할 수 있도록 한다. 

때문에 일반적인 트리 구조를 확장해야 한다.

먼저, 모든 입력 값은 트리의 leaf 노드가 되어야 한다. 그리고 각 범위에 정보들이 부모 노드가 된다. (부모 노드에 저장되는 값은 문제 정의에 따라 다르다.)


그림 1. Segment Tree


그런데 위의 그림에서 볼 수 있듯이, 완전한 Binary Tree가 아니기 때문에 배열로 구현할 경우 중간 중간 비어있는 노드로 인해 메모리 낭비가 발생하기도 한다.

[1,2,3,7,9,11] -> [36,9,27,4,5,16,11,1,3,null, null, 7, 9, null, null]

물론 배열 방식이 아닌 구조로 구현할 수 있겠지만 더 복잡해진다.


배열은 기본 구현 방식으로 본다면, Segment Tree의 크기는 당연히 노드의 개수보다 크다.

그 이유는 n개의 leaf 노드가 존재할 때 internal 노드 (leaf 노드를 제외한 노드)는 n-1개 이기 때문이다. 그림 1만 보아도 input 배열의 개수는 6이지만 internal 노드는 5개 인 것을 확인 할 수 있다. 그래서 Segment Tree의 전체 크기는 다음이 구할 수 있다.

1) n이 2의 거듭 제곱일 때

n이 2의 거듭 제곱이라는 것은 완전한 Binary Tree를 의미한다. 때문에 2n-1이다.

-> 2n - 1 = n (n개의 leaf 노드) + n - 1 (Internal 노드)

2) n이 2의 거듭 제곱이 아닐 때 

중간중간 null 값을 고려해야 한다. 때문에 n은 n보다 큰 2의 거듭 제곱이 되어야 한다. 그리고 갯수는 1)과 같다.

2의 거듭 제곱의 개수가 아닌 그림 1을 예로 들어보면 개수는 2 * 8 - 1 로 총 15개이다.


이를 바탕으로 세부 구현 사항을 확인해보면,

int getMid(int s, int e) { return s + (e -s)/2; }


int getSum(int *st, int n, int qs, int qe)  

{  

    // Check for erroneous input values  

    if (qs < 0 || qe > n-1 || qs > qe)  

    {  

        cout<<"Invalid Input";  

        return -1;  

    }  

  

    return getSumUtil(st, 0, n-1, qs, qe, 0);  

}  

  

// A recursive function that constructs Segment Tree for array[ss..se].  

// si is index of current node in segment tree st  

int constructSTUtil(int arr[], int ss, int se, int *st, int si)  

{  

    // If there is one element in array, store it in current node of  

    // segment tree and return  

    if (ss == se)  

    {  

        st[si] = arr[ss];  

        return arr[ss];  

    }  

  

    // If there are more than one elements, then recur for left and  

    // right subtrees and store the sum of values in this node  

    int mid = getMid(ss, se);  

    st[si] = constructSTUtil(arr, ss, mid, st, si*2+1) +  

            constructSTUtil(arr, mid+1, se, st, si*2+2);  

    return st[si];  

}   


int *constructST(int arr[], int n)  

{  

    // Allocate memory for the segment tree  

  

    //Height of segment tree  

    int x = (int)(ceil(log2(n)));  

  

    //Maximum size of segment tree  

    int max_size = 2*(int)pow(2, x) - 1;  

  

    // Allocate memory  

    int *st = new int[max_size];  

  

    // Fill the allocated memory st  

    constructSTUtil(arr, 0, n-1, st, 0);  

  

    // Return the constructed segment tree  

    return st;  

}  


위의 코드는 Segment Tree를 생성하는 코드이다. (어떻게 보면 Segment Tree의 가장 중요한 부분이다.) 흐름을 보면 다음과 같다.

1) Segment Tree에서 사용할 만큼의 배열을 만든다. 위에서도 알아보았지만 2n-1이다. (여기서는 height를 통해 값을 구하려고 log를 사용하였다.)

2) 재귀 함수를 사용하여 leaf 노드에 값을 할당하고 부모 노드로 올라오면서 자식 노드의 값을 합친다.

3) 모든 구문이 끝나면 할당된 배열에는 Segment Tree로 만들어진 값이 존재한다.


이렇듯 전처리가 된 Segment Tree를 바탕으로 특정 범위의 값을 가져오는 방법은 아래의 코드와 같다.

int getSum(int *st, int n, int qs, int qe)  

{  

    // Check for erroneous input values  

    if (qs < 0 || qe > n-1 || qs > qe)  

    {  

        cout<<"Invalid Input";  

        return -1;  

    }  

  

    return getSumUtil(st, 0, n-1, qs, qe, 0);  

}   


int getSumUtil(int *st, int ss, int se, int qs, int qe, int si)  

{  

    // If segment of this node is a part of given range, then return  

    // the sum of the segment  

    if (qs <= ss && qe >= se)  

        return st[si];  

  

    // If segment of this node is outside the given range  

    if (se < qs || ss > qe)  

        return 0;  

  

    // If a part of this segment overlaps with the given range  

    int mid = getMid(ss, se);  

    return getSumUtil(st, ss, mid, qs, qe, 2*si+1) +  

        getSumUtil(st, mid+1, se, qs, qe, 2*si+2);  

}


그런데 범위를 Segment Tree의 배열 크기로 하지 않고 말 그대로 input 배열의 크기로 범위를 잡고 있다.

잘 이해하지 못하였다면 다음 그림과 함께 보는 편이 좀 더 나을 것 같다.


그림 2. Divide


그림 2는 배열을 Left와 Right로 분할하는 과정이다. 여기서 주의 깊게 봐야하는 부분은 배열의 그림이다. 즉, [7,2,5,9,6,4,1,3,8]이 루트 노드가 되고 n/2로 점차 분할되면서 자식 노드가 된다. 그렇게 가장 마지막 노드는 각각 한 개의 값을 가지게 된다. 말 그대로 트리 구조의 형태를 띄고 있다.

이제 그림 1을 다시 보면, 루트 노드는 모든 배열의 합이고 자식 노드는 n/2로 분할된 배열의 합을 나타낸다. 결과적으로 마지막 노드는 값을 하나만 가지고 있을 뿐만아니라 input 값이 된다. 정리해보면 Segment Tree는 배열을 Divide하는 과정과 동일하다는 것이다.


이를 바탕으로 코드를 확인해보면, 

Range Start <= Segment Start && Segment End <= Range End

Segment가 Range에 포함될 수 있을 때, 해당 인덱스의 값을 가져온다. (그렇지 않으면 n/2 분할을 계속하여 포함될 때까지)

이로써 특정 범위의 값을 빠르게 가져올 수 있다. 

물론 범위에 따라 leaf 노드를 찾아야 할 때가 있다. (예를 들어 그림 1에서 1-3의 범위를 찾을 때)


마지막으로 Segment Tree에서 값을 변경할 때이다.

void updateValueUtil(int *st, int ss, int se, int i, int diff, int si)  

{  

    // Base Case: If the input index lies outside the range of  

    // this segment  

    if (i < ss || i > se)  

        return;  

  

    // If the input index is in range of this node, then update  

    // the value of the node and its children  

    st[si] = st[si] + diff;  

    if (se != ss)  

    {  

        int mid = getMid(ss, se);  

        updateValueUtil(st, ss, mid, i, diff, 2*si + 1);  

        updateValueUtil(st, mid+1, se, i, diff, 2*si + 2);  

    }  

}  

  

// The function to update a value in input array and segment tree.  

// It uses updateValueUtil() to update the value in segment tree  

void updateValue(int arr[], int *st, int n, int i, int new_val)  

{  

    // Check for erroneous input index  

    if (i < 0 || i > n-1)  

    {  

        cout<<"Invalid Input";  

        return;  

    }  

  

    // Get the difference between new value and old value  

    int diff = new_val - arr[i];  

  

    // Update the value in array  

    arr[i] = new_val;  

  

    // Update the values of nodes in segment tree  

    updateValueUtil(st, 0, n-1, i, diff, 0);  

}   


Segment Tree에서 값의 변경은 생각보다 간단하다.

Tree의 각 노드는 자식 노드의 합이므로 기존 값과 변경된 값의 차이에 대해서만 해당 범위에 넣어주면 된다.

따라서 코드를 보면 diff 값을 찾아 preorder 방식으로 Tree 노드를 순회한다.


지금까지 알아본 Segment Tree의 시간 복잡도를 알아보면 

- Create : O(n) -> 모든 노드를 계산해야 하기 때문

- Query : O(Logn) -> Binary Search Tree와 비슷하다.

- Update : O(Logn) -> Query와 비슷하게 변경된 범위만 수정해야 하기 때문


Reference

https://janghw.tistory.com/entry/%EC%95%8C%EA%B3%A0%EB%A6%AC%EC%A6%98-Divide-and-Conquer-%EB%B6%84%ED%95%A0%EC%A0%95%EB%B3%B5

https://www.geeksforgeeks.org/segment-tree-set-1-sum-of-given-range/

'IT > Algorithm' 카테고리의 다른 글

Trie  (0) 2019.06.11
Topological Sorting  (0) 2019.06.10
Boyer–Moore Majority Vote  (0) 2019.06.06
Insertion Sort  (0) 2019.06.06
Selection Sort  (0) 2019.06.06

댓글