https://git.reactos.org/?p=reactos.git;a=commitdiff;h=d12880829ffee337c46bc…
commit d12880829ffee337c46bc4cedbcc37a7a702cffb
Author: Mark Jansen <mark.jansen(a)reactos.org>
AuthorDate: Sat Apr 15 22:21:33 2023 +0200
Commit: Mark Jansen <mark.jansen(a)reactos.org>
CommitDate: Sat Apr 22 21:23:55 2023 +0200
[ATL] Add OBJECT_ENTRY_AUTO for simpler com object registration
Of course gcc needs a nasty hack to include the symbol.
CORE-18936
---
sdk/lib/atl/atlbase.h | 177 ++++++++++++++++++++++++++++++++++++++++++++------
sdk/lib/atl/atlcom.h | 58 +++++++++++++++++
2 files changed, 216 insertions(+), 19 deletions(-)
diff --git a/sdk/lib/atl/atlbase.h b/sdk/lib/atl/atlbase.h
index 4bc00d2e9bf..1b0b0361b6a 100644
--- a/sdk/lib/atl/atlbase.h
+++ b/sdk/lib/atl/atlbase.h
@@ -59,6 +59,15 @@ class CAtlComModule;
__declspec(selectany) CAtlModule *_pAtlModule = NULL;
__declspec(selectany) CComModule *_pModule = NULL;
+template <bool isDll, typename T> struct CAtlValidateModuleConfiguration
+{
+#if !defined(_WINDLL) && !defined(_USRDLL)
+ static_assert(!isDll, "_WINDLL or _USRDLL must be defined when
'CAtlDllModuleT<T>' is used");
+#else
+ static_assert(isDll, "_WINDLL or _USRDLL must be defined when
'CAtlExeModuleT<T>' is used");
+#endif
+};
+
struct _ATL_CATMAP_ENTRY
{
@@ -173,6 +182,46 @@ struct _ATL_WIN_MODULE70
};
typedef _ATL_WIN_MODULE70 _ATL_WIN_MODULE;
+
+// Auto object map
+
+#if defined(_MSC_VER)
+#pragma section("ATL$__a", read, write)
+#pragma section("ATL$__z", read, write)
+#pragma section("ATL$__m", read, write)
+#define _ATLALLOC(x) __declspec(allocate(x))
+
+#if defined(_M_IX86)
+#define OBJECT_ENTRY_PRAGMA(class) __pragma(comment(linker,
"/include:___pobjMap_" #class));
+#elif defined(_M_IA64) || defined(_M_AMD64) || (_M_ARM) || defined(_M_ARM64)
+#define OBJECT_ENTRY_PRAGMA(class) __pragma(comment(linker,
"/include:__pobjMap_" #class));
+#else
+#error Your platform is not supported.
+#endif
+
+#elif defined(__GNUC__)
+
+// GCC completely ignores __attribute__((unused)) on the __pobjMap_ pointer, so we pass
it to a function that is not allowed to be optimized....
+static int __attribute__((optimize("O0"), unused)) hack_for_gcc(const
_ATL_OBJMAP_ENTRY * const *)
+{
+ return 1;
+}
+
+#define _ATLALLOC(x) __attribute__((section(x)))
+#define OBJECT_ENTRY_PRAGMA(class) static int __pobjMap_hack_##class =
hack_for_gcc(&__pobjMap_##class);
+
+#else
+#error Your compiler is not supported.
+#endif
+
+
+extern "C"
+{
+ __declspec(selectany) _ATLALLOC("ATL$__a") _ATL_OBJMAP_ENTRY
*__pobjMapEntryFirst = NULL;
+ __declspec(selectany) _ATLALLOC("ATL$__z") _ATL_OBJMAP_ENTRY
*__pobjMapEntryLast = NULL;
+}
+
+
struct _ATL_REGMAP_ENTRY
{
LPCOLESTR szKey;
@@ -551,8 +600,9 @@ public:
CAtlComModule()
{
GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS, (LPCWSTR)this,
&m_hInstTypeLib);
- m_ppAutoObjMapFirst = NULL;
- m_ppAutoObjMapLast = NULL;
+
+ m_ppAutoObjMapFirst = &__pobjMapEntryFirst + 1;
+ m_ppAutoObjMapLast = &__pobjMapEntryLast;
if (FAILED(m_csObjMap.Init()))
{
ATLASSERT(0);
@@ -577,17 +627,37 @@ public:
return AtlComModuleUnregisterServer(this, bUnRegTypeLib, pCLSID);
}
-
void Term()
{
if (cbSize != 0)
{
- ATLASSERT(m_ppAutoObjMapFirst == NULL);
- ATLASSERT(m_ppAutoObjMapLast == NULL);
+ for (_ATL_OBJMAP_ENTRY **iter = m_ppAutoObjMapFirst; iter <
m_ppAutoObjMapLast; iter++)
+ {
+ _ATL_OBJMAP_ENTRY *ptr = *iter;
+ if (!ptr)
+ continue;
+
+ if (!ptr->pCF)
+ continue;
+
+ ptr->pCF->Release();
+ ptr->pCF = NULL;
+ }
m_csObjMap.Term();
cbSize = 0;
}
}
+
+ void ExecuteObjectMain(bool bStarting)
+ {
+ for (_ATL_OBJMAP_ENTRY **iter = m_ppAutoObjMapFirst; iter <
m_ppAutoObjMapLast; iter++)
+ {
+ if (!*iter)
+ continue;
+
+ (*iter)->pfnObjectMain(bStarting);
+ }
+ }
};
__declspec(selectany) CAtlComModule _AtlComModule;
@@ -606,11 +676,20 @@ HRESULT CAtlModuleT<T>::UnregisterServer(BOOL bUnRegTypeLib,
const CLSID *pCLSID
}
template <class T>
-class CAtlDllModuleT : public CAtlModuleT<T>
+class CAtlDllModuleT
+ : public CAtlModuleT<T>
+ , private CAtlValidateModuleConfiguration<true, T>
+
{
public:
CAtlDllModuleT()
{
+ _AtlComModule.ExecuteObjectMain(true);
+ }
+
+ ~CAtlDllModuleT()
+ {
+ _AtlComModule.ExecuteObjectMain(false);
}
HRESULT DllCanUnloadNow()
@@ -659,7 +738,9 @@ public:
template <class T>
-class CAtlExeModuleT : public CAtlModuleT<T>
+class CAtlExeModuleT
+ : public CAtlModuleT<T>
+ , private CAtlValidateModuleConfiguration<false, T>
{
public:
DWORD m_dwMainThreadID;
@@ -670,10 +751,12 @@ public:
CAtlExeModuleT()
:m_dwMainThreadID(::GetCurrentThreadId())
{
+ _AtlComModule.ExecuteObjectMain(true);
}
~CAtlExeModuleT()
{
+ _AtlComModule.ExecuteObjectMain(false);
}
int WinMain(int nShowCmd)
@@ -815,12 +898,19 @@ public:
}
}
}
+
+ for (_ATL_OBJMAP_ENTRY **iter = _AtlComModule.m_ppAutoObjMapFirst; iter <
_AtlComModule.m_ppAutoObjMapLast; iter++)
+ {
+ if (*iter != NULL)
+ (*iter)->pfnObjectMain(true);
+ }
+
return S_OK;
}
void Term()
{
- _ATL_OBJMAP_ENTRY *objectMapEntry;
+ _ATL_OBJMAP_ENTRY *objectMapEntry;
if (m_pObjMap != NULL)
{
@@ -834,12 +924,19 @@ public:
objectMapEntry++;
}
}
+
+ for (_ATL_OBJMAP_ENTRY **iter = _AtlComModule.m_ppAutoObjMapFirst; iter <
_AtlComModule.m_ppAutoObjMapLast; iter++)
+ {
+ if (*iter != NULL)
+ (*iter)->pfnObjectMain(false);
+ }
+
}
HRESULT GetClassObject(REFCLSID rclsid, REFIID riid, LPVOID *ppv)
{
- _ATL_OBJMAP_ENTRY *objectMapEntry;
- HRESULT hResult;
+ _ATL_OBJMAP_ENTRY *objectMapEntry;
+ HRESULT hResult;
ATLASSERT(ppv != NULL);
if (ppv == NULL)
@@ -869,8 +966,7 @@ public:
}
if (hResult == S_OK && *ppv == NULL)
{
- // FIXME: call AtlComModuleGetClassObject
- hResult = CLASS_E_CLASSNOTAVAILABLE;
+ hResult = AtlComModuleGetClassObject(&_AtlComModule, rclsid, riid, ppv);
}
return hResult;
}
@@ -1480,9 +1576,9 @@ inline HRESULT __stdcall AtlAdvise(IUnknown *pUnkCP, IUnknown *pUnk,
const IID &
inline HRESULT __stdcall AtlUnadvise(IUnknown *pUnkCP, const IID &iid, DWORD dw)
{
- CComPtr<IConnectionPointContainer> container;
- CComPtr<IConnectionPoint> connectionPoint;
- HRESULT hResult;
+ CComPtr<IConnectionPointContainer> container;
+ CComPtr<IConnectionPoint> connectionPoint;
+ HRESULT hResult;
if (pUnkCP == NULL)
return E_INVALIDARG;
@@ -1809,14 +1905,18 @@ inline HRESULT WINAPI
AtlComModuleRegisterClassObjects(_ATL_COM_MODULE *module,
for (iter = module->m_ppAutoObjMapFirst; iter < module->m_ppAutoObjMapLast;
iter++)
{
- if (!(*iter)->pfnGetClassObject)
+ _ATL_OBJMAP_ENTRY *ptr = *iter;
+ if (!ptr)
continue;
- hr = (*iter)->pfnGetClassObject((void*)(*iter)->pfnCreateInstance,
IID_IUnknown, (void**)&unk);
+ if (!ptr->pfnGetClassObject)
+ continue;
+
+ hr = ptr->pfnGetClassObject((void*)ptr->pfnCreateInstance, IID_IUnknown,
(void**)&unk);
if (FAILED(hr))
return hr;
- hr = CoRegisterClassObject(*(*iter)->pclsid, unk, context, flags,
&(*iter)->dwRegister);
+ hr = CoRegisterClassObject(*ptr->pclsid, unk, context, flags,
&ptr->dwRegister);
unk->Release();
if (FAILED(hr))
return hr;
@@ -1837,7 +1937,11 @@ inline HRESULT WINAPI
AtlComModuleRevokeClassObjects(_ATL_COM_MODULE *module)
for (iter = module->m_ppAutoObjMapFirst; iter < module->m_ppAutoObjMapLast;
iter++)
{
- hr = CoRevokeClassObject((*iter)->dwRegister);
+ _ATL_OBJMAP_ENTRY *ptr = *iter;
+ if (!ptr)
+ continue;
+
+ hr = CoRevokeClassObject(ptr->dwRegister);
if (FAILED(hr))
return hr;
}
@@ -1845,6 +1949,41 @@ inline HRESULT WINAPI
AtlComModuleRevokeClassObjects(_ATL_COM_MODULE *module)
return S_OK;
}
+// Adapted from dll/win32/atl/atl.c
+inline HRESULT WINAPI
+AtlComModuleGetClassObject(_ATL_COM_MODULE *pm, REFCLSID rclsid, REFIID riid, void
**ppv)
+{
+ if (!pm)
+ return E_INVALIDARG;
+
+ for (_ATL_OBJMAP_ENTRY **iter = pm->m_ppAutoObjMapFirst; iter <
pm->m_ppAutoObjMapLast; iter++)
+ {
+ _ATL_OBJMAP_ENTRY *ptr = *iter;
+ if (!ptr)
+ continue;
+
+ if (IsEqualCLSID(*ptr->pclsid, rclsid) && ptr->pfnGetClassObject)
+ {
+ HRESULT hr = CLASS_E_CLASSNOTAVAILABLE;
+
+ if (!ptr->pCF)
+ {
+ CComCritSecLock<CComCriticalSection> lock(_AtlComModule.m_csObjMap,
true);
+ if (!ptr->pCF)
+ {
+ hr = ptr->pfnGetClassObject((void *)ptr->pfnCreateInstance,
IID_IUnknown, (void **)&ptr->pCF);
+ }
+ }
+ if (ptr->pCF)
+ hr = ptr->pCF->QueryInterface(riid, ppv);
+ return hr;
+ }
+ }
+
+ return CLASS_E_CLASSNOTAVAILABLE;
+}
+
+
}; // namespace ATL
#ifndef _ATL_NO_AUTOMATIC_NAMESPACE
diff --git a/sdk/lib/atl/atlcom.h b/sdk/lib/atl/atlcom.h
index c9d92edbd9c..024d85a0bef 100644
--- a/sdk/lib/atl/atlcom.h
+++ b/sdk/lib/atl/atlcom.h
@@ -30,7 +30,12 @@ namespace ATL
template <class Base, const IID *piid, class T, class Copy, class ThreadModel =
CComObjectThreadModel>
class CComEnum;
+#if defined(_WINDLL) | defined(_USRDLL)
#define DECLARE_CLASSFACTORY_EX(cf) typedef
ATL::CComCreator<ATL::CComObjectCached<cf> > _ClassFactoryCreatorClass;
+#else
+// Class factory should not change lock count
+#define DECLARE_CLASSFACTORY_EX(cf) typedef
ATL::CComCreator<ATL::CComObjectNoLock<cf>> _ClassFactoryCreatorClass;
+#endif
#define DECLARE_CLASSFACTORY() DECLARE_CLASSFACTORY_EX(ATL::CComClassFactory)
#define DECLARE_CLASSFACTORY_SINGLETON(obj)
DECLARE_CLASSFACTORY_EX(ATL::CComClassFactorySingleton<obj>)
@@ -539,6 +544,40 @@ public:
}
};
+
+template <class Base>
+class CComObjectNoLock : public Base
+{
+ public:
+ CComObjectNoLock(void* = NULL)
+ {
+ }
+
+ virtual ~CComObjectNoLock()
+ {
+ this->FinalRelease();
+ }
+
+ STDMETHOD_(ULONG, AddRef)()
+ {
+ return this->InternalAddRef();
+ }
+
+ STDMETHOD_(ULONG, Release)()
+ {
+ ULONG newRefCount = this->InternalRelease();
+ if (newRefCount == 0)
+ delete this;
+ return newRefCount;
+ }
+
+ STDMETHOD(QueryInterface)(REFIID iid, void **ppvObject)
+ {
+ return this->_InternalQueryInterface(iid, ppvObject);
+ }
+};
+
+
#define BEGIN_COM_MAP(x) \
public: \
typedef x _ComMapClass; \
@@ -663,6 +702,24 @@ public:
class::GetCategoryMap, \
class::ObjectMain },
+
+
+#define OBJECT_ENTRY_AUTO(clsid, class)
\
+ ATL::_ATL_OBJMAP_ENTRY __objMap_##class = {
\
+ &clsid,
\
+ class ::UpdateRegistry,
\
+ class ::_ClassFactoryCreatorClass::CreateInstance,
\
+ class ::_CreatorClass::CreateInstance,
\
+ NULL,
\
+ 0,
\
+ class ::GetObjectDescription,
\
+ class ::GetCategoryMap,
\
+ class ::ObjectMain};
\
+ extern "C" _ATLALLOC("ATL$__m") ATL::_ATL_OBJMAP_ENTRY *const
__pobjMap_##class = &__objMap_##class; \
+ OBJECT_ENTRY_PRAGMA(class)
+
+
+
class CComClassFactory :
public IClassFactory,
public CComObjectRootEx<CComGlobalsThreadModel>
@@ -772,6 +829,7 @@ class CComCoClass
{
public:
DECLARE_CLASSFACTORY()
+ //DECLARE_AGGREGATABLE(T) // This should be here, but gcc...
static LPCTSTR WINAPI GetObjectDescription()
{