Advanced C++

[C++] 멀티쓰레드 환경에서의 스마트 포인터

로파이 2021. 8. 16. 20:18

스마트 포인터

객체의 포인터를 가지고 있는 대리자로서 메모리 해제 책임을 스마트 포인터에 위임하는 방법이다.

 

문제

멀티 스레드 환경에서 공유 카운트에 대한 쓰기 (참조 증가, 감소)가 병행적으로 일어나여 리소스를 다중 스레드를 이용하여 로드시 등과 같은 예제에서 공유 카운트가 동기화가 안 될 수도 있다. 

 

스핀락

스핀락의 경우 운영체제의 lock 방법을 사용하지 않고(lock-free) 현 쓰레드를 block하지 않고 while 루프를 돌며 재진입을 기다리기 때문에 잠그는 임계영역의 범위가 작은 경우 lock-free 동기화 기법이 유용할 수도 있다.

 

사용하고 있는 SpinLock

반복 횟수에 따라 유동적으로 문맥 교환을 유도하거나 다른 프로세서의 명령어 처리를 유도할 수 있는 장점이 있다.

class SpinLock
{
private:
	static constexpr int YIELD_ITERATION = 20;
	static constexpr int MAX_SLEEP_ITERATION = 30;
	volatile LONG m_iFlag;

public:
	SpinLock();
	~SpinLock();

public:
	NULL_COPY_AND_ASSIGN(SpinLock);
	DEFAULT_MOVE_CLASS(SpinLock);

public:
	void Lock();
	void UnLock();
};

#include "SpinLock.h"

void SpinLock::Lock()
{
	int iterations = 0;
	while (InterlockedCompareExchange(&m_iFlag, 1, 0))
	{
		// 더 높은 우선순위의 ready 스레드가 있으면 Context Switching이 일어난다.
		if (iterations + YIELD_ITERATION >= MAX_SLEEP_ITERATION)
		{
			Sleep(0);
		}

		if (iterations >= YIELD_ITERATION && iterations < MAX_SLEEP_ITERATION)
		{
			iterations = 0;
			SwitchToThread();
		}
		++iterations;
		//
		// Yield processor on multi-processor but if on single processor then give other thread the CPU
		YieldProcessor(/*no op*/);
	}
}

void SpinLock::UnLock()
{
	InterlockedDecrement(&m_iFlag);
}

SpinLock::SpinLock()
	:
	m_iFlag(0)
{
}

SpinLock::~SpinLock()
{
}

 

스마트 포인터 구현 (수정)

주요 구현점

1. 동기화 카운트는 힙에 생성한다.

2. static_pointer_cast와 같이 다형성 스마트 포인터 캐스팅을 지원하는 CastTo를 구현해본다.

- assert와 dynamic_cast(RTTI)를 사용하여 체크할 수도 있다.

3. AddRef()와 SubRef()의 내용에서 참조 카운트를 변경할 때 반드시 동기화를 신경쓰도록 한다.

4. STL 스마트 포인터의 reset() 기능을 구현해본다.

5. 생 포인터로부터 생성하는 것을 explicit으로 구현한다.

// SharedPtr<A> pA(new A); // ok

// SharedPtr<A> pA = new A; // not ok

// SharedPtr<A> pA = nullptr; // overload with SharedPtr(nullptr_t)

6. nullptr의 대입, 복사 생성 그리고 ==와 같은 연산자의 지원을 신경쓰도록 한다.

- nullptr의 대입 연산자 오버로딩은 기능 지원하도록 따로 구현.

7. CriticalSection을 이용하여 동기화

#pragma once
#include "CriticalSection.h"

template<typename T>
class SharedPtr
{
	template<typename S>
	friend class SharedPtr;
private:
	T*			  m_pRes;
	UINT*		  m_iRefCount;
	static CriticalSection m_tMutex;

public:
	T* operator->()
	{
		assert(m_pRes && "[SharedPtr] Access Null Obj");
		return m_pRes;
	}
	const T* operator->() const
	{
		assert(m_pRes && "[SharedPtr] Access Null Obj");
		return m_pRes;
	}

	T* Get() { return m_pRes; }
	const T* Get() const { return m_pRes; }
	T** GetAddressOf() { return &m_pRes; }
	T* const* GetAddressOf() const { return &m_pRes; }

	UINT UseCount() const 
	{ 
		if (m_iRefCount == nullptr)
			return 0;
		return (*m_iRefCount); 
	}

	void Reset(T* _ptr = nullptr)
	{
		assert(_ptr != m_pRes);

		if (m_pRes)
		{
			SubRef();
		}
		
		m_pRes = _ptr;

		if (m_pRes)
		{
			assert(m_iRefCount == nullptr);
			m_iRefCount = new UINT(1);
		}
	}

	template<typename S>
	SharedPtr<S> CastTo() const
	{
		assert(m_pRes != nullptr && dynamic_cast<S*>(m_pRes) != nullptr);

		SharedPtr<S> pCast;
		pCast.m_iRefCount = m_iRefCount;
		pCast.m_pRes = (S*)m_pRes;
		pCast.AddRef();

		return pCast;
	}

	constexpr operator bool() const { return m_pRes != nullptr; }
	bool operator==(nullptr_t) const { return m_pRes == nullptr; }
	bool operator!=(nullptr_t) const { return m_pRes != nullptr; }
	bool operator==(const SharedPtr<T>& rhs) const { return m_pRes == rhs.m_pRes; }
	bool operator!=(const SharedPtr<T>& rhs) const { return m_pRes != rhs.m_pRes; }

private:
	void AddRef()
	{
		m_tMutex.Lock();
		++(*m_iRefCount);
		m_tMutex.UnLock();
	}
	void SubRef()
	{
		m_tMutex.Lock();
		--(*m_iRefCount);

		if ((*m_iRefCount) == 0)
		{
			SAFE_DELETE(m_iRefCount);
			SAFE_DELETE(m_pRes);
		}

		m_tMutex.UnLock();
	}
public:
	constexpr SharedPtr() noexcept
		:
		m_pRes(nullptr),
		m_iRefCount(nullptr)
	{}
	constexpr SharedPtr(nullptr_t) noexcept
		:
		m_pRes(nullptr),
		m_iRefCount(nullptr)
	{}
	explicit SharedPtr(T* _ptr)
		:
		m_pRes(_ptr),
		m_iRefCount(nullptr)
	{
		if (m_pRes)
		{
			m_iRefCount = new UINT(1);
		}
	}
	~SharedPtr()
	{
		if (m_pRes)
		{
			SubRef();
		}
	}
	SharedPtr& operator=(nullptr_t)
	{
		if (m_pRes)
		{
			SubRef();
		}

		m_pRes = nullptr;
		m_iRefCount = nullptr;
		return *this;
	}
	SharedPtr(const SharedPtr<T>& rhs)
		:
		m_pRes(rhs.m_pRes),
		m_iRefCount(rhs.m_iRefCount)
	{
		if (m_pRes)
		{
			AddRef();
		}
	}
	SharedPtr& operator=(const SharedPtr<T>& _ptr)
	{
		if (m_pRes)
		{
			SubRef();
		}

		m_pRes = _ptr.m_pRes;
		m_iRefCount = _ptr.m_iRefCount;

		if (m_pRes)
		{
			AddRef();
		}

		return *this;
	}
	SharedPtr(SharedPtr<T>&& rhs) noexcept
		:
		m_pRes(rhs.m_pRes),
		m_iRefCount(rhs.m_iRefCount)
	{
		rhs.m_pRes = nullptr;
		rhs.m_iRefCount = nullptr;
	}
	SharedPtr& operator=(SharedPtr<T>&& _ptr) noexcept
	{
		if (m_pRes)
		{
			SubRef();
		}

		m_pRes = _ptr.m_pRes;
		m_iRefCount = _ptr.m_iRefCount;

		_ptr.m_iRefCount = nullptr;
		_ptr.m_pRes = nullptr;

		return *this;
	}
};

template<typename T>
CriticalSection SharedPtr<T>::m_tMutex = {};