알고리즘/알고리즘 문제 해결전략(종만북)

22장 삽입 정렬 뒤집기 INSERTION

pureworld 2019. 1. 1. 21:34
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include<iostream>
#include<string>
#include<vector>
#include<cassert>
#include<algorithm>
using namespace std;
 
int main(void) {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int cases;
    cin >> cases;
    while (cases--) {
        int n;
        cin >> n;
        vector<int> arr(n, 0);
        vector<int> reverse;
        for (int i = 0; i < n; i++)
            arr[i] = i + 1;
        for (int i = 0; i < n; i++) {
            int k;
            cin >> k;
            reverse.push_back(k);
        }
        for (int i = n - 1; i >= 0; i--) {
            int j = i;
            int move = reverse[j];
            //i>0 안걸어주면 vector out of range 오류 뜸.
            while (i>0&&reverse[j]--) {
                swap(arr[i-move], arr[i-move+1]);
                i++;
            }
            i = j;
        }
        for (int i = 0; i < n; i++) {
            cout << arr[i] << " ";
        }
        cout << "\n";
    }
    return 0;
}
cs

문제 출처:https://algospot.com/judge/problem/read/INSERTION


난이도가 중이라서 긴장타고 풀었는데 10분도 안되서 정답이 뜨네요 처음으로 쉽게 문제해결을 해서 감동이였습니다.

복잡한 생각을 안하고 그냥 배열로 삽입 정렬을 뒤집어서 코딩했습니다. 


풀이와는 물론 의도가 다르게 풀었기 때문에 문제 의도에 맞춘 풀이 해석을 올리겠습니다. 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#include<set>
#include<cassert>
#include<cstdio>
#include<cstdlib>
#include<iostream>
#include<string>
#include<vector>
using namespace std;
 
int n, shifted[50000];
int A[50000];
 
typedef int KeyType;
 
// 트립의 한 노드를 저장한다
struct Node {
    // 노드에 저장된 키
    KeyType key;
    // 이 노드의 우선순위 (priority)
    // 이 노드를 루트로 하는 서브트리의 크기 (size)
    int priority, size;
    // 두 자식 노드의 포인터
    Node *left, *right;
 
    // 생성자에서 난수 우선순위를 생성하고, size 와 left/right 를 초기화한다
    Node(const KeyType& _key) : key(_key), priority(rand()),
        size(1), left(NULL), right(NULL) {
    }
    void setLeft(Node* newLeft) { left = newLeft; calcSize(); }
    void setRight(Node* newRight) { right = newRight; calcSize(); }
    // size 멤버를 갱신한다
    void calcSize() {
        size = 1;
        if (left) size += left->size;
        if (right) size += right->size;
    }
};
 
// key 이상의 값 중 가장 작은 값의 위치를 반환한다
Node* lowerBound(Node* root, KeyType key) {
    if (root == NULL) return NULL;
    if (root->key < key) return lowerBound(root->right, key);
    Node* ret = lowerBound(root->left, key);
    if (!ret) ret = root;
    return ret;
}
 
bool exists(Node* root, KeyType key) {
    Node* node = lowerBound(root, key);
    return node != NULL && node->key == key;
}
 
// root 를 루트로 하는 트리 중에서 k번째 원소를 반환한다
Node* kth(Node* root, int k) {
    int ls = (root->left ? root->left->size : 0);
    int rs = (root->right ? root->right->size : 0);
    if (k <= ls) return kth(root->left, k);
    if (k == ls + 1) return root;
    return kth(root->right, k - ls - 1);
}
 
// key 보다 작은 키값의 수를 반환한다
int countLessThan(Node* root, KeyType key) {
    if (root == NULL) return 0;
    if (root->key >= key)
        return countLessThan(root->left, key);
    int ls = (root->left ? root->left->size : 0);
    return ls + 1 + countLessThan(root->right, key);
}
 
typedef pair<Node*, Node*> NodePair;
 
