[c++] CAS ๊ตฌํ ๋ฐ ABA ๋ฌธ์ ํด๊ฒฐ :: ABA ํด๊ฒฐ_3. Counter ๊ตฌํ
๋ชฉ์ฐจ
- ๋ฌธ์ ์ ์
- lock free ๊ตฌํ
- ABA ํด๊ฒฐ
- intํ ๊ตฌํ(+ Hazard pointer)
- Counter
- ๊ทธ ์ธ์ ๋ฐฉ๋ฒ๋ค
- mutex lock(spin lock)๊ณผ์ ๋น๊ต
Counter
#include <iostream>
#include <thread>
#include <vector>
#include <atomic>
#include <random>
#include <mutex>
using namespace std;
class Node {
public:
int value;
Node* next;
Node() : value(0) { next = NULL; }
Node(int k_value) {
next = NULL;
value = k_value;
}
};
int node_n;
//random ๊ตฌํ
random_device rd;
mt19937 gen(rd());
uniform_int_distribution<> dis(0, 10);
//์ถ๋ ฅ์ ์ํ mutex
mutex mut;
class LFStack {
private:
long long CountedPtr;
public:
void printNodes() {
int size = 0;
Node* head = reinterpret_cast<Node*>(CountedPtr);
while (head != nullptr) {
size++;
head = head->next;
}
cout << size << endl;
}
long long getCount(long long cp) {
return ((unsigned long long)cp >> 52); // ์ผ์ชฝ ๋นํธ 0์ผ๋ก ์ฑ์ฐ๊ธฐ
}
long long setCount(long long count) {
return count << 52;
}
void push(Node* newNode) {
long long oldCP, newCP;
unsigned long long newCount;
do {
oldCP = this->CountedPtr;
newCount = getCount(oldCP) + 1; // count 1 ์ฆ๊ฐ
newNode->next = reinterpret_cast<Node*>(oldCP); // list์ head ๋
ธ๋๋ฅผ ๊ฐ๋ฆฌํค๊ฒ ํ๊ธฐ
newCP = setCount(newCount) + reinterpret_cast<long long>(newNode); // ๋ฐ๊พผ ์นด์ดํฐ์ node ๊ฒฐํฉ
} while (!(_InterlockedCompareExchange64(&(this->CountedPtr),newCP,oldCP)==oldCP)); //CAS
}
Node* pop() {
long long oldCP, newCP;
do {
oldCP = this->CountedPtr;
Node* popNode = reinterpret_cast<Node*>(oldCP); // ๋งจ ์ ๋
ธ๋ ๊ตฌํ๊ธฐ
if (popNode == nullptr)
return nullptr;
newCP = (long long)(setCount(getCount(oldCP))) + reinterpret_cast<long long>(popNode->next); // counter๋ ๊ทธ๋๋ก, node๋ ๋ค์ ๋
ธ๋
} while (!(_InterlockedCompareExchange64(&(this->CountedPtr),newCP,oldCP)==oldCP)); // CAS
return reinterpret_cast<Node*>(oldCP);
}
};
LFStack* FreeList;
LFStack* HeadList;
void subThread() {
// random n, m
int n = dis(gen);
int m = dis(gen);
int num = 10000; // while๋ฌธ ๋ฐ๋ณต ํ์
while (num--) {
for (int i = 0; i < n; i++)
{
// free_list์ ์์ผ๋ฉด head_list์ ์ฎ๊ธฐ๋ ์์
Node* tmp = FreeList->pop();
if (tmp == nullptr)
continue;
HeadList->push(tmp);
}
for (int i = 0; i<m; i++) {
// head_list์ ์์ผ๋ฉด free_list์ ์ฎ๊ธฐ๋ ์์
Node* tmp = HeadList->pop();
if (tmp == nullptr)
continue;
FreeList->push(tmp);
}
}
}
int main() {
// free, head list ์์ฑ
FreeList = new LFStack();
HeadList = new LFStack();
// ๋
ธ๋ ์์ฑ ๋ฐ free_list์ ์ฝ์
node_n = 100000; // ๋
ธ๋ ๊ฐฏ์
for (int i = node_n - 1; i >= 0; i--) {
Node* node = new Node(i);
FreeList->push(node);
}
// ์ค๋ ๋ ์์ฑ ๋ฐ ๋์ ์์
vector<thread> threads;
int n = 5; // ์ค๋ ๋ ๊ฐฏ์
for (int i = 0; i < n; i++) {
threads.emplace_back(subThread);
}
for (auto &t : threads)
t.join();
threads.clear();
FreeList->printNodes();
HeadList->printNodes();
cin >> n;
}
๋๋ถ๋ถ์ 64bit CPU์์ ๊ฐ์ ์ฃผ์๋ฅผ 52bit ๋ฐ์ ์ฌ์ฉํ์ง ์๋๋ค๋ ์ ์ ์ด์ฉํ ๋ฐฉ๋ฒ์ด๋ค.
๊ฐ List๋ long long ํ CountedPtr์ด๋ผ๋ ์๋ฃํ ํ๋๋ก ์ด๋ฃจ์ด์ ธ์๋ค. long long์ 8byte ์๋ฃํ์ด๋ฏ๋ก 64bit์ ์ฃผ์๋ฅผ ๋ชจ๋ ํํํ ์ ์๋ค. CountedPtr์ ํ์ 52bit๋ ์ฃผ์๋ฅผ ๋ํ๋ด๋๋ฐ์ ์ฐ๊ณ , ์์ 12bit๋ Count๋ฅผ ๋ํ๋ด๋ ๋ฐ์ ์ฐ๊ธฐ๋ก ํ์.
List๊ฐ ๊ฐ๋ฆฌํค๋ ์ฒซ๋ฒ์งธ ๋
ธ๋์ ์ฃผ์์ List์ push๊ฐ ์ฌ์ฉ๋ ํ์(Count)๋ฅผ ๊ฐ๊ฐ ํ์ 52bit, ์์ 12bit์ ํ ๋นํ์ฌ CountedPtr์ ๋ง๋ ๋ค. ABA ๋ฌธ์ ๋ push๊ฐ ๋ฐ์ํด์ผ์ง๋ง ์ผ์ด๋ ๊ฐ๋ฅ์ฑ์ด ์์ผ๋ฏ๋ก pushํจ์ ์ํ ์์๋ง Count๋ฅผ ์ฆ๊ฐ์์ผ์ฃผ์๋ค. push ํจ์๊ฐ ์คํ๋ ๋๋ง๋ค Count๋ 1์ฉ ์ฆ๊ฐํ๋ฏ๋ก _InterlockedCompareExchange64
์์ ๋น๊ต ์ Count๊ฐ ๋ฌ๋ผ์ง๋ฉด ์ฐ์ฐ์ด ์ํ๋์ง ์๋๋ค. ๋ฐ๋ผ์ ABA ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ ์ ์๋ค.
pop ํจ์์์ List์ ์ฒซ๋ฒ์งธ ๋
ธ๋๊ฐ ๊ฐ๋ฆฌํค๋ ๋ค์ ๋
ธ๋๋ฅผ ๋ฐ์์ค๋ ค๋ฉด CountedPtr์ด๋ผ๋ ์ ์ํ์์ Node์ ์ฃผ์์ธ Node*๋ฅผ ์ถ์ถํ ํ์๊ฐ ์๋ค. ์ด ๋๋ reinterpret_cast<Node*>(๋์)
์ด๋ผ๋ ์บ์คํ
์ฐ์ฐ์ ์ฌ์ฉํ์ฌ Node๋ฅผ ๋ถ๋ฌ์จ๋ค.
๋ํ Count๋ 1 ๋๋ฆฌ๊ธฐ ์ํด ์ถ์ถํ๋ ๊ณผ์ ์ด ํ์ํ๋ฐ, Count๋ getCount์ setCount ํจ์๋ฅผ ๋ฐ๋ก ๊ตฌํํ์ฌ ๋ค๋ฃจ์๋ค. Count ์ถ์ถ ์ ์ ์ํ ์ ์ CPU๋ง๋ค right shift ์ฐ์ฐ ์ 0๊ณผ 1 ์ค์ ์ด๋ค ๊ฐ์ด ์ฑ์์ง๋ ์ง๊ฐ ๋ค๋ฅด๋ค๋ ๊ฒ์ด๋ค. ๋ฐ๋ผ์ unsigned ์๋ฃํ์ ์ด์ฉํ์ฌ ์์๊ฐ ๋์ง ์๋๋ก(0์ผ๋ก ์ฑ์์ง๋๋ก) ๋ง๋ค์ด์ฃผ์ด์ผํ๋ค. ์๋๋ฉด 0์ผ๋ก ์ฑ์์ฃผ๋ >>>
์ฐ์ฐ์ ์ฌ์ฉํด๋ ๋๋๋ฐ c++์๋ ํด๋น ์ฐ์ฐ์ด ์กด์ฌํ์ง ์๋๋ค.
=> ABA๋ฅผ ํด๊ฒฐํ๋ ๊ธฐํ ๋ฐฉ๋ฒ๋ค์ ๋ค์ ํฌ์คํ ์