securely access buffers in NtOpenDirectoryObject(), NtQueryDirectoryObject() and NtCreateDirectoryObject()
Modified: trunk/reactos/ntoskrnl/ob/dirobj.c

Modified: trunk/reactos/ntoskrnl/ob/dirobj.c
--- trunk/reactos/ntoskrnl/ob/dirobj.c	2005-01-26 05:00:08 UTC (rev 13305)
+++ trunk/reactos/ntoskrnl/ob/dirobj.c	2005-01-26 12:47:38 UTC (rev 13306)
@@ -47,65 +47,57 @@
 		       IN ACCESS_MASK DesiredAccess,
 		       IN POBJECT_ATTRIBUTES ObjectAttributes)
 {
-   PVOID Object;
-   NTSTATUS Status;
-
-   *DirectoryHandle = 0;
+   HANDLE hDirectory;
+   KPROCESSOR_MODE PreviousMode;
+   NTSTATUS Status = STATUS_SUCCESS;
    
-   Status = ObReferenceObjectByName(ObjectAttributes->ObjectName,
-				    ObjectAttributes->Attributes,
-				    NULL,
-				    DesiredAccess,
-				    ObDirectoryType,
-				    UserMode,
-				    NULL,
-				    &Object);
-   if (!NT_SUCCESS(Status))
+   PreviousMode = ExGetPreviousMode();
+   
+   if(PreviousMode != KernelMode)
+   {
+     _SEH_TRY
      {
-	return Status;
+       ProbeForWrite(DirectoryHandle,
+                     sizeof(HANDLE),
+                     sizeof(ULONG));
      }
+     _SEH_HANDLE
+     {
+       Status = _SEH_GetExceptionCode();
+     }
+     _SEH_END;
+     
+     if(!NT_SUCCESS(Status))
+     {
+       DPRINT1("NtOpenDirectoryObject failed, Status: 0x%x\n", Status);
+       return Status;
+     }
+   }
    
-   Status = ObCreateHandle(PsGetCurrentProcess(),
-			   Object,
-			   DesiredAccess,
-			   FALSE,
-			   DirectoryHandle);
-   return STATUS_SUCCESS;
+   Status = ObOpenObjectByName(ObjectAttributes,
+                               ObDirectoryType,
+                               NULL,
+                               PreviousMode,
+                               DesiredAccess,
+                               NULL,
+                               &hDirectory);
+   if(NT_SUCCESS(Status))
+   {
+     _SEH_TRY
+     {
+       *DirectoryHandle = hDirectory;
+     }
+     _SEH_HANDLE
+     {
+       Status = _SEH_GetExceptionCode();
+     }
+     _SEH_END;
+   }
+   
+   return Status;
 }
 
