- probe and capture parameters in NtCreateKey
- added ProbeForWriteUnicodeString and ProbeForReadUnicodeString macros
Modified: trunk/reactos/ntoskrnl/cm/ntfunc.c
Modified: trunk/reactos/ntoskrnl/include/internal/ntoskrnl.h

Modified: trunk/reactos/ntoskrnl/cm/ntfunc.c
--- trunk/reactos/ntoskrnl/cm/ntfunc.c	2005-10-29 15:05:37 UTC (rev 18852)
+++ trunk/reactos/ntoskrnl/cm/ntfunc.c	2005-10-29 16:00:00 UTC (rev 18853)
@@ -184,35 +184,70 @@
 	    IN ULONG CreateOptions,
 	    OUT PULONG Disposition)
 {
-  UNICODE_STRING RemainingPath;
+  UNICODE_STRING RemainingPath = {0};
+  BOOLEAN FreeRemainingPath = TRUE;
+  ULONG LocalDisposition;
   PKEY_OBJECT KeyObject;
-  NTSTATUS Status;
-  PVOID Object;
+  NTSTATUS Status = STATUS_SUCCESS;
+  PVOID Object = NULL;
   PWSTR Start;
   UNICODE_STRING ObjectName;
   OBJECT_CREATE_INFORMATION ObjectCreateInfo;
   unsigned i;
   REG_PRE_CREATE_KEY_INFORMATION PreCreateKeyInfo;
   REG_POST_CREATE_KEY_INFORMATION PostCreateKeyInfo;
+  KPROCESSOR_MODE PreviousMode;
+  UNICODE_STRING CapturedClass = {0};
+  HANDLE hKey;
 
   PAGED_CODE();
 
-  DPRINT("NtCreateKey (Name %wZ  KeyHandle 0x%p  Root 0x%p)\n",
-	 ObjectAttributes->ObjectName,
-	 KeyHandle,
-	 ObjectAttributes->RootDirectory);
+  PreviousMode = KeGetPreviousMode();
 
+  if (PreviousMode != KernelMode)
+  {
+      _SEH_TRY
+      {
+          ProbeForWriteHandle(KeyHandle);
+          if (Disposition != NULL)
+          {
+              ProbeForWriteUlong(Disposition);
+          }
+      }
+      _SEH_HANDLE
+      {
+          Status = _SEH_GetExceptionCode();
+      }
+      _SEH_END;
+      
+      if (!NT_SUCCESS(Status))
+      {
+          return Status;
+      }
+  }
+  
+  if (Class != NULL)
+  {
+      Status = ProbeAndCaptureUnicodeString(&CapturedClass,
+                                            PreviousMode,
+                                            Class);
+      if (!NT_SUCCESS(Status))
+      {
+          return Status;
+      }
+  }
+
   /* Capture all the info */
   DPRINT("Capturing Create Info\n");
   Status = ObpCaptureObjectAttributes(ObjectAttributes,
-                                      KeGetPreviousMode(),
+                                      PreviousMode,
                                       CmiKeyType,
                                       &ObjectCreateInfo,
                                       &ObjectName);
   if (!NT_SUCCESS(Status))
     {
       DPRINT("ObpCaptureObjectAttributes() failed (Status %lx)\n", Status);
-      return Status;
+      goto Cleanup;
     }
 
   PostCreateKeyInfo.CompleteName = &ObjectName;
@@ -221,13 +256,12 @@
   if (!NT_SUCCESS(Status))
     {
       ObpReleaseCapturedAttributes(&ObjectCreateInfo);
-      if (ObjectName.Buffer) ExFreePool(ObjectName.Buffer);
-      return Status;
+      goto Cleanup;
     }
     
   Status = ObFindObject(&ObjectCreateInfo,
                         &ObjectName,
-			(PVOID*)&Object,
+                        (PVOID*)&Object,
                         &RemainingPath,
                         CmiKeyType);
   ObpReleaseCapturedAttributes(&ObjectCreateInfo);
@@ -236,9 +270,9 @@
       PostCreateKeyInfo.Object = NULL;
       PostCreateKeyInfo.Status = Status;
       CmiCallRegisteredCallbacks(RegNtPostCreateKey, &PostCreateKeyInfo);
-      if (ObjectName.Buffer) ExFreePool(ObjectName.Buffer);
+
       DPRINT("CmpFindObject failed, Status: 0x%x\n", Status);
-      return(Status);
+      goto Cleanup;
     }
 
   DPRINT("RemainingPath %wZ\n", &RemainingPath);
