https://git.reactos.org/?p=reactos.git;a=commitdiff;h=4c95339da0addaa59a0d58...
commit 4c95339da0addaa59a0d584c240a34f708079ce5 Author: Victor Perevertkin victor.perevertkin@reactos.org AuthorDate: Wed Nov 25 03:18:31 2020 +0300 Commit: Victor Perevertkin victor.perevertkin@reactos.org CommitDate: Mon Jan 4 16:50:33 2021 +0300
[NTOS:IO] Refactoring of the driver initialization code (2)
- Do not hold the IopDriverLoadResource while trying to reference a driver object (but still acquire it when we actually need to load a driver) - Change IopLoadDriver and IopInitializeDriverModule to use registry handle instead of a service name string and/or full registry path - Do not try to reference a driver object inside IopLoadDriver. It's supposed to be done before the function call --- ntoskrnl/include/internal/io.h | 7 +- ntoskrnl/io/iomgr/driver.c | 387 ++++++++++++++++++++++++----------------- ntoskrnl/io/pnpmgr/devaction.c | 37 +++- 3 files changed, 267 insertions(+), 164 deletions(-)
diff --git a/ntoskrnl/include/internal/io.h b/ntoskrnl/include/internal/io.h index 5da5950ba67..ad51265ba3f 100644 --- a/ntoskrnl/include/internal/io.h +++ b/ntoskrnl/include/internal/io.h @@ -1127,14 +1127,13 @@ IopLoadServiceModule(
NTSTATUS IopLoadDriver( - _In_opt_ PCUNICODE_STRING RegistryPath, - _Inout_ PDRIVER_OBJECT *DriverObject -); + _In_ HANDLE ServiceHandle, + _Out_ PDRIVER_OBJECT *DriverObject);
NTSTATUS IopInitializeDriverModule( _In_ PLDR_DATA_TABLE_ENTRY ModuleObject, - _In_ PUNICODE_STRING ServiceName, + _In_ HANDLE ServiceHandle, _Out_ PDRIVER_OBJECT *DriverObject, _Out_ NTSTATUS *DriverEntryStatus);
diff --git a/ntoskrnl/io/iomgr/driver.c b/ntoskrnl/io/iomgr/driver.c index a4b39b77f6d..994d7375af8 100644 --- a/ntoskrnl/io/iomgr/driver.c +++ b/ntoskrnl/io/iomgr/driver.c @@ -28,6 +28,7 @@ KSPIN_LOCK DriverBootReinitListLock;
UNICODE_STRING IopHardwareDatabaseKey = RTL_CONSTANT_STRING(L"\REGISTRY\MACHINE\HARDWARE\DESCRIPTION\SYSTEM"); +static const WCHAR ServicesKeyName[] = L"\Registry\Machine\System\CurrentControlSet\Services\";
POBJECT_TYPE IoDriverObjectType = NULL;
@@ -47,7 +48,7 @@ PLIST_ENTRY IopGroupTable; typedef struct _LOAD_UNLOAD_PARAMS { NTSTATUS Status; - PCUNICODE_STRING RegistryPath; + PUNICODE_STRING RegistryPath; WORK_QUEUE_ITEM WorkItem; KEVENT Event; PDRIVER_OBJECT DriverObject; @@ -56,7 +57,7 @@ typedef struct _LOAD_UNLOAD_PARAMS
NTSTATUS IopDoLoadUnloadDriver( - _In_opt_ PCUNICODE_STRING RegistryPath, + _In_opt_ PUNICODE_STRING RegistryPath, _Inout_ PDRIVER_OBJECT *DriverObject);
/* PRIVATE FUNCTIONS **********************************************************/ @@ -135,7 +136,6 @@ IopGetDriverObject( DPRINT("IopGetDriverObject(%p '%wZ' %x)\n", DriverObject, ServiceName, FileSystem);
- ASSERT(ExIsResourceAcquiredExclusiveLite(&IopDriverLoadResource)); *DriverObject = NULL;
/* Create ModuleName string */ @@ -480,8 +480,8 @@ IopLoadServiceModule( * Module object representing the driver. It can be retrieved by IopLoadServiceModule. * Freed on failure, so in a such case this should not be accessed anymore * - * @param[in] ServiceName - * Name of the service (as in the registry) + * @param[in] ServiceHandle + * Handle to a driver's CCS/Services/<ServiceName> key * * @param[out] DriverObject * This contains the driver object if it was created (even with unsuccessfull result) @@ -495,56 +495,58 @@ IopLoadServiceModule( NTSTATUS IopInitializeDriverModule( _In_ PLDR_DATA_TABLE_ENTRY ModuleObject, - _In_ PUNICODE_STRING ServiceName, + _In_ HANDLE ServiceHandle, _Out_ PDRIVER_OBJECT *OutDriverObject, _Out_ NTSTATUS *DriverEntryStatus) { - static const WCHAR ServicesKeyName[] = L"\Registry\Machine\System\CurrentControlSet\Services\"; - UNICODE_STRING DriverName; - UNICODE_STRING RegistryKey; + UNICODE_STRING DriverName, RegistryPath, ServiceName; NTSTATUS Status;
PAGED_CODE();
- ASSERT(ServiceName && ServiceName->Length != 0); - - // Make the registry path for the driver - RegistryKey.Length = 0; - RegistryKey.MaximumLength = sizeof(ServicesKeyName) + ServiceName->Length; - RegistryKey.Buffer = ExAllocatePoolWithTag(PagedPool, - RegistryKey.MaximumLength, - TAG_IO); - if (RegistryKey.Buffer == NULL) + // get the ServiceName + PKEY_BASIC_INFORMATION basicInfo; + ULONG infoLength; + Status = ZwQueryKey(ServiceHandle, KeyBasicInformation, NULL, 0, &infoLength); + if (Status == STATUS_BUFFER_TOO_SMALL) { - return STATUS_INSUFFICIENT_RESOURCES; - } - RtlAppendUnicodeToString(&RegistryKey, ServicesKeyName); - RtlAppendUnicodeStringToString(&RegistryKey, ServiceName); + basicInfo = ExAllocatePoolWithTag(PagedPool, infoLength, TAG_IO); + if (!basicInfo) + { + MmUnloadSystemImage(ModuleObject); + return STATUS_INSUFFICIENT_RESOURCES; + }
- // Open the registry key for this driver (it has to exist) - HANDLE serviceHandle; - PKEY_VALUE_FULL_INFORMATION kvInfo; + Status = ZwQueryKey(ServiceHandle, KeyBasicInformation, basicInfo, infoLength, &infoLength); + if (!NT_SUCCESS(Status)) + { + ExFreePoolWithTag(basicInfo, TAG_IO); + MmUnloadSystemImage(ModuleObject); + return Status; + }
- Status = IopOpenRegistryKeyEx(&serviceHandle, NULL, &RegistryKey, KEY_READ); - if (!NT_SUCCESS(Status)) + ServiceName.Length = basicInfo->NameLength; + ServiceName.MaximumLength = basicInfo->NameLength; + ServiceName.Buffer = basicInfo->Name; + } + else { - RtlFreeUnicodeString(&RegistryKey); MmUnloadSystemImage(ModuleObject); - return Status; + return NT_SUCCESS(Status) ? STATUS_UNSUCCESSFUL : Status; }
// Make the DriverName field of a DRIVER_OBJECT + PKEY_VALUE_FULL_INFORMATION kvInfo;
// 1. Check the "ObjectName" field in the driver's registry key (it has the priority) - Status = IopGetRegistryValue(serviceHandle, L"ObjectName", &kvInfo); + Status = IopGetRegistryValue(ServiceHandle, L"ObjectName", &kvInfo); if (NT_SUCCESS(Status)) { // we're got the ObjectName. Use it to create the DRIVER_OBJECT if (kvInfo->Type != REG_SZ || kvInfo->DataLength == 0) { ExFreePool(kvInfo); - ZwClose(serviceHandle); - RtlFreeUnicodeString(&RegistryKey); + ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName MmUnloadSystemImage(ModuleObject); return STATUS_ILL_FORMED_SERVICE_ENTRY; } @@ -555,8 +557,7 @@ IopInitializeDriverModule( if (!DriverName.Buffer) { ExFreePool(kvInfo); - ZwClose(serviceHandle); - RtlFreeUnicodeString(&RegistryKey); + ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName MmUnloadSystemImage(ModuleObject); return STATUS_INSUFFICIENT_RESOURCES; } @@ -571,12 +572,11 @@ IopInitializeDriverModule( // 2. there is no "ObjectName" - construct it ourselves. Depending on a driver type, // it will be either "\Driver<ServiceName>" or "\FileSystem<ServiceName>"
- Status = IopGetRegistryValue(serviceHandle, L"Type", &kvInfo); + Status = IopGetRegistryValue(ServiceHandle, L"Type", &kvInfo); if (!NT_SUCCESS(Status) || kvInfo->Type != REG_DWORD) { ExFreePool(kvInfo); - ZwClose(serviceHandle); - RtlFreeUnicodeString(&RegistryKey); + ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName MmUnloadSystemImage(ModuleObject); return STATUS_ILL_FORMED_SERVICE_ENTRY; } @@ -587,14 +587,13 @@ IopInitializeDriverModule(
DriverName.Length = 0; if (driverType == SERVICE_RECOGNIZER_DRIVER || driverType == SERVICE_FILE_SYSTEM_DRIVER) - DriverName.MaximumLength = sizeof(FILESYSTEM_ROOT_NAME) + ServiceName->Length; + DriverName.MaximumLength = sizeof(FILESYSTEM_ROOT_NAME) + ServiceName.Length; else - DriverName.MaximumLength = sizeof(DRIVER_ROOT_NAME) + ServiceName->Length; + DriverName.MaximumLength = sizeof(DRIVER_ROOT_NAME) + ServiceName.Length; DriverName.Buffer = ExAllocatePoolWithTag(NonPagedPool, DriverName.MaximumLength, TAG_IO); if (!DriverName.Buffer) { - ZwClose(serviceHandle); - RtlFreeUnicodeString(&RegistryKey); + ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName MmUnloadSystemImage(ModuleObject); return STATUS_INSUFFICIENT_RESOURCES; } @@ -604,13 +603,53 @@ IopInitializeDriverModule( else RtlAppendUnicodeToString(&DriverName, DRIVER_ROOT_NAME);
- RtlAppendUnicodeStringToString(&DriverName, ServiceName); + RtlAppendUnicodeStringToString(&DriverName, &ServiceName); }
- ZwClose(serviceHandle); - DPRINT("Driver name: '%wZ'\n", &DriverName);
+ // obtain the registry path for the DriverInit routine + PKEY_NAME_INFORMATION nameInfo; + Status = ZwQueryKey(ServiceHandle, KeyNameInformation, NULL, 0, &infoLength); + if (Status == STATUS_BUFFER_TOO_SMALL) + { + nameInfo = ExAllocatePoolWithTag(NonPagedPool, infoLength, TAG_IO); + if (nameInfo) + { + Status = ZwQueryKey(ServiceHandle, + KeyNameInformation, + nameInfo, + infoLength, + &infoLength); + if (NT_SUCCESS(Status)) + { + RegistryPath.Length = nameInfo->NameLength; + RegistryPath.MaximumLength = nameInfo->NameLength; + RegistryPath.Buffer = nameInfo->Name; + } + else + { + ExFreePoolWithTag(nameInfo, TAG_IO); + } + } + else + { + Status = STATUS_INSUFFICIENT_RESOURCES; + } + } + else + { + Status = NT_SUCCESS(Status) ? STATUS_UNSUCCESSFUL : Status; + } + + if (!NT_SUCCESS(Status)) + { + ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName + RtlFreeUnicodeString(&DriverName); + MmUnloadSystemImage(ModuleObject); + return Status; + } + // create the driver object UINT32 ObjectSize = sizeof(DRIVER_OBJECT) + sizeof(EXTENDED_DRIVER_EXTENSION); OBJECT_ATTRIBUTES objAttrs; @@ -632,7 +671,8 @@ IopInitializeDriverModule( (PVOID*)&driverObject); if (!NT_SUCCESS(Status)) { - RtlFreeUnicodeString(&RegistryKey); + ExFreePoolWithTag(nameInfo, TAG_IO); // container for RegistryPath + ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName RtlFreeUnicodeString(&DriverName); MmUnloadSystemImage(ModuleObject); DPRINT1("Error while creating driver object "%wZ" status %x\n", &DriverName, Status); @@ -665,9 +705,9 @@ IopInitializeDriverModule( Status = ObInsertObject(driverObject, NULL, FILE_READ_DATA, 0, NULL, &hDriver); if (!NT_SUCCESS(Status)) { - RtlFreeUnicodeString(&RegistryKey); + ExFreePoolWithTag(nameInfo, TAG_IO); + ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName RtlFreeUnicodeString(&DriverName); - MmUnloadSystemImage(ModuleObject); // TODO: is it needed here? return Status; }
@@ -684,7 +724,8 @@ IopInitializeDriverModule(
if (!NT_SUCCESS(Status)) { - RtlFreeUnicodeString(&RegistryKey); + ExFreePoolWithTag(nameInfo, TAG_IO); // container for RegistryPath + ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName RtlFreeUnicodeString(&DriverName); return Status; } @@ -693,7 +734,7 @@ IopInitializeDriverModule( UNICODE_STRING serviceKeyName; serviceKeyName.Length = 0; // put a NULL character at the end for Windows compatibility - serviceKeyName.MaximumLength = ServiceName->MaximumLength + sizeof(UNICODE_NULL); + serviceKeyName.MaximumLength = ServiceName.MaximumLength + sizeof(UNICODE_NULL); serviceKeyName.Buffer = ExAllocatePoolWithTag(NonPagedPool, serviceKeyName.MaximumLength, TAG_IO); @@ -701,13 +742,15 @@ IopInitializeDriverModule( { ObMakeTemporaryObject(driverObject); ObDereferenceObject(driverObject); - RtlFreeUnicodeString(&RegistryKey); + ExFreePoolWithTag(nameInfo, TAG_IO); // container for RegistryPath + ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName RtlFreeUnicodeString(&DriverName); return STATUS_INSUFFICIENT_RESOURCES; }
/* Copy the name and set it in the driver extension */ - RtlCopyUnicodeString(&serviceKeyName, ServiceName); + RtlCopyUnicodeString(&serviceKeyName, &ServiceName); + ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName driverObject->DriverExtension->ServiceKeyName = serviceKeyName;
/* Make a copy of the driver name to store in the driver object */ @@ -722,7 +765,7 @@ IopInitializeDriverModule( { ObMakeTemporaryObject(driverObject); ObDereferenceObject(driverObject); - RtlFreeUnicodeString(&RegistryKey); + ExFreePoolWithTag(nameInfo, TAG_IO); // container for RegistryPath RtlFreeUnicodeString(&DriverName); return STATUS_INSUFFICIENT_RESOURCES; } @@ -731,7 +774,7 @@ IopInitializeDriverModule( driverObject->DriverName = driverNamePaged;
/* Finally, call its init function */ - Status = driverObject->DriverInit(driverObject, &RegistryKey); + Status = driverObject->DriverInit(driverObject, &RegistryPath); *DriverEntryStatus = Status; if (!NT_SUCCESS(Status)) { @@ -770,7 +813,7 @@ IopInitializeDriverModule(
// TODO: for legacy drivers, unload the driver if it didn't create any DO
- RtlFreeUnicodeString(&RegistryKey); + ExFreePoolWithTag(nameInfo, TAG_IO); // container for RegistryPath RtlFreeUnicodeString(&DriverName);
if (!NT_SUCCESS(Status)) @@ -830,37 +873,62 @@ IopAttachFilterDriversCallback( ServiceName.MaximumLength = ServiceName.Length = (USHORT)wcslen(Filters) * sizeof(WCHAR);
- KeEnterCriticalRegion(); - ExAcquireResourceExclusiveLite(&IopDriverLoadResource, TRUE); + UNICODE_STRING RegistryPath; + + // Make the registry path for the driver + RegistryPath.Length = 0; + RegistryPath.MaximumLength = sizeof(ServicesKeyName) + ServiceName.Length; + RegistryPath.Buffer = ExAllocatePoolWithTag(PagedPool, RegistryPath.MaximumLength, TAG_IO); + if (RegistryPath.Buffer == NULL) + { + return STATUS_INSUFFICIENT_RESOURCES; + } + RtlAppendUnicodeToString(&RegistryPath, ServicesKeyName); + RtlAppendUnicodeStringToString(&RegistryPath, &ServiceName); + + HANDLE serviceHandle; + Status = IopOpenRegistryKeyEx(&serviceHandle, NULL, &RegistryPath, KEY_READ); + RtlFreeUnicodeString(&RegistryPath); + if (!NT_SUCCESS(Status)) + { + return Status; + } + Status = IopGetDriverObject(&DriverObject, &ServiceName, FALSE); if (!NT_SUCCESS(Status)) { + KeEnterCriticalRegion(); + ExAcquireResourceExclusiveLite(&IopDriverLoadResource, TRUE); + /* Load and initialize the filter driver */ Status = IopLoadServiceModule(&ServiceName, &ModuleObject); if (!NT_SUCCESS(Status)) { ExReleaseResourceLite(&IopDriverLoadResource); KeLeaveCriticalRegion(); + ZwClose(serviceHandle); return Status; }
NTSTATUS driverEntryStatus; Status = IopInitializeDriverModule(ModuleObject, - &ServiceName, + serviceHandle, &DriverObject, &driverEntryStatus); + + ExReleaseResourceLite(&IopDriverLoadResource); + KeLeaveCriticalRegion(); + if (!NT_SUCCESS(Status)) { - ExReleaseResourceLite(&IopDriverLoadResource); - KeLeaveCriticalRegion(); + ZwClose(serviceHandle); return Status; } }
- ExReleaseResourceLite(&IopDriverLoadResource); - KeLeaveCriticalRegion(); + ZwClose(serviceHandle);
Status = IopInitializeDevice(DeviceNode, DriverObject);
@@ -1104,6 +1172,28 @@ IopInitializeBuiltinDriver(IN PLDR_DATA_TABLE_ENTRY BootLdrEntry) FileExtension[0] = UNICODE_NULL; }
+ UNICODE_STRING RegistryPath; + + // Make the registry path for the driver + RegistryPath.Length = 0; + RegistryPath.MaximumLength = sizeof(ServicesKeyName) + ServiceName.Length; + RegistryPath.Buffer = ExAllocatePoolWithTag(PagedPool, RegistryPath.MaximumLength, TAG_IO); + if (RegistryPath.Buffer == NULL) + { + return STATUS_INSUFFICIENT_RESOURCES; + } + RtlAppendUnicodeToString(&RegistryPath, ServicesKeyName); + RtlAppendUnicodeStringToString(&RegistryPath, &ServiceName); + RtlFreeUnicodeString(&ServiceName); + + HANDLE serviceHandle; + Status = IopOpenRegistryKeyEx(&serviceHandle, NULL, &RegistryPath, KEY_READ); + RtlFreeUnicodeString(&RegistryPath); + if (!NT_SUCCESS(Status)) + { + return Status; + } + /* Lookup the new Ldr entry in PsLoadedModuleList */ NextEntry = PsLoadedModuleList.Flink; while (NextEntry != &PsLoadedModuleList) @@ -1125,10 +1215,10 @@ IopInitializeBuiltinDriver(IN PLDR_DATA_TABLE_ENTRY BootLdrEntry) */ NTSTATUS driverEntryStatus; Status = IopInitializeDriverModule(LdrEntry, - &ServiceName, + serviceHandle, &DriverObject, &driverEntryStatus); - RtlFreeUnicodeString(&ServiceName); + ZwClose(serviceHandle);
if (!NT_SUCCESS(Status)) { @@ -2050,69 +2140,47 @@ IoGetDriverObjectExtension(IN PDRIVER_OBJECT DriverObject,
NTSTATUS IopLoadDriver( - _In_opt_ PCUNICODE_STRING RegistryPath, - _Inout_ PDRIVER_OBJECT *DriverObject) + _In_ HANDLE ServiceHandle, + _Out_ PDRIVER_OBJECT *DriverObject) { - RTL_QUERY_REGISTRY_TABLE QueryTable[3]; UNICODE_STRING ImagePath; - UNICODE_STRING ServiceName; NTSTATUS Status; - ULONG Type; PLDR_DATA_TABLE_ENTRY ModuleObject; PVOID BaseAddress; - WCHAR *cur; - - RtlInitUnicodeString(&ImagePath, NULL);
- /* - * Get the service name from the registry key name. - */ - ASSERT(RegistryPath->Length >= sizeof(WCHAR)); - - ServiceName = *RegistryPath; - cur = RegistryPath->Buffer + RegistryPath->Length / sizeof(WCHAR) - 1; - while (RegistryPath->Buffer != cur) + PKEY_VALUE_FULL_INFORMATION kvInfo; + Status = IopGetRegistryValue(ServiceHandle, L"ImagePath", &kvInfo); + if (NT_SUCCESS(Status)) { - if (*cur == L'\') + if (kvInfo->Type != REG_EXPAND_SZ || kvInfo->DataLength == 0) { - ServiceName.Buffer = cur + 1; - ServiceName.Length = RegistryPath->Length - - (USHORT)((ULONG_PTR)ServiceName.Buffer - - (ULONG_PTR)RegistryPath->Buffer); - break; + ExFreePool(kvInfo); + return STATUS_ILL_FORMED_SERVICE_ENTRY; } - cur--; - } - - /* - * Get service type. - */ - RtlZeroMemory(&QueryTable, sizeof(QueryTable)); - - RtlInitUnicodeString(&ImagePath, NULL); - - QueryTable[0].Name = L"Type"; - QueryTable[0].Flags = RTL_QUERY_REGISTRY_DIRECT | RTL_QUERY_REGISTRY_REQUIRED; - QueryTable[0].EntryContext = &Type;
- QueryTable[1].Name = L"ImagePath"; - QueryTable[1].Flags = RTL_QUERY_REGISTRY_DIRECT; - QueryTable[1].EntryContext = &ImagePath; + ImagePath.Length = kvInfo->DataLength - sizeof(UNICODE_NULL), + ImagePath.MaximumLength = kvInfo->DataLength, + ImagePath.Buffer = ExAllocatePoolWithTag(PagedPool, ImagePath.MaximumLength, TAG_RTLREGISTRY); + if (!ImagePath.Buffer) + { + ExFreePool(kvInfo); + return STATUS_INSUFFICIENT_RESOURCES; + }
- Status = RtlQueryRegistryValues(RTL_REGISTRY_ABSOLUTE, - RegistryPath->Buffer, - QueryTable, NULL, NULL); - if (!NT_SUCCESS(Status)) + RtlMoveMemory(ImagePath.Buffer, + (PVOID)((ULONG_PTR)kvInfo + kvInfo->DataOffset), + ImagePath.Length); + ExFreePool(kvInfo); + } + else { - DPRINT("RtlQueryRegistryValues() failed (Status %lx)\n", Status); - if (ImagePath.Buffer) ExFreePool(ImagePath.Buffer); return Status; }
/* * Normalize the image path for all later processing. */ - Status = IopNormalizeImagePath(&ImagePath, &ServiceName); + Status = IopNormalizeImagePath(&ImagePath, NULL); if (!NT_SUCCESS(Status)) { DPRINT("IopNormalizeImagePath() failed (Status %x)\n", Status); @@ -2120,67 +2188,65 @@ IopLoadDriver( }
DPRINT("FullImagePath: '%wZ'\n", &ImagePath); - DPRINT("Type: %lx\n", Type);
KeEnterCriticalRegion(); ExAcquireResourceExclusiveLite(&IopDriverLoadResource, TRUE); + /* - * Get existing DriverObject pointer (in case the driver - * has already been loaded and initialized). + * Load the driver module */ - Status = IopGetDriverObject(DriverObject, - &ServiceName, - (Type == SERVICE_FILE_SYSTEM_DRIVER || - Type == SERVICE_RECOGNIZER_DRIVER)); + DPRINT("Loading module from %wZ\n", &ImagePath); + Status = MmLoadSystemImage(&ImagePath, NULL, NULL, 0, (PVOID)&ModuleObject, &BaseAddress); + RtlFreeUnicodeString(&ImagePath);
if (!NT_SUCCESS(Status)) { - /* - * Load the driver module - */ - DPRINT("Loading module from %wZ\n", &ImagePath); - Status = MmLoadSystemImage(&ImagePath, NULL, NULL, 0, (PVOID)&ModuleObject, &BaseAddress); - if (!NT_SUCCESS(Status)) - { - DPRINT("MmLoadSystemImage() failed (Status %lx)\n", Status); - ExReleaseResourceLite(&IopDriverLoadResource); - KeLeaveCriticalRegion(); - return Status; - } - - /* - * Initialize the driver module if it's loaded for the first time - */ - IopDisplayLoadingMessage(&ServiceName); - - NTSTATUS driverEntryStatus; - Status = IopInitializeDriverModule(ModuleObject, - &ServiceName, - DriverObject, - &driverEntryStatus); - if (!NT_SUCCESS(Status)) - { - DPRINT1("IopInitializeDriverModule() failed (Status %lx)\n", Status); - ExReleaseResourceLite(&IopDriverLoadResource); - KeLeaveCriticalRegion(); - return Status; - } - + DPRINT("MmLoadSystemImage() failed (Status %lx)\n", Status); ExReleaseResourceLite(&IopDriverLoadResource); KeLeaveCriticalRegion(); + return Status; } - else + + // Display the loading message + ULONG infoLength; + Status = ZwQueryKey(ServiceHandle, KeyBasicInformation, NULL, 0, &infoLength); + if (Status == STATUS_BUFFER_TOO_SMALL) { - ExReleaseResourceLite(&IopDriverLoadResource); - KeLeaveCriticalRegion(); + PKEY_BASIC_INFORMATION servName = ExAllocatePoolWithTag(PagedPool, infoLength, TAG_IO); + if (servName) + { + Status = ZwQueryKey(ServiceHandle, + KeyBasicInformation, + servName, + infoLength, + &infoLength); + if (NT_SUCCESS(Status)) + { + UNICODE_STRING serviceName = { + .Length = servName->NameLength, + .MaximumLength = servName->NameLength, + .Buffer = servName->Name + };
- DPRINT("DriverObject already exist in ObjectManager\n"); - Status = STATUS_IMAGE_ALREADY_LOADED; + IopDisplayLoadingMessage(&serviceName); + } + ExFreePoolWithTag(servName, TAG_IO); + } + }
- /* IopGetDriverObject references the DriverObject, so dereference it */ - ObDereferenceObject(*DriverObject); + NTSTATUS driverEntryStatus; + Status = IopInitializeDriverModule(ModuleObject, + ServiceHandle, + DriverObject, + &driverEntryStatus); + if (!NT_SUCCESS(Status)) + { + DPRINT1("IopInitializeDriverModule() failed (Status %lx)\n", Status); }
+ ExReleaseResourceLite(&IopDriverLoadResource); + KeLeaveCriticalRegion(); + return Status; }
@@ -2203,7 +2269,18 @@ IopLoadUnloadDriverWorker( else { // load request - LoadParams->Status = IopLoadDriver(LoadParams->RegistryPath, &LoadParams->DriverObject); + HANDLE serviceHandle; + NTSTATUS status; + status = IopOpenRegistryKeyEx(&serviceHandle, NULL, LoadParams->RegistryPath, KEY_READ); + if (!NT_SUCCESS(status)) + { + LoadParams->Status = status; + } + else + { + LoadParams->Status = IopLoadDriver(serviceHandle, &LoadParams->DriverObject); + ZwClose(serviceHandle); + } }
if (LoadParams->SetEvent) @@ -2223,7 +2300,7 @@ IopLoadUnloadDriverWorker( */ NTSTATUS IopDoLoadUnloadDriver( - _In_opt_ PCUNICODE_STRING RegistryPath, + _In_opt_ PUNICODE_STRING RegistryPath, _Inout_ PDRIVER_OBJECT *DriverObject) { LOAD_UNLOAD_PARAMS LoadParams; diff --git a/ntoskrnl/io/pnpmgr/devaction.c b/ntoskrnl/io/pnpmgr/devaction.c index aebbf1a01c2..17347573216 100644 --- a/ntoskrnl/io/pnpmgr/devaction.c +++ b/ntoskrnl/io/pnpmgr/devaction.c @@ -47,6 +47,7 @@ WORK_QUEUE_ITEM IopDeviceActionWorkItem; BOOLEAN IopDeviceActionInProgress; KSPIN_LOCK IopDeviceActionLock; KEVENT PiEnumerationFinished; +static const WCHAR ServicesKeyName[] = L"\Registry\Machine\System\CurrentControlSet\Services\";
/* TYPES *********************************************************************/
@@ -1062,8 +1063,6 @@ IopActionInitChildServices(PDEVICE_NODE DeviceNode, PLDR_DATA_TABLE_ENTRY ModuleObject; PDRIVER_OBJECT DriverObject;
- KeEnterCriticalRegion(); - ExAcquireResourceExclusiveLite(&IopDriverLoadResource, TRUE); /* Get existing DriverObject pointer (in case the driver has already been loaded and initialized) */ Status = IopGetDriverObject( @@ -1073,17 +1072,44 @@ IopActionInitChildServices(PDEVICE_NODE DeviceNode,
if (!NT_SUCCESS(Status)) { + KeEnterCriticalRegion(); + ExAcquireResourceExclusiveLite(&IopDriverLoadResource, TRUE); + /* Driver is not initialized, try to load it */ Status = IopLoadServiceModule(&DeviceNode->ServiceName, &ModuleObject);
if (NT_SUCCESS(Status) || Status == STATUS_IMAGE_ALREADY_LOADED) { + UNICODE_STRING RegistryPath; + + // obtain a handle for driver's RegistryPath + RegistryPath.Length = 0; + RegistryPath.MaximumLength = sizeof(ServicesKeyName) + DeviceNode->ServiceName.Length; + RegistryPath.Buffer = ExAllocatePoolWithTag(PagedPool, RegistryPath.MaximumLength, TAG_IO); + if (RegistryPath.Buffer == NULL) + { + Status = STATUS_INSUFFICIENT_RESOURCES; + goto OpenHandleFail; + } + RtlAppendUnicodeToString(&RegistryPath, ServicesKeyName); + RtlAppendUnicodeStringToString(&RegistryPath, &DeviceNode->ServiceName); + + HANDLE serviceHandle; + Status = IopOpenRegistryKeyEx(&serviceHandle, NULL, &RegistryPath, KEY_READ); + RtlFreeUnicodeString(&RegistryPath); + if (!NT_SUCCESS(Status)) + { + goto OpenHandleFail; + } + /* Initialize the driver */ NTSTATUS driverEntryStatus; Status = IopInitializeDriverModule(ModuleObject, - &DeviceNode->ServiceName, + serviceHandle, &DriverObject, &driverEntryStatus); + ZwClose(serviceHandle); + if (!NT_SUCCESS(Status)) DeviceNode->Problem = CM_PROB_FAILED_DRIVER_ENTRY; } @@ -1099,9 +1125,10 @@ IopActionInitChildServices(PDEVICE_NODE DeviceNode, if (!BootDrivers) DeviceNode->Problem = CM_PROB_DRIVER_FAILED_LOAD; } +OpenHandleFail: + ExReleaseResourceLite(&IopDriverLoadResource); + KeLeaveCriticalRegion(); } - ExReleaseResourceLite(&IopDriverLoadResource); - KeLeaveCriticalRegion();
/* Driver is loaded and initialized at this point */ if (NT_SUCCESS(Status))