-static NTSTATUS
-CopyDirectoryString(PUNICODE_STRING UnsafeTarget, PUNICODE_STRING Source, PUCHAR *Buffer)
-{
-    UNICODE_STRING Target;
-    NTSTATUS Status;
-    WCHAR NullWchar;
 
-    Target.Length        = Source->Length;
-    Target.MaximumLength = (Source->Length + sizeof (WCHAR));
-    Target.Buffer        = (PWCHAR) *Buffer;
-    Status = MmCopyToCaller(UnsafeTarget, &Target, sizeof(UNICODE_STRING));
-    if (! NT_SUCCESS(Status))
-      {
-	return Status;
-      }
-    Status = MmCopyToCaller(*Buffer, Source->Buffer, Source->Length);
-    if (! NT_SUCCESS(Status))
-      {
-	return Status;
-      }
-    *Buffer += Source->Length;
-    NullWchar = L'\0';
-    Status = MmCopyToCaller(*Buffer, &NullWchar, sizeof(WCHAR));
-    if (! NT_SUCCESS(Status))
-      {
-	return Status;
-      }
-    *Buffer += sizeof(WCHAR);
-
-    return STATUS_SUCCESS;
-}
-
-
 /**********************************************************************
  * NAME							EXPORTED
  *	NtQueryDirectoryObject
@@ -169,202 +161,229 @@
 			IN ULONG BufferLength,
 			IN BOOLEAN ReturnSingleEntry,
 			IN BOOLEAN RestartScan,
-			IN OUT PULONG UnsafeContext,
-			OUT PULONG UnsafeReturnLength OPTIONAL)
+			IN OUT PULONG Context,
+			OUT PULONG ReturnLength OPTIONAL)
 {
-    PDIRECTORY_OBJECT   dir = NULL;
-    PLIST_ENTRY         current_entry = NULL;
-    PLIST_ENTRY         start_entry;
-    POBJECT_HEADER      current = NULL;
-    NTSTATUS            Status = STATUS_SUCCESS;
-    ULONG               DirectoryCount = 0;
-    ULONG               DirectoryIndex = 0;
-    POBJECT_DIRECTORY_INFORMATION current_odi = (POBJECT_DIRECTORY_INFORMATION) Buffer;
-    OBJECT_DIRECTORY_INFORMATION ZeroOdi;
-    PUCHAR              FirstFree = (PUCHAR) Buffer;
-    ULONG               Context;
-    ULONG               RequiredSize;
-    ULONG               NewValue;
-    KIRQL               OldLevel;
+  PDIRECTORY_OBJECT Directory;
+  KPROCESSOR_MODE PreviousMode;
+  ULONG SkipEntries = 0;
+  ULONG NextEntry = 0;
+  NTSTATUS Status = STATUS_SUCCESS;
+  
+  PreviousMode = ExGetPreviousMode();
 
-    DPRINT("NtQueryDirectoryObject(DirectoryHandle %x)\n", DirectoryHandle);
-
-    /* Check Context is not NULL */
-    if (NULL == UnsafeContext)
+  if(PreviousMode != KernelMode)
+  {
+    _SEH_TRY
+    {
+      /* a test showed that the Buffer pointer just has to be 16 bit aligned,
+         propably due to the fact that most information that needs to be copied
+         is unicode strings */
+      ProbeForWrite(Buffer,
+                    BufferLength,
+                    sizeof(WCHAR));
+      ProbeForWrite(Context,
+                    sizeof(ULONG),
+                    sizeof(ULONG));
+      if(!RestartScan)
       {
-        return STATUS_INVALID_PARAMETER;
+        SkipEntries = *Context;
       }
-
-    /* Reference the DIRECTORY_OBJECT */
-    Status = ObReferenceObjectByHandle(DirectoryHandle,
-				      DIRECTORY_QUERY,
-				      ObDirectoryType,
-				      UserMode,
-				      (PVOID*)&dir,
-				      NULL);
-    if (!NT_SUCCESS(Status))
+      if(ReturnLength != NULL)
       {
-        return Status;
+        ProbeForWrite(ReturnLength,
+                      sizeof(ULONG),
+                      sizeof(ULONG));
       }
+    }
+    _SEH_HANDLE
+    {
+      Status = _SEH_GetExceptionCode();
+    }
+    _SEH_END;
 
-    KeAcquireSpinLock(&dir->Lock, &OldLevel);
+    if(!NT_SUCCESS(Status))
+    {
+      DPRINT1("NtQueryDirectoryObject failed, Status: 0x%x\n", Status);
+      return Status;
+    }
+  }
+  else if(!RestartScan)
+  {
+    SkipEntries = *Context;
+  }
+  
+  Status = ObReferenceObjectByHandle(DirectoryHandle,
+                                     DIRECTORY_QUERY,
+                                     ObDirectoryType,
+                                     PreviousMode,
+                                     (PVOID*)&Directory,
+                                     NULL);
+  if(NT_SUCCESS(Status))
+  {
+    PVOID TemporaryBuffer = ExAllocatePool(PagedPool,
+                                           BufferLength);
+    if(TemporaryBuffer != NULL)
+    {
+      POBJECT_HEADER EntryHeader;
+      PLIST_ENTRY ListEntry;
+      KIRQL OldLevel;
+      ULONG RequiredSize = 0;
+      ULONG nDirectories = 0;
+      POBJECT_DIRECTORY_INFORMATION DirInfo = (POBJECT_DIRECTORY_INFORMATION)TemporaryBuffer;
 
-    /*
-     * Optionally, skip over some entries at the start of the directory
-     * (use *ObjectIndex value)
-     */
-    start_entry = dir->head.Flink;
-    if (! RestartScan)
+      KeAcquireSpinLock(&Directory->Lock, &OldLevel);
+
+      for(ListEntry = Directory->head.Flink;
+          ListEntry != &Directory->head;
+          ListEntry = ListEntry->Flink)
       {
-        register ULONG EntriesToSkip;
+        NextEntry++;
+        if(SkipEntries == 0)
+        {
+          PUNICODE_STRING Name, Type;
+          ULONG EntrySize;
 
-	Status = MmCopyFromCaller(&Context, UnsafeContext, sizeof(ULONG));
-	if (! NT_SUCCESS(Status))
-	  {
-	    KeReleaseSpinLock(&dir->Lock, OldLevel);
-            ObDereferenceObject(dir);
-	    return Status;
-	  }
-	EntriesToSkip = Context;
+          EntryHeader = CONTAINING_RECORD(ListEntry, OBJECT_HEADER, Entry);
 
-	CHECKPOINT;
-	
-	for (; 0 != EntriesToSkip-- && start_entry != &dir->head;
-	     start_entry = start_entry->Flink)
-	  {
-	    ;
-	  }
-	if ((0 != EntriesToSkip) && (start_entry == &dir->head))
-	  {
-	    KeReleaseSpinLock(&dir->Lock, OldLevel);
-            ObDereferenceObject(dir);
-            return STATUS_NO_MORE_ENTRIES;
-	  }
-      }
+          /* calculate the size of the required buffer space for this entry */
+          Name = (EntryHeader->Name.Length != 0 ? &EntryHeader->Name : NULL);
+          Type = &EntryHeader->ObjectType->TypeName;
+          EntrySize = sizeof(OBJECT_DIRECTORY_INFORMATION) +
+                      ((Name != NULL) ? ((ULONG)Name->Length + sizeof(WCHAR)) : 0) +
+                      (ULONG)EntryHeader->ObjectType->TypeName.Length + sizeof(WCHAR);
 
-    /*
-     * Compute number of entries that we will copy into the buffer and
-     * the total size of all entries (even if larger than the buffer size)
-     */
-    DirectoryCount = 0;
-    /* For the end sentenil */
-    RequiredSize = sizeof(OBJECT_DIRECTORY_INFORMATION);
-    for (current_entry = start_entry;
-         current_entry != &dir->head;
-         current_entry = current_entry->Flink)
-      {
-	current = CONTAINING_RECORD(current_entry, OBJECT_HEADER, Entry);
+          if(RequiredSize + EntrySize <= BufferLength)
+          {
+            /* the buffer is large enough to receive this entry. It would've
+               been much easier if the strings were directly appended to the
+               OBJECT_DIRECTORY_INFORMATION structured written into the buffer */
+            if(Name != NULL)
+              DirInfo->ObjectName = *Name;
+            else
+            {
+              DirInfo->ObjectName.Length = DirInfo->ObjectName.MaximumLength = 0;
+              DirInfo->ObjectName.Buffer = NULL;
+            }
+            DirInfo->ObjectTypeName = *Type;
 
-	RequiredSize += sizeof(OBJECT_DIRECTORY_INFORMATION) +
-	                current->Name.Length + sizeof(WCHAR) +
-	                current->ObjectType->TypeName.Length + sizeof(WCHAR);
-	if (RequiredSize <= BufferLength &&
-	    (! ReturnSingleEntry || DirectoryCount < 1))
-	  {
-	    DirectoryCount++;
-	  }
-      }
+            nDirectories++;
+            RequiredSize += EntrySize;
 
-    /*
-     * If there's no room to even copy a single entry then return error
-     * status.
-     */
-    if (0 == DirectoryCount && 
-        !(IsListEmpty(&dir->head) && BufferLength >= RequiredSize))
-      {
-	KeReleaseSpinLock(&dir->Lock, OldLevel);
-	ObDereferenceObject(dir);
-	if (NULL != UnsafeReturnLength)
-	  {
-	    Status = MmCopyToCaller(UnsafeReturnLength, &RequiredSize, sizeof(ULONG));
-	  }
-	return NT_SUCCESS(Status) ? STATUS_BUFFER_TOO_SMALL : Status;
+            if(ReturnSingleEntry)
+            {
+              /* we're only supposed to query one entry, so bail and copy the
+                 strings to the buffer */
+              break;
+            }
+            DirInfo++;
+          }
+          else
+          {
+            if(ReturnSingleEntry)
+            {
+              /* the buffer is too small, so return the number of bytes that
+                 would've been required for this query */
+              RequiredSize += EntrySize;
+              Status = STATUS_BUFFER_TOO_SMALL;
+            }
+            else
+            {
+              /* just copy the entries that fit into the buffer */
+              Status = STATUS_NO_MORE_ENTRIES;
+            }
+            break;
+          }
+        }
+        else
+        {
+          /* skip the entry */
+          SkipEntries--;
+        }
       }
 
-    /*
-     * Move FirstFree to point to the Unicode strings area
-     */
-    FirstFree += (DirectoryCount + 1) * sizeof(OBJECT_DIRECTORY_INFORMATION);
-
-    /* Scan the directory */
-    current_entry = start_entry;
-    for (DirectoryIndex = 0; DirectoryIndex < DirectoryCount; DirectoryIndex++) 
+      if(NT_SUCCESS(Status))
       {
-	current = CONTAINING_RECORD(current_entry, OBJECT_HEADER, Entry);
+        if(SkipEntries > 0 || nDirectories == 0)
+        {
+          /* we skipped more entries than the directory contains, nothing more to do */
+          Status = STATUS_NO_MORE_ENTRIES;
+        }
+        else
+        {
+          _SEH_TRY
+          {
+            POBJECT_DIRECTORY_INFORMATION DestDirInfo = (POBJECT_DIRECTORY_INFORMATION)Buffer;
+            PWSTR strbuf = (PWSTR)((POBJECT_DIRECTORY_INFORMATION)Buffer + nDirectories);
 
-	/*
-	 * Copy the current directory entry's data into the buffer
-	 * and update the OBJDIR_INFORMATION entry in the array.
-	 */
-	/* --- Object's name --- */
-	Status = CopyDirectoryString(&current_odi->ObjectName, &current->Name, &FirstFree);
-	if (! NT_SUCCESS(Status))
-	  {
-	    KeReleaseSpinLock(&dir->Lock, OldLevel);
-	    ObDereferenceObject(dir);
-	    return Status;
-	  }
-	/* --- Object type's name --- */
-	Status = CopyDirectoryString(&current_odi->ObjectTypeName, &current->ObjectType->TypeName, &FirstFree);
-	if (! NT_SUCCESS(Status))
-	  {
-	    KeReleaseSpinLock(&dir->Lock, OldLevel);
-	    ObDereferenceObject(dir);
-	    return Status;
-	  }
+            /* copy all OBJECT_DIRECTORY_INFORMATION structures to the buffer and
+               just append all strings (whose pointers are stored in the buffer!)
+               and replace the pointers */
+            for(DirInfo = (POBJECT_DIRECTORY_INFORMATION)TemporaryBuffer;
+                nDirectories > 0;
+                nDirectories--, DirInfo++, DestDirInfo++)
+            {
+              if(DirInfo->ObjectName.Length > 0)
+              {
+                DestDirInfo->ObjectName.Length = DirInfo->ObjectName.Length;
+                DestDirInfo->ObjectName.MaximumLength = DirInfo->ObjectName.MaximumLength;
+                DestDirInfo->ObjectName.Buffer = strbuf;
+                RtlCopyMemory(strbuf,
+                              DirInfo->ObjectName.Buffer,
+                              DirInfo->ObjectName.Length);
+                /* NULL-terminate the string */
+                strbuf[DirInfo->ObjectName.Length / sizeof(WCHAR)] = L'\0';
+                strbuf += (DirInfo->ObjectName.Length / sizeof(WCHAR)) + 1;
+              }
+              
+              DestDirInfo->ObjectTypeName.Length = DirInfo->ObjectTypeName.Length;
+              DestDirInfo->ObjectTypeName.MaximumLength = DirInfo->ObjectTypeName.MaximumLength;
+              DestDirInfo->ObjectTypeName.Buffer = strbuf;
+              RtlCopyMemory(strbuf,
+                            DirInfo->ObjectTypeName.Buffer,
+                            DirInfo->ObjectTypeName.Length);
+              /* NULL-terminate the string */
+              strbuf[DirInfo->ObjectTypeName.Length / sizeof(WCHAR)] = L'\0';
+              strbuf += (DirInfo->ObjectTypeName.Length / sizeof(WCHAR)) + 1;
+            }
+          }
+          _SEH_HANDLE
+          {
+            Status = _SEH_GetExceptionCode();
+          }
+          _SEH_END;
+        }
+      }
 
-	/* Next entry in the array */
-	current_odi++;
-	/* Next object in the directory */
-	current_entry = current_entry->Flink;
-    }
+      KeReleaseSpinLock(&Directory->Lock, OldLevel);
+      ObDereferenceObject(Directory);
+      
+      ExFreePool(TemporaryBuffer);
 
-    /*
-     * Don't need dir object anymore
-     */
-    KeReleaseSpinLock(&dir->Lock, OldLevel);
-    ObDereferenceObject(dir);
-
-    /* Terminate with all zero entry */
-    memset(&ZeroOdi, '\0', sizeof(OBJECT_DIRECTORY_INFORMATION));
-    Status = MmCopyToCaller(current_odi, &ZeroOdi, sizeof(OBJECT_DIRECTORY_INFORMATION));
-    if (! NT_SUCCESS(Status))
+      if(NT_SUCCESS(Status) || ReturnSingleEntry)
       {
-        return Status;
+        _SEH_TRY
+        {
+          *Context = NextEntry;
+          if(ReturnLength != NULL)
+          {
+            *ReturnLength = RequiredSize;
+          }
+        }
+        _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
       }
-
-    /*
-     * Store current index in Context
-     */
-    if (RestartScan)
-      {
-	Context = DirectoryCount;
-      }
+    }
     else
-      {
-	Context += DirectoryCount;
-      }
-    Status = MmCopyToCaller(UnsafeContext, &Context, sizeof(ULONG));
-    if (! NT_SUCCESS(Status))
-      {
-        return Status;
-      }
-
-    /*
-     * Report to the caller how much bytes
-     * we wrote in the user buffer.
-     */
-    if (NULL != UnsafeReturnLength)
-      {
-	NewValue = FirstFree - (PUCHAR) Buffer;
-	Status = MmCopyToCaller(UnsafeReturnLength, &NewValue, sizeof(ULONG));
-	if (! NT_SUCCESS(Status))
-	  {
-	    return Status;
-	  }
-      }
-
-    return Status;
+    {
+      Status = STATUS_INSUFFICIENT_RESOURCES;
+    }
+  }
+  
+  return Status;
 }
 
 
@@ -396,43 +415,69 @@
 			 IN ACCESS_MASK DesiredAccess,
 			 IN POBJECT_ATTRIBUTES ObjectAttributes)
 {
-  PDIRECTORY_OBJECT DirectoryObject;
-  NTSTATUS Status;
-
+  PDIRECTORY_OBJECT Directory;
+  HANDLE hDirectory;
+  KPROCESSOR_MODE PreviousMode;
+  NTSTATUS Status = STATUS_SUCCESS;
+  
   DPRINT("NtCreateDirectoryObject(DirectoryHandle %x, "
-	 "DesiredAccess %x, ObjectAttributes %x, "
-	 "ObjectAttributes->ObjectName %wZ)\n",
-	 DirectoryHandle, DesiredAccess, ObjectAttributes,
-	 ObjectAttributes->ObjectName);
+	 "DesiredAccess %x, ObjectAttributes %x\n",
+	 DirectoryHandle, DesiredAccess, ObjectAttributes);
 
-  Status = NtOpenDirectoryObject (DirectoryHandle,
-		                  DesiredAccess,
-		                  ObjectAttributes);
+  PreviousMode = ExGetPreviousMode();
 
-  if (Status == STATUS_OBJECT_NAME_NOT_FOUND)
+  if(PreviousMode != KernelMode)
   {
-     Status = ObCreateObject (ExGetPreviousMode(),
-			      ObDirectoryType,
-			      ObjectAttributes,
-			      ExGetPreviousMode(),
-			      NULL,
-			      sizeof(DIRECTORY_OBJECT),
-			      0,
-			      0,
-			      (PVOID*)&DirectoryObject);
-     if (!NT_SUCCESS(Status))
-     {
-        return Status;
-     }
+    _SEH_TRY
+    {
+      ProbeForWrite(DirectoryHandle,
+                    sizeof(HANDLE),
+                    sizeof(ULONG));
+    }
+    _SEH_HANDLE
+    {
+      Status = _SEH_GetExceptionCode();
+    }
+    _SEH_END;
 
-     Status = ObInsertObject ((PVOID)DirectoryObject,
-			      NULL,
-			      DesiredAccess,
-			      0,
-			      NULL,
-			      DirectoryHandle);
+    if(!NT_SUCCESS(Status))
+    {
+      DPRINT1("NtCreateDirectoryObject failed, Status: 0x%x\n", Status);
+      return Status;
+    }
+  }
 
-     ObDereferenceObject(DirectoryObject);
+  Status = ObCreateObject(PreviousMode,
+                          ObDirectoryType,
+                          ObjectAttributes,
+                          PreviousMode,
+                          NULL,
+                          sizeof(DIRECTORY_OBJECT),
+                          0,
+                          0,
+                          (PVOID*)&Directory);
+  if(NT_SUCCESS(Status))
+  {
+    Status = ObInsertObject((PVOID)Directory,
+                            NULL,
+                            DesiredAccess,
+                            0,
+                            NULL,
+                            &hDirectory);
+    ObDereferenceObject(Directory);
+
+    if(NT_SUCCESS(Status))
+    {
+      _SEH_TRY
+      {
+        *DirectoryHandle = hDirectory;
+      }
+      _SEH_HANDLE
+      {
+        Status = _SEH_GetExceptionCode();
+      }
+      _SEH_END;
+    }
   }
 
   return Status;