@@ -248,33 +282,29 @@
       /* Fail if the key has been deleted */
       if (((PKEY_OBJECT) Object)->Flags & KO_MARKED_FOR_DELETE)
 	{
-	  ObDereferenceObject(Object);
-	  RtlFreeUnicodeString(&RemainingPath);
           PostCreateKeyInfo.Object = NULL;
           PostCreateKeyInfo.Status = STATUS_UNSUCCESSFUL;
           CmiCallRegisteredCallbacks(RegNtPostCreateKey, &PostCreateKeyInfo);
-          if (ObjectName.Buffer) ExFreePool(ObjectName.Buffer);
+
 	  DPRINT("Object marked for delete!\n");
-	  return(STATUS_UNSUCCESSFUL);
+	  Status = STATUS_UNSUCCESSFUL;
+	  goto Cleanup;
 	}
 
-      if (Disposition)
-	*Disposition = REG_OPENED_EXISTING_KEY;
-
       Status = ObpCreateHandle(PsGetCurrentProcess(),
 			      Object,
 			      DesiredAccess,
 			      TRUE,
-			      KeyHandle);
+			      &hKey);
 
       DPRINT("ObpCreateHandle failed Status 0x%x\n", Status);
-      ObDereferenceObject(Object);
-      RtlFreeUnicodeString(&RemainingPath);
+
       PostCreateKeyInfo.Object = NULL;
       PostCreateKeyInfo.Status = Status;
       CmiCallRegisteredCallbacks(RegNtPostCreateKey, &PostCreateKeyInfo);