// root 를 루트로 하는 트립을 key 미만의 값과 이상의 값을 갖는
// 두 개의 트립으로 분리한다.
NodePair split(Node* root, KeyType key) {
    if (root == NULL) return NodePair(NULL, NULL);
    // 루트가 key 미만이면 오른쪽의 일부를 잘라낸다
    if (root->key < key) {
        NodePair rs = split(root->right, key);
        root->setRight(rs.first);
        return NodePair(root, rs.second);
    }
    // 루트가 key 이상이면 왼쪽의 일부를 잘라낸다
    NodePair ls = split(root->left, key);
    root->setLeft(ls.second);
    return NodePair(ls.first, root);
}
 
// root 를 루트로 하는 트립에 새 노드 node 를 삽입한 뒤 결과 트립의
// 루트를 반환한다.
Node* insert(Node* root, Node* node) {
    if (root == NULL) return node;
    // node 가 루트를 대체해야 한다: 해당 서브트립을 반으로 잘라
    // 각각 자손으로 한다
    if (root->priority < node->priority) {
        NodePair splitted = split(root, node->key);
        node->setLeft(splitted.first);
        node->setRight(splitted.second);
        return node;
    }
    else if (node->key < root->key)
        root->setLeft(insert(root->left, node));
    else
        root->setRight(insert(root->right, node));
    return root;
}
 
// a 와 b 가 두 개의 트립이고, max(a) < min(b) 일때 이 둘을 합친다
Node* merge(Node* a, Node* b) {
    if (a == NULL) return b;
    if (b == NULL) return a;
 
    if (a->priority < b->priority) {
        b->setLeft(merge(a, b->left));
        return b;
    }
    a->setRight(merge(a->right, b));
    return a;
}
 
// root 를 루트로 하는 트립에서 key 를 지운다
Node* erase(Node* root, KeyType key) {
    if (root == NULL) return root;
    // root 를 지우고 양 서브트립을 합친 뒤 반환한다
    if (root->key == key) {
        Node* ret = merge(root->left, root->right);
        delete root;
        return ret;
    }
    if (key < root->key)
        root->setLeft(erase(root->left, key));
    else
        root->setRight(erase(root->right, key));
    return root;
}
void solve() {
    //1~N까지의 숫자를 모두 저장하는 트립을 만든다.
    Node* candidates = NULL;
    for (int i = 0; i < n; i++)
        candidates = insert(candidates, new Node(i + 1));
    //뒤에서부터 A[]를 채워나간다.
    for (int i = n - 1; i >= 0--i) {
        //후보 중 이 수보다 큰 수가 larger개 있다.
        int larger = shifted[i];
        Node*= kth(candidates, i + 1- larger);
        A[i] = k->key;
        candidates = erase(candidates, k->key);
    }
}
 
int main(void) {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int cases;
    cin >> cases;
    while (cases--) {
 
        cin >> n;
        for (int i = 0; i < n; i++)
            cin >> shifted[i];
        solve();
        for (int i = 0; i < n; i++)
            cout << A[i] << " ";
        cout << "\n";
 
    }
    return 0;
}
cs

마지막 숫자 A[N-1]이 왼쪽으로 몇 칸 움직였는지를 보면 A[N-1]에 어떤 숫자가 들어가야 할지 알 수 있습니다.

예를 들어 문제에 적힌 예제에서 마지막 숫자 A[4]는 3칸 왼쪽으로 움직임

-> 이말은 1~5 범위의 숫자 중에 A[4]보다 큰 숫자가 3개 있다는 뜻


kth()와 erase() 함수의 수행 시간은 모두 O(lgN) 이므로 solve()의 전체 시간 복잡도는 O(NlgN)가 됩니다.


수행시간을 보면 트립을 사용했을때와 사용하지 않았을 때의 시간이 압도적으로 차이가 나네요. 

제가 짠 위의 코드의 최악의 경우는 O(n^2)이기 때문에 이런 결과가 나온 것 같습니다.