https://git.reactos.org/?p=reactos.git;a=commitdiff;h=d12880829ffee337c46bc4...
commit d12880829ffee337c46bc4cedbcc37a7a702cffb Author: Mark Jansen mark.jansen@reactos.org AuthorDate: Sat Apr 15 22:21:33 2023 +0200 Commit: Mark Jansen mark.jansen@reactos.org CommitDate: Sat Apr 22 21:23:55 2023 +0200
[ATL] Add OBJECT_ENTRY_AUTO for simpler com object registration
Of course gcc needs a nasty hack to include the symbol. CORE-18936 --- sdk/lib/atl/atlbase.h | 177 ++++++++++++++++++++++++++++++++++++++++++++------ sdk/lib/atl/atlcom.h | 58 +++++++++++++++++ 2 files changed, 216 insertions(+), 19 deletions(-)
diff --git a/sdk/lib/atl/atlbase.h b/sdk/lib/atl/atlbase.h index 4bc00d2e9bf..1b0b0361b6a 100644 --- a/sdk/lib/atl/atlbase.h +++ b/sdk/lib/atl/atlbase.h @@ -59,6 +59,15 @@ class CAtlComModule; __declspec(selectany) CAtlModule *_pAtlModule = NULL; __declspec(selectany) CComModule *_pModule = NULL;
+template <bool isDll, typename T> struct CAtlValidateModuleConfiguration +{ +#if !defined(_WINDLL) && !defined(_USRDLL) + static_assert(!isDll, "_WINDLL or _USRDLL must be defined when 'CAtlDllModuleT<T>' is used"); +#else + static_assert(isDll, "_WINDLL or _USRDLL must be defined when 'CAtlExeModuleT<T>' is used"); +#endif +}; +
struct _ATL_CATMAP_ENTRY { @@ -173,6 +182,46 @@ struct _ATL_WIN_MODULE70 }; typedef _ATL_WIN_MODULE70 _ATL_WIN_MODULE;
+ +// Auto object map + +#if defined(_MSC_VER) +#pragma section("ATL$__a", read, write) +#pragma section("ATL$__z", read, write) +#pragma section("ATL$__m", read, write) +#define _ATLALLOC(x) __declspec(allocate(x)) + +#if defined(_M_IX86) +#define OBJECT_ENTRY_PRAGMA(class) __pragma(comment(linker, "/include:___pobjMap_" #class)); +#elif defined(_M_IA64) || defined(_M_AMD64) || (_M_ARM) || defined(_M_ARM64) +#define OBJECT_ENTRY_PRAGMA(class) __pragma(comment(linker, "/include:__pobjMap_" #class)); +#else +#error Your platform is not supported. +#endif + +#elif defined(__GNUC__) + +// GCC completely ignores __attribute__((unused)) on the __pobjMap_ pointer, so we pass it to a function that is not allowed to be optimized.... +static int __attribute__((optimize("O0"), unused)) hack_for_gcc(const _ATL_OBJMAP_ENTRY * const *) +{ + return 1; +} + +#define _ATLALLOC(x) __attribute__((section(x))) +#define OBJECT_ENTRY_PRAGMA(class) static int __pobjMap_hack_##class = hack_for_gcc(&__pobjMap_##class); + +#else +#error Your compiler is not supported. +#endif + + +extern "C" +{ + __declspec(selectany) _ATLALLOC("ATL$__a") _ATL_OBJMAP_ENTRY *__pobjMapEntryFirst = NULL; + __declspec(selectany) _ATLALLOC("ATL$__z") _ATL_OBJMAP_ENTRY *__pobjMapEntryLast = NULL; +} + + struct _ATL_REGMAP_ENTRY { LPCOLESTR szKey; @@ -551,8 +600,9 @@ public: CAtlComModule() { GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS, (LPCWSTR)this, &m_hInstTypeLib); - m_ppAutoObjMapFirst = NULL; - m_ppAutoObjMapLast = NULL; + + m_ppAutoObjMapFirst = &__pobjMapEntryFirst + 1; + m_ppAutoObjMapLast = &__pobjMapEntryLast; if (FAILED(m_csObjMap.Init())) { ATLASSERT(0); @@ -577,17 +627,37 @@ public: return AtlComModuleUnregisterServer(this, bUnRegTypeLib, pCLSID); }
- void Term() { if (cbSize != 0) { - ATLASSERT(m_ppAutoObjMapFirst == NULL); - ATLASSERT(m_ppAutoObjMapLast == NULL); + for (_ATL_OBJMAP_ENTRY **iter = m_ppAutoObjMapFirst; iter < m_ppAutoObjMapLast; iter++) + { + _ATL_OBJMAP_ENTRY *ptr = *iter; + if (!ptr) + continue; + + if (!ptr->pCF) + continue; + + ptr->pCF->Release(); + ptr->pCF = NULL; + } m_csObjMap.Term(); cbSize = 0; } } + + void ExecuteObjectMain(bool bStarting) + { + for (_ATL_OBJMAP_ENTRY **iter = m_ppAutoObjMapFirst; iter < m_ppAutoObjMapLast; iter++) + { + if (!*iter) + continue; + + (*iter)->pfnObjectMain(bStarting); + } + } };
__declspec(selectany) CAtlComModule _AtlComModule; @@ -606,11 +676,20 @@ HRESULT CAtlModuleT<T>::UnregisterServer(BOOL bUnRegTypeLib, const CLSID *pCLSID }
template <class T> -class CAtlDllModuleT : public CAtlModuleT<T> +class CAtlDllModuleT + : public CAtlModuleT<T> + , private CAtlValidateModuleConfiguration<true, T> + { public: CAtlDllModuleT() { + _AtlComModule.ExecuteObjectMain(true); + } + + ~CAtlDllModuleT() + { + _AtlComModule.ExecuteObjectMain(false); }
HRESULT DllCanUnloadNow() @@ -659,7 +738,9 @@ public:
template <class T> -class CAtlExeModuleT : public CAtlModuleT<T> +class CAtlExeModuleT + : public CAtlModuleT<T> + , private CAtlValidateModuleConfiguration<false, T> { public: DWORD m_dwMainThreadID; @@ -670,10 +751,12 @@ public: CAtlExeModuleT() :m_dwMainThreadID(::GetCurrentThreadId()) { + _AtlComModule.ExecuteObjectMain(true); }
~CAtlExeModuleT() { + _AtlComModule.ExecuteObjectMain(false); }
int WinMain(int nShowCmd) @@ -815,12 +898,19 @@ public: } } } + + for (_ATL_OBJMAP_ENTRY **iter = _AtlComModule.m_ppAutoObjMapFirst; iter < _AtlComModule.m_ppAutoObjMapLast; iter++) + { + if (*iter != NULL) + (*iter)->pfnObjectMain(true); + } + return S_OK; }
void Term() { - _ATL_OBJMAP_ENTRY *objectMapEntry; + _ATL_OBJMAP_ENTRY *objectMapEntry;
if (m_pObjMap != NULL) { @@ -834,12 +924,19 @@ public: objectMapEntry++; } } + + for (_ATL_OBJMAP_ENTRY **iter = _AtlComModule.m_ppAutoObjMapFirst; iter < _AtlComModule.m_ppAutoObjMapLast; iter++) + { + if (*iter != NULL) + (*iter)->pfnObjectMain(false); + } + }
HRESULT GetClassObject(REFCLSID rclsid, REFIID riid, LPVOID *ppv) { - _ATL_OBJMAP_ENTRY *objectMapEntry; - HRESULT hResult; + _ATL_OBJMAP_ENTRY *objectMapEntry; + HRESULT hResult;
ATLASSERT(ppv != NULL); if (ppv == NULL) @@ -869,8 +966,7 @@ public: } if (hResult == S_OK && *ppv == NULL) { - // FIXME: call AtlComModuleGetClassObject - hResult = CLASS_E_CLASSNOTAVAILABLE; + hResult = AtlComModuleGetClassObject(&_AtlComModule, rclsid, riid, ppv); } return hResult; } @@ -1480,9 +1576,9 @@ inline HRESULT __stdcall AtlAdvise(IUnknown *pUnkCP, IUnknown *pUnk, const IID &
inline HRESULT __stdcall AtlUnadvise(IUnknown *pUnkCP, const IID &iid, DWORD dw) { - CComPtr<IConnectionPointContainer> container; - CComPtr<IConnectionPoint> connectionPoint; - HRESULT hResult; + CComPtr<IConnectionPointContainer> container; + CComPtr<IConnectionPoint> connectionPoint; + HRESULT hResult;
if (pUnkCP == NULL) return E_INVALIDARG; @@ -1809,14 +1905,18 @@ inline HRESULT WINAPI AtlComModuleRegisterClassObjects(_ATL_COM_MODULE *module,
for (iter = module->m_ppAutoObjMapFirst; iter < module->m_ppAutoObjMapLast; iter++) { - if (!(*iter)->pfnGetClassObject) + _ATL_OBJMAP_ENTRY *ptr = *iter; + if (!ptr) continue;
- hr = (*iter)->pfnGetClassObject((void*)(*iter)->pfnCreateInstance, IID_IUnknown, (void**)&unk); + if (!ptr->pfnGetClassObject) + continue; + + hr = ptr->pfnGetClassObject((void*)ptr->pfnCreateInstance, IID_IUnknown, (void**)&unk); if (FAILED(hr)) return hr;
- hr = CoRegisterClassObject(*(*iter)->pclsid, unk, context, flags, &(*iter)->dwRegister); + hr = CoRegisterClassObject(*ptr->pclsid, unk, context, flags, &ptr->dwRegister); unk->Release(); if (FAILED(hr)) return hr; @@ -1837,7 +1937,11 @@ inline HRESULT WINAPI AtlComModuleRevokeClassObjects(_ATL_COM_MODULE *module)
for (iter = module->m_ppAutoObjMapFirst; iter < module->m_ppAutoObjMapLast; iter++) { - hr = CoRevokeClassObject((*iter)->dwRegister); + _ATL_OBJMAP_ENTRY *ptr = *iter; + if (!ptr) + continue; + + hr = CoRevokeClassObject(ptr->dwRegister); if (FAILED(hr)) return hr; } @@ -1845,6 +1949,41 @@ inline HRESULT WINAPI AtlComModuleRevokeClassObjects(_ATL_COM_MODULE *module) return S_OK; }
+// Adapted from dll/win32/atl/atl.c +inline HRESULT WINAPI +AtlComModuleGetClassObject(_ATL_COM_MODULE *pm, REFCLSID rclsid, REFIID riid, void **ppv) +{ + if (!pm) + return E_INVALIDARG; + + for (_ATL_OBJMAP_ENTRY **iter = pm->m_ppAutoObjMapFirst; iter < pm->m_ppAutoObjMapLast; iter++) + { + _ATL_OBJMAP_ENTRY *ptr = *iter; + if (!ptr) + continue; + + if (IsEqualCLSID(*ptr->pclsid, rclsid) && ptr->pfnGetClassObject) + { + HRESULT hr = CLASS_E_CLASSNOTAVAILABLE; + + if (!ptr->pCF) + { + CComCritSecLock<CComCriticalSection> lock(_AtlComModule.m_csObjMap, true); + if (!ptr->pCF) + { + hr = ptr->pfnGetClassObject((void *)ptr->pfnCreateInstance, IID_IUnknown, (void **)&ptr->pCF); + } + } + if (ptr->pCF) + hr = ptr->pCF->QueryInterface(riid, ppv); + return hr; + } + } + + return CLASS_E_CLASSNOTAVAILABLE; +} + + }; // namespace ATL
#ifndef _ATL_NO_AUTOMATIC_NAMESPACE diff --git a/sdk/lib/atl/atlcom.h b/sdk/lib/atl/atlcom.h index c9d92edbd9c..024d85a0bef 100644 --- a/sdk/lib/atl/atlcom.h +++ b/sdk/lib/atl/atlcom.h @@ -30,7 +30,12 @@ namespace ATL template <class Base, const IID *piid, class T, class Copy, class ThreadModel = CComObjectThreadModel> class CComEnum;
+#if defined(_WINDLL) | defined(_USRDLL) #define DECLARE_CLASSFACTORY_EX(cf) typedef ATL::CComCreator<ATL::CComObjectCached<cf> > _ClassFactoryCreatorClass; +#else +// Class factory should not change lock count +#define DECLARE_CLASSFACTORY_EX(cf) typedef ATL::CComCreator<ATL::CComObjectNoLock<cf>> _ClassFactoryCreatorClass; +#endif #define DECLARE_CLASSFACTORY() DECLARE_CLASSFACTORY_EX(ATL::CComClassFactory) #define DECLARE_CLASSFACTORY_SINGLETON(obj) DECLARE_CLASSFACTORY_EX(ATL::CComClassFactorySingleton<obj>)
@@ -539,6 +544,40 @@ public: } };
+ +template <class Base> +class CComObjectNoLock : public Base +{ + public: + CComObjectNoLock(void* = NULL) + { + } + + virtual ~CComObjectNoLock() + { + this->FinalRelease(); + } + + STDMETHOD_(ULONG, AddRef)() + { + return this->InternalAddRef(); + } + + STDMETHOD_(ULONG, Release)() + { + ULONG newRefCount = this->InternalRelease(); + if (newRefCount == 0) + delete this; + return newRefCount; + } + + STDMETHOD(QueryInterface)(REFIID iid, void **ppvObject) + { + return this->_InternalQueryInterface(iid, ppvObject); + } +}; + + #define BEGIN_COM_MAP(x) \ public: \ typedef x _ComMapClass; \ @@ -663,6 +702,24 @@ public: class::GetCategoryMap, \ class::ObjectMain },
+ + +#define OBJECT_ENTRY_AUTO(clsid, class) \ + ATL::_ATL_OBJMAP_ENTRY __objMap_##class = { \ + &clsid, \ + class ::UpdateRegistry, \ + class ::_ClassFactoryCreatorClass::CreateInstance, \ + class ::_CreatorClass::CreateInstance, \ + NULL, \ + 0, \ + class ::GetObjectDescription, \ + class ::GetCategoryMap, \ + class ::ObjectMain}; \ + extern "C" _ATLALLOC("ATL$__m") ATL::_ATL_OBJMAP_ENTRY *const __pobjMap_##class = &__objMap_##class; \ + OBJECT_ENTRY_PRAGMA(class) + + + class CComClassFactory : public IClassFactory, public CComObjectRootEx<CComGlobalsThreadModel> @@ -772,6 +829,7 @@ class CComCoClass { public: DECLARE_CLASSFACTORY() + //DECLARE_AGGREGATABLE(T) // This should be here, but gcc...
static LPCTSTR WINAPI GetObjectDescription() {