-      if (ObjectName.Buffer) ExFreePool(ObjectName.Buffer);
-      return Status;
+
+      LocalDisposition = REG_OPENED_EXISTING_KEY;
+      goto SuccessReturn;
     }
 
   /* If RemainingPath contains \ we must return error
@@ -287,23 +317,23 @@
     {
       if (L'\\' == RemainingPath.Buffer[i])
         {
-          ObDereferenceObject(Object);
           DPRINT1("NtCreateKey() doesn't create trees! (found \'\\\' in remaining path: \"%wZ\"!)\n", &RemainingPath);
-          RtlFreeUnicodeString(&RemainingPath);
+
           PostCreateKeyInfo.Object = NULL;
           PostCreateKeyInfo.Status = STATUS_OBJECT_NAME_NOT_FOUND;
           CmiCallRegisteredCallbacks(RegNtPostCreateKey, &PostCreateKeyInfo);
-          if (ObjectName.Buffer) ExFreePool(ObjectName.Buffer);
-          return STATUS_OBJECT_NAME_NOT_FOUND;
+
+          Status = STATUS_OBJECT_NAME_NOT_FOUND;
+          goto Cleanup;
         }
     }
 
   DPRINT("RemainingPath %S  ParentObject 0x%p\n", RemainingPath.Buffer, Object);
 
-  Status = ObCreateObject(ExGetPreviousMode(),
+  Status = ObCreateObject(PreviousMode,
 			  CmiKeyType,
 			  NULL,
-			  ExGetPreviousMode(),
+			  PreviousMode,
 			  NULL,
 			  sizeof(KEY_OBJECT),
 			  0,
@@ -315,8 +345,8 @@
       PostCreateKeyInfo.Object = NULL;
       PostCreateKeyInfo.Status = Status;
       CmiCallRegisteredCallbacks(RegNtPostCreateKey, &PostCreateKeyInfo);
-      if (ObjectName.Buffer) ExFreePool(ObjectName.Buffer);
-      return(Status);
+
+      goto Cleanup;
     }
 
   Status = ObInsertObject((PVOID)KeyObject,
@@ -324,17 +354,17 @@
 			  DesiredAccess,
 			  0,
 			  NULL,
-			  KeyHandle);
+			  &hKey);
   if (!NT_SUCCESS(Status))
     {
       ObDereferenceObject(KeyObject);
-      RtlFreeUnicodeString(&RemainingPath);
       DPRINT1("ObInsertObject() failed!\n");
+
       PostCreateKeyInfo.Object = NULL;
       PostCreateKeyInfo.Status = Status;
       CmiCallRegisteredCallbacks(RegNtPostCreateKey, &PostCreateKeyInfo);
-      if (ObjectName.Buffer) ExFreePool(ObjectName.Buffer);
-      return(Status);
+
+      goto Cleanup;
     }
 
   KeyObject->ParentKey = Object;
@@ -361,7 +391,7 @@
 			KeyObject,
 			&RemainingPath,
 			TitleIndex,
-			Class,
+			&CapturedClass,
 			CreateOptions);
   if (!NT_SUCCESS(Status))
     {
@@ -370,23 +400,23 @@
       ExReleaseResourceLite(&CmiRegistryLock);
       KeLeaveCriticalRegion();
       ObDereferenceObject(KeyObject);
-      ObDereferenceObject(Object);
-      RtlFreeUnicodeString(&RemainingPath);
+
       PostCreateKeyInfo.Object = NULL;
       PostCreateKeyInfo.Status = STATUS_UNSUCCESSFUL;
       CmiCallRegisteredCallbacks(RegNtPostCreateKey, &PostCreateKeyInfo);
-      if (ObjectName.Buffer) ExFreePool(ObjectName.Buffer);
-      return STATUS_UNSUCCESSFUL;
+
+      Status = STATUS_UNSUCCESSFUL;
+      goto Cleanup;
     }
 
   if (Start == RemainingPath.Buffer)
     {
       KeyObject->Name = RemainingPath;
+      FreeRemainingPath = FALSE;
     }
   else
     {
       RtlpCreateUnicodeString(&KeyObject->Name, Start, NonPagedPool);
-      RtlFreeUnicodeString(&RemainingPath);
     }
 
   if (KeyObject->RegistryHive == KeyObject->ParentKey->RegistryHive)
@@ -400,10 +430,7 @@
       KeyObject->KeyCell->SecurityKeyOffset = -1;
       /* This key must remain in memory unless it is deleted
 	 or file is unloaded */
-      ObReferenceObjectByPointer(KeyObject,
-				 STANDARD_RIGHTS_REQUIRED,
-				 NULL,
-				 UserMode);
+      ObReferenceObject(KeyObject);
     }
 
   CmiAddKeyToList(KeyObject->ParentKey, KeyObject);
@@ -414,19 +441,39 @@
   ExReleaseResourceLite(&CmiRegistryLock);
   KeLeaveCriticalRegion();
 
-
-  ObDereferenceObject(Object);
-
-  if (Disposition)
-    *Disposition = REG_CREATED_NEW_KEY;
-
   PostCreateKeyInfo.Object = KeyObject;
   PostCreateKeyInfo.Status = Status;
   CmiCallRegisteredCallbacks(RegNtPostCreateKey, &PostCreateKeyInfo);
-  if (ObjectName.Buffer) ExFreePool(ObjectName.Buffer);
 
   CmiSyncHives();
+  
+  LocalDisposition = REG_CREATED_NEW_KEY;
 
+SuccessReturn:
+  _SEH_TRY
+  {
+      *KeyHandle = hKey;
+      if (Disposition != NULL)
+      {
+          *Disposition = LocalDisposition;
+      }
+  }
+  _SEH_HANDLE
+  {
+      Status = _SEH_GetExceptionCode();
+  }
+  _SEH_END;
+
+Cleanup:
+  if (Class != NULL)
+  {
+    ReleaseCapturedUnicodeString(&CapturedClass,
+                                 PreviousMode);
+  }
+  if (ObjectName.Buffer) ExFreePool(ObjectName.Buffer);
+  if (FreeRemainingPath) RtlFreeUnicodeString(&RemainingPath);
+  if (Object != NULL) ObDereferenceObject(Object);
+
   return Status;
 }
 

Modified: trunk/reactos/ntoskrnl/include/internal/ntoskrnl.h
--- trunk/reactos/ntoskrnl/include/internal/ntoskrnl.h	2005-10-29 15:05:37 UTC (rev 18852)
+++ trunk/reactos/ntoskrnl/include/internal/ntoskrnl.h	2005-10-29 16:00:00 UTC (rev 18853)
@@ -83,7 +83,64 @@
 
 #define ExRaiseStatus RtlRaiseStatus
 
+static const UNICODE_STRING __emptyUnicodeString = {0};
+
 /*
+ * NOTE: Alignment of the pointers is not verified!
+ */
+#define ProbeForWriteGenericType(Ptr, Type)                                    \
+    do {                                                                       \
+        if ((ULONG_PTR)(Ptr) + sizeof(Type) - 1 < (ULONG_PTR)(Ptr) ||          \
+            (ULONG_PTR)(Ptr) + sizeof(Type) - 1 >= (ULONG_PTR)MmUserProbeAddress) { \
+            RtlRaiseStatus (STATUS_ACCESS_VIOLATION);                          \
+        }                                                                      \
+        *(volatile Type *)(Ptr) = *(volatile Type *)(Ptr);                     \
+    } while (0)
+
+#define ProbeForWriteBoolean(Ptr) ProbeForWriteGenericType(Ptr, BOOLEAN)
+#define ProbeForWriteUchar(Ptr) ProbeForWriteGenericType(Ptr, UCHAR)
+#define ProbeForWriteChar(Ptr) ProbeForWriteGenericType(Ptr, Char)
+#define ProbeForWriteUshort(Ptr) ProbeForWriteGenericType(Ptr, USHORT)
+#define ProbeForWriteShort(Ptr) ProbeForWriteGenericType(Ptr, SHORT)
+#define ProbeForWriteUlong(Ptr) ProbeForWriteGenericType(Ptr, ULONG)
+#define ProbeForWriteLong(Ptr) ProbeForWriteGenericType(Ptr, LONG)
+#define ProbeForWriteUint(Ptr) ProbeForWriteGenericType(Ptr, UINT)
+#define ProbeForWriteInt(Ptr) ProbeForWriteGenericType(Ptr, INT)
+#define ProbeForWriteUlonglong(Ptr) ProbeForWriteGenericType(Ptr, ULONGLONG)
+#define ProbeForWriteLonglong(Ptr) ProbeForWriteGenericType(Ptr, LONGLONG)
+#define ProbeForWriteLonglong(Ptr) ProbeForWriteGenericType(Ptr, LONGLONG)
+#define ProbeForWritePointer(Ptr) ProbeForWriteGenericType(Ptr, PVOID)
+#define ProbeForWriteHandle(Ptr) ProbeForWriteGenericType(Ptr, HANDLE)
+#define ProbeForWriteLangid(Ptr) ProbeForWriteGenericType(Ptr, LANGID)
+#define ProbeForWriteLargeInteger(Ptr) ProbeForWriteGenericType(&(Ptr)->QuadPart, LONGLONG)
+#define ProbeForWriteUlargeInteger(Ptr) ProbeForWriteGenericType(&(Ptr)->QuadPart, ULONGLONG)
+#define ProbeForWriteUnicodeString(Ptr) ProbeForWriteGenericType(Ptr, UNICODE_STRING)
+
+#define ProbeForReadGenericType(Ptr, Type, Default)                            \
+    (((ULONG_PTR)(Ptr) + sizeof(Type) - 1 < (ULONG_PTR)(Ptr) ||                \
+	 (ULONG_PTR)(Ptr) + sizeof(Type) - 1 >= (ULONG_PTR)MmUserProbeAddress) ?   \
+	     ExRaiseStatus (STATUS_ACCESS_VIOLATION), Default :                    \
+	     *(volatile Type *)(Ptr))
+
+#define ProbeForReadBoolean(Ptr) ProbeForReadGenericType(Ptr, BOOLEAN, FALSE)
+#define ProbeForReadUchar(Ptr) ProbeForReadGenericType(Ptr, UCHAR, 0)
+#define ProbeForReadChar(Ptr) ProbeForReadGenericType(Ptr, CHAR, 0)
+#define ProbeForReadUshort(Ptr) ProbeForReadGenericType(Ptr, USHORT, 0)
+#define ProbeForReadShort(Ptr) ProbeForReadGenericType(Ptr, SHORT, 0)
+#define ProbeForReadUlong(Ptr) ProbeForReadGenericType(Ptr, ULONG, 0)
+#define ProbeForReadLong(Ptr) ProbeForReadGenericType(Ptr, LONG, 0)
+#define ProbeForReadUint(Ptr) ProbeForReadGenericType(Ptr, UINT, 0)
+#define ProbeForReadInt(Ptr) ProbeForReadGenericType(Ptr, INT, 0)
+#define ProbeForReadUlonglong(Ptr) ProbeForReadGenericType(Ptr, ULONGLONG, 0)
+#define ProbeForReadLonglong(Ptr) ProbeForReadGenericType(Ptr, LONGLONG, 0)
+#define ProbeForReadPointer(Ptr) ProbeForReadGenericType(Ptr, PVOID, NULL)
+#define ProbeForReadHandle(Ptr) ProbeForReadGenericType(Ptr, HANDLE, NULL)
+#define ProbeForReadLangid(Ptr) ProbeForReadGenericType(Ptr, LANGID, 0)
+#define ProbeForReadLargeInteger(Ptr) ((LARGE_INTEGER)ProbeForReadGenericType(&(Ptr)->QuadPart, LONGLONG, 0))
+#define ProbeForReadUlargeInteger(Ptr) ((ULARGE_INTEGER)ProbeForReadGenericType(&(Ptr)->QuadPart, ULONGLONG, 0))
+#define ProbeForReadUnicodeString(Ptr) ProbeForReadGenericType(Ptr, UNICODE_STRING, __emptyUnicodeString)
+
+/*
  * Inlined Probing Macros
  */
 static __inline
@@ -102,10 +159,7 @@
     {
         _SEH_TRY
         {
-            ProbeForRead(UnsafeSrc,
-                         sizeof(UNICODE_STRING),
-                         sizeof(ULONG));
-            *Dest = *UnsafeSrc;
+            *Dest = ProbeForReadUnicodeString(UnsafeSrc);
             if(Dest->Buffer != NULL)
             {
                 if (Dest->Length != 0)
@@ -175,59 +229,6 @@
 }
 
 /*
- * NOTE: Alignment of the pointers is not verified!
- */
-#define ProbeForWriteGenericType(Ptr, Type)                                    \
-    do {                                                                       \
-        if ((ULONG_PTR)(Ptr) + sizeof(Type) - 1 < (ULONG_PTR)(Ptr) ||          \
-            (ULONG_PTR)(Ptr) + sizeof(Type) - 1 >= (ULONG_PTR)MmUserProbeAddress) { \
-            RtlRaiseStatus (STATUS_ACCESS_VIOLATION);                          \
-        }                                                                      \
-        *(volatile Type *)(Ptr) = *(volatile Type *)(Ptr);                     \
-    } while (0)
-
-#define ProbeForWriteBoolean(Ptr) ProbeForWriteGenericType(Ptr, BOOLEAN)
-#define ProbeForWriteUchar(Ptr) ProbeForWriteGenericType(Ptr, UCHAR)
-#define ProbeForWriteChar(Ptr) ProbeForWriteGenericType(Ptr, Char)
-#define ProbeForWriteUshort(Ptr) ProbeForWriteGenericType(Ptr, USHORT)
-#define ProbeForWriteShort(Ptr) ProbeForWriteGenericType(Ptr, SHORT)
-#define ProbeForWriteUlong(Ptr) ProbeForWriteGenericType(Ptr, ULONG)
-#define ProbeForWriteLong(Ptr) ProbeForWriteGenericType(Ptr, LONG)
-#define ProbeForWriteUint(Ptr) ProbeForWriteGenericType(Ptr, UINT)
-#define ProbeForWriteInt(Ptr) ProbeForWriteGenericType(Ptr, INT)
-#define ProbeForWriteUlonglong(Ptr) ProbeForWriteGenericType(Ptr, ULONGLONG)
-#define ProbeForWriteLonglong(Ptr) ProbeForWriteGenericType(Ptr, LONGLONG)
-#define ProbeForWriteLonglong(Ptr) ProbeForWriteGenericType(Ptr, LONGLONG)
-#define ProbeForWritePointer(Ptr) ProbeForWriteGenericType(Ptr, PVOID)
-#define ProbeForWriteHandle(Ptr) ProbeForWriteGenericType(Ptr, HANDLE)
-#define ProbeForWriteLangid(Ptr) ProbeForWriteGenericType(Ptr, LANGID)
-#define ProbeForWriteLargeInteger(Ptr) ProbeForWriteGenericType(&(Ptr)->QuadPart, LONGLONG)
-#define ProbeForWriteUlargeInteger(Ptr) ProbeForWriteGenericType(&(Ptr)->QuadPart, ULONGLONG)
-
-#define ProbeForReadGenericType(Ptr, Type, Default)                            \
-    (((ULONG_PTR)(Ptr) + sizeof(Type) - 1 < (ULONG_PTR)(Ptr) ||                \
-	 (ULONG_PTR)(Ptr) + sizeof(Type) - 1 >= (ULONG_PTR)MmUserProbeAddress) ?   \
-	     ExRaiseStatus (STATUS_ACCESS_VIOLATION), Default :                    \
-	     *(volatile Type *)(Ptr))
-
-#define ProbeForReadBoolean(Ptr) ProbeForReadGenericType(Ptr, BOOLEAN, FALSE)
-#define ProbeForReadUchar(Ptr) ProbeForReadGenericType(Ptr, UCHAR, 0)
-#define ProbeForReadChar(Ptr) ProbeForReadGenericType(Ptr, CHAR, 0)
-#define ProbeForReadUshort(Ptr) ProbeForReadGenericType(Ptr, USHORT, 0)
-#define ProbeForReadShort(Ptr) ProbeForReadGenericType(Ptr, SHORT, 0)
-#define ProbeForReadUlong(Ptr) ProbeForReadGenericType(Ptr, ULONG, 0)
-#define ProbeForReadLong(Ptr) ProbeForReadGenericType(Ptr, LONG, 0)
-#define ProbeForReadUint(Ptr) ProbeForReadGenericType(Ptr, UINT, 0)
-#define ProbeForReadInt(Ptr) ProbeForReadGenericType(Ptr, INT, 0)
-#define ProbeForReadUlonglong(Ptr) ProbeForReadGenericType(Ptr, ULONGLONG, 0)
-#define ProbeForReadLonglong(Ptr) ProbeForReadGenericType(Ptr, LONGLONG, 0)
-#define ProbeForReadPointer(Ptr) ProbeForReadGenericType(Ptr, PVOID, NULL)
-#define ProbeForReadHandle(Ptr) ProbeForReadGenericType(Ptr, HANDLE, NULL)
-#define ProbeForReadLangid(Ptr) ProbeForReadGenericType(Ptr, LANGID, 0)
-#define ProbeForReadLargeInteger(Ptr) ((LARGE_INTEGER)ProbeForReadGenericType(&(Ptr)->QuadPart, LONGLONG, 0))
-#define ProbeForReadUlargeInteger(Ptr) ((ULARGE_INTEGER)ProbeForReadGenericType(&(Ptr)->QuadPart, ULONGLONG, 0))
-
-/*
  * generic information class probing code
  */