- fixed RegSetValueEx to take the null-terminating byte for strings into account when the caller forgot it
- capture the driver service name string in NtLoadDriver
Modified: trunk/reactos/lib/advapi32/reg/reg.c
Modified: trunk/reactos/ntoskrnl/io/driver.c

Modified: trunk/reactos/lib/advapi32/reg/reg.c
--- trunk/reactos/lib/advapi32/reg/reg.c	2005-05-05 00:07:27 UTC (rev 14983)
+++ trunk/reactos/lib/advapi32/reg/reg.c	2005-05-05 02:46:17 UTC (rev 14984)
@@ -3003,14 +3003,21 @@
       pValueName = NULL;
     }
 
-  if ((dwType == REG_SZ) ||
-      (dwType == REG_MULTI_SZ) ||
-      (dwType == REG_EXPAND_SZ))
+  if (((dwType == REG_SZ) ||
+       (dwType == REG_MULTI_SZ) ||
+       (dwType == REG_EXPAND_SZ)) &&
+      (cbData != 0))
     {
+      /* NT adds one if the caller forgot the NULL-termination character */
+      if (lpData[cbData - 1] != '\0')
+      {
+         cbData++;
+      }
+      
       RtlInitAnsiString (&AnsiString,
 			 NULL);
       AnsiString.Buffer = (PSTR)lpData;
-      AnsiString.Length = cbData;
+      AnsiString.Length = cbData - 1;
       AnsiString.MaximumLength = cbData;
       RtlAnsiStringToUnicodeString (&Data,
 				    &AnsiString,
@@ -3088,6 +3095,15 @@
       RtlInitUnicodeString (&ValueName, L"");
     }
   pValueName = &ValueName;
+  
+  if (((dwType == REG_SZ) ||
+       (dwType == REG_MULTI_SZ) ||
+       (dwType == REG_EXPAND_SZ)) &&
+      (cbData != 0) && (*(((PWCHAR)lpData) + (cbData / sizeof(WCHAR)) - 1) != L'\0'))
+    {
+      /* NT adds one if the caller forgot the NULL-termination character */
+      cbData += sizeof(WCHAR);
+    }
 
   Status = NtSetValueKey (KeyHandle,
 			  pValueName,
@@ -3118,51 +3134,41 @@
 	      LPCSTR lpData,
 	      DWORD cbData)
 {
-  WCHAR SubKeyNameBuffer[MAX_PATH+1];
-  UNICODE_STRING SubKeyName;
-  UNICODE_STRING Data;
-  ANSI_STRING AnsiString;
-  LONG DataSize;
-  LONG ErrorCode;
+  LONG ret;
+  HKEY hSubKey;
+  
+  if (dwType != REG_SZ)
+  {
+     return ERROR_INVALID_PARAMETER;
+  }
+  
+  if (lpSubKey != NULL && lpSubKey[0] != '\0')
+  {
+     ret = RegCreateKeyA(hKey,
+                         lpSubKey,
+                         &hSubKey);
 
-  if (lpData == NULL)
-    {
-      SetLastError (ERROR_INVALID_PARAMETER);
-      return ERROR_INVALID_PARAMETER;
-    }
+     if (ret != ERROR_SUCCESS)
+     {
+        return ret;
+     }
+  }
+  else
+     hSubKey = hKey;
+  
+  ret = RegSetValueExA(hSubKey,
+                       NULL,
+                       0,
+                       REG_SZ,
+                       lpData,
+                       strlen(lpData) + 1);
+  
+  if (hSubKey != hKey)
+  {
+     RegCloseKey(hSubKey);
+  }
 
-  RtlInitUnicodeString (&SubKeyName, NULL);
-  RtlInitUnicodeString (&Data, NULL);
-  if (lpSubKey != NULL && (strlen(lpSubKey) != 0))
-    {
-      RtlInitAnsiString (&AnsiString, (LPSTR)lpSubKey);
-      SubKeyName.Buffer = &SubKeyNameBuffer[0];
-      SubKeyName.MaximumLength = sizeof(SubKeyNameBuffer);
-      RtlAnsiStringToUnicodeString (&SubKeyName, &AnsiString, FALSE);
-    }
-
-  DataSize = cbData * sizeof(WCHAR);
-  Data.MaximumLength = DataSize;
-  Data.Buffer = RtlAllocateHeap (ProcessHeap,
-				 0,
-				 DataSize);
-  if (Data.Buffer == NULL)
-    {
-      SetLastError (ERROR_OUTOFMEMORY);
-      return ERROR_OUTOFMEMORY;
-    }
-
-  ErrorCode = RegSetValueW (hKey,
-			    (LPCWSTR)SubKeyName.Buffer,
-			    dwType,
-			    Data.Buffer,
-			    DataSize);
-
-  RtlFreeHeap (ProcessHeap,
-	       0,
-	       Data.Buffer);
-
-  return ErrorCode;
+  return ret;
 }
 
 

Modified: trunk/reactos/ntoskrnl/io/driver.c
--- trunk/reactos/ntoskrnl/io/driver.c	2005-05-05 00:07:27 UTC (rev 14983)
+++ trunk/reactos/ntoskrnl/io/driver.c	2005-05-05 02:46:17 UTC (rev 14984)
@@ -1311,7 +1311,7 @@
    IopBootLog(&Service->ImagePath, NT_SUCCESS(Status) ? TRUE : FALSE);
    if (!NT_SUCCESS(Status))
    {
-      DPRINT("NtLoadDriver() failed (Status %lx)\n", Status);
+      DPRINT("IopLoadDriver() failed (Status %lx)\n", Status);
 #if 0
       if (Service->ErrorControl == 1)
       {
@@ -1783,42 +1783,66 @@
    RTL_QUERY_REGISTRY_TABLE QueryTable[3];
    UNICODE_STRING ImagePath;
    UNICODE_STRING ServiceName;
+   UNICODE_STRING CapturedDriverServiceName;
+   KPROCESSOR_MODE PreviousMode;
    NTSTATUS Status;
    ULONG Type;
    PDEVICE_NODE DeviceNode;
    PMODULE_OBJECT ModuleObject;
    PDRIVER_OBJECT DriverObject;
-   LPWSTR Start;
+   WCHAR *cur;
+   
+   PAGED_CODE();
+   
+   PreviousMode = KeGetPreviousMode();
 
-   DPRINT("NtLoadDriver('%wZ')\n", DriverServiceName);
-
    /*
     * Check security privileges
     */
 
 /* FIXME: Uncomment when privileges will be correctly implemented. */
 #if 0
-   if (!SeSinglePrivilegeCheck(SeLoadDriverPrivilege, KeGetPreviousMode()))
+   if (!SeSinglePrivilegeCheck(SeLoadDriverPrivilege, PreviousMode))
    {
       DPRINT("Privilege not held\n");
       return STATUS_PRIVILEGE_NOT_HELD;
    }
 #endif
 
+   Status = RtlCaptureUnicodeString(&CapturedDriverServiceName,
+                                    PreviousMode,
+                                    PagedPool,
+                                    FALSE,
+                                    DriverServiceName);
+   if (!NT_SUCCESS(Status))
+   {
+      return Status;
+   }
+
+   DPRINT("NtLoadDriver('%wZ')\n", &CapturedDriverServiceName);
+
    RtlInitUnicodeString(&ImagePath, NULL);
 
    /*
     * Get the service name from the registry key name.
     */
+   ASSERT(CapturedDriverServiceName.Length >= sizeof(WCHAR));
 
-   Start = wcsrchr(DriverServiceName->Buffer, L'\\');
-   if (Start == NULL)
-      Start = DriverServiceName->Buffer;
-   else
-      Start++;
+   ServiceName = CapturedDriverServiceName;
+   cur = CapturedDriverServiceName.Buffer + (CapturedDriverServiceName.Length / sizeof(WCHAR)) - 1;
+   while (CapturedDriverServiceName.Buffer != cur)
+   {
+      if(*cur == L'\\')
+      {
+         ServiceName.Buffer = cur + 1;
+         ServiceName.Length = CapturedDriverServiceName.Length -
+                              (USHORT)((ULONG_PTR)ServiceName.Buffer -
+                                       (ULONG_PTR)CapturedDriverServiceName.Buffer);
+         break;
+      }
+      cur--;
+   }
 
-   RtlInitUnicodeString(&ServiceName, Start);
-
    /*
     * Get service type.
     */
@@ -1836,13 +1860,13 @@
    QueryTable[1].EntryContext = &ImagePath;
 
    Status = RtlQueryRegistryValues(RTL_REGISTRY_ABSOLUTE,
-      DriverServiceName->Buffer, QueryTable, NULL, NULL);
+      CapturedDriverServiceName.Buffer, QueryTable, NULL, NULL);
 
    if (!NT_SUCCESS(Status))
    {
       DPRINT("RtlQueryRegistryValues() failed (Status %lx)\n", Status);
       RtlFreeUnicodeString(&ImagePath);
-      return Status;
+      goto ReleaseCapturedString;
    }
 
    /*
@@ -1854,10 +1878,10 @@
    if (!NT_SUCCESS(Status))
    {
       DPRINT("IopNormalizeImagePath() failed (Status %x)\n", Status);
-      return Status;
+      goto ReleaseCapturedString;
    }
 
-   DPRINT("FullImagePath: '%S'\n", ImagePath.Buffer);
+   DPRINT("FullImagePath: '%wZ'\n", &ImagePath);
    DPRINT("Type: %lx\n", Type);
 
    /*
@@ -1868,7 +1892,8 @@
    if (ModuleObject != NULL)
    {
       DPRINT("Image already loaded\n");
-      return STATUS_IMAGE_ALREADY_LOADED;
+      Status = STATUS_IMAGE_ALREADY_LOADED;
+      goto ReleaseCapturedString;
    }
 
    /*
@@ -1881,7 +1906,7 @@
    if (!NT_SUCCESS(Status))
    {
       DPRINT("IopCreateDeviceNode() failed (Status %lx)\n", Status);
-      return Status;
+      goto ReleaseCapturedString;
    }
 
    /*
@@ -1894,19 +1919,14 @@
    {
       DPRINT("LdrLoadModule() failed (Status %lx)\n", Status);
       IopFreeDeviceNode(DeviceNode);
-      return Status;
+      goto ReleaseCapturedString;
    }
 
    /*
     * Set a service name for the device node
     */
 
-   Start = wcsrchr(DriverServiceName->Buffer, L'\\');
-   if (Start == NULL)
-      Start = DriverServiceName->Buffer;
-   else
-      Start++;
-   RtlpCreateUnicodeString(&DeviceNode->ServiceName, Start, NonPagedPool);
+   RtlpCreateUnicodeString(&DeviceNode->ServiceName, ServiceName.Buffer, NonPagedPool);
 
    /*
     * Initialize the driver module
@@ -1925,10 +1945,15 @@
       DPRINT("IopInitializeDriver() failed (Status %lx)\n", Status);
       LdrUnloadModule(ModuleObject);
       IopFreeDeviceNode(DeviceNode);
-      return Status;
+      goto ReleaseCapturedString;
    }
 
    IopInitializeDevice(DeviceNode, DriverObject);
+   
+ReleaseCapturedString:
+   RtlReleaseCapturedUnicodeString(&CapturedDriverServiceName,
+                                   PreviousMode,
+                                   FALSE);
 
    return Status;
 }