protect access to buffers with SEH in NtSetSecurityObject and NtQuerySecurityObject and ask for the proper access rights
Modified: trunk/reactos/ntoskrnl/include/internal/se.h
Modified: trunk/reactos/ntoskrnl/ob/security.c
Modified: trunk/reactos/ntoskrnl/se/semgr.c

Modified: trunk/reactos/ntoskrnl/include/internal/se.h
--- trunk/reactos/ntoskrnl/include/internal/se.h	2005-12-30 01:39:34 UTC (rev 20454)
+++ trunk/reactos/ntoskrnl/include/internal/se.h	2005-12-30 01:41:02 UTC (rev 20455)
@@ -274,6 +274,14 @@
     KeLeaveCriticalRegion();                                                   \
   while(0)
 
+VOID STDCALL
+SeQuerySecurityAccessMask(IN SECURITY_INFORMATION SecurityInformation,
+                          OUT PACCESS_MASK DesiredAccess);
+
+VOID STDCALL
+SeSetSecurityAccessMask(IN SECURITY_INFORMATION SecurityInformation,
+                        OUT PACCESS_MASK DesiredAccess);
+
 #endif /* __NTOSKRNL_INCLUDE_INTERNAL_SE_H */
 
 /* EOF */

Modified: trunk/reactos/ntoskrnl/ob/security.c
--- trunk/reactos/ntoskrnl/ob/security.c	2005-12-30 01:39:34 UTC (rev 20454)
+++ trunk/reactos/ntoskrnl/ob/security.c	2005-12-30 01:41:02 UTC (rev 20455)
@@ -159,47 +159,74 @@
 		      IN ULONG Length,
 		      OUT PULONG ResultLength)
 {
-  POBJECT_HEADER Header;
-  PVOID Object;
-  NTSTATUS Status;
+    KPROCESSOR_MODE PreviousMode;
+    PVOID Object;
+    POBJECT_HEADER Header;
+    ACCESS_MASK DesiredAccess = (ACCESS_MASK)0;
+    NTSTATUS Status = STATUS_SUCCESS;
 
-  PAGED_CODE();
+    PAGED_CODE();
 
-  DPRINT("NtQuerySecurityObject() called\n");
+    PreviousMode = ExGetPreviousMode();
 
-  Status = ObReferenceObjectByHandle(Handle,
-				     (SecurityInformation & SACL_SECURITY_INFORMATION) ? ACCESS_SYSTEM_SECURITY : 0,
-				     NULL,
-				     KeGetPreviousMode(),
-				     &Object,
-				     NULL);
-  if (!NT_SUCCESS(Status))
+    if (PreviousMode != KernelMode)
     {
-      DPRINT1("ObReferenceObjectByHandle() failed (Status %lx)\n", Status);
-      return Status;
+        _SEH_TRY
+        {
+            ProbeForWrite(SecurityDescriptor,
+                          Length,
+                          sizeof(ULONG));
+            ProbeForWriteUlong(ResultLength);
+        }
+        _SEH_HANDLE
+        {
+            Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
+
+        if (!NT_SUCCESS(Status)) return Status;
     }
 
-  Header = BODY_TO_HEADER(Object);
-  if (Header->Type == NULL)
+    /* get the required access rights for the operation */
+    SeQuerySecurityAccessMask(SecurityInformation,
+                              &DesiredAccess);
+
+    Status = ObReferenceObjectByHandle(Handle,
+                                       DesiredAccess,
+                                       NULL,
+                                       PreviousMode,
+                                       &Object,
+                                       NULL);
+
+    if (NT_SUCCESS(Status))
     {
-      DPRINT1("Invalid object type\n");
-      ObDereferenceObject(Object);
-      return STATUS_UNSUCCESSFUL;
-    }
+        Header = BODY_TO_HEADER(Object);
+        ASSERT(Header->Type != NULL);
 
-      *ResultLength = Length;
-      Status = Header->Type->TypeInfo.SecurityProcedure(Object,
-					    QuerySecurityDescriptor,
-					    SecurityInformation,
-					    SecurityDescriptor,
-					    ResultLength,
-                        NULL,
-                        NonPagedPool,
-                        NULL);
+        Status = Header->Type->TypeInfo.SecurityProcedure(Object,
+                                                          QuerySecurityDescriptor,
+                                                          SecurityInformation,
+                                                          SecurityDescriptor,
+                                                          &Length,
+                                                          &Header->SecurityDescriptor,
+                                                          Header->Type->TypeInfo.PoolType,
+                                                          &Header->Type->TypeInfo.GenericMapping);
 
-  ObDereferenceObject(Object);
+        ObDereferenceObject(Object);
 
-  return Status;
+        /* return the required length */
+        _SEH_TRY
+        {
+            *ResultLength = Length;
+        }
+        _SEH_HANDLE
+        {
+            Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
+    }
+
+    return Status;
 }
 
 
@@ -211,46 +238,81 @@
 		    IN SECURITY_INFORMATION SecurityInformation,
 		    IN PSECURITY_DESCRIPTOR SecurityDescriptor)
 {
-  POBJECT_HEADER Header;
-  PVOID Object;
-  NTSTATUS Status;
+    KPROCESSOR_MODE PreviousMode;
+    PVOID Object;
+    POBJECT_HEADER Header;
+    SECURITY_DESCRIPTOR_RELATIVE *CapturedSecurityDescriptor;
+    ACCESS_MASK DesiredAccess = (ACCESS_MASK)0;
+    NTSTATUS Status;
 
-  PAGED_CODE();
+    PAGED_CODE();
 
-  DPRINT("NtSetSecurityObject() called\n");
+    /* make sure the caller doesn't pass a NULL security descriptor! */
+    if (SecurityDescriptor == NULL)
+    {
+        return STATUS_ACCESS_DENIED;
+    }
 
-  Status = ObReferenceObjectByHandle(Handle,
-				     (SecurityInformation & SACL_SECURITY_INFORMATION) ? ACCESS_SYSTEM_SECURITY : 0,
-				     NULL,
-				     KeGetPreviousMode(),
-				     &Object,
-				     NULL);
-  if (!NT_SUCCESS(Status))
+    PreviousMode = ExGetPreviousMode();
+
+    /* capture and make a copy of the security descriptor */
+    Status = SeCaptureSecurityDescriptor(SecurityDescriptor,
+                                         PreviousMode,
+                                         PagedPool,
+                                         TRUE,
+                                         (PSECURITY_DESCRIPTOR*)&CapturedSecurityDescriptor);
+    if (!NT_SUCCESS(Status))
     {
-      DPRINT1("ObReferenceObjectByHandle() failed (Status %lx)\n", Status);
-      return Status;
+        DPRINT1("Capturing the security descriptor failed! Status: 0x%lx\n", Status);
+        return Status;
     }
 
-  Header = BODY_TO_HEADER(Object);
-  if (Header->Type == NULL)
+    /* make sure the security descriptor passed by the caller
+       is valid for the operation we're about to perform */
+    if (((SecurityInformation & OWNER_SECURITY_INFORMATION) &&
+         (CapturedSecurityDescriptor->Owner == 0)) ||
+        ((SecurityInformation & GROUP_SECURITY_INFORMATION) &&
+         (CapturedSecurityDescriptor->Group == 0)))
     {
-      DPRINT1("Invalid object type\n");
-      ObDereferenceObject(Object);
-      return STATUS_UNSUCCESSFUL;
+        Status = STATUS_INVALID_SECURITY_DESCR;
     }
+    else
+    {
+        /* get the required access rights for the operation */
+        SeSetSecurityAccessMask(SecurityInformation,
+                                &DesiredAccess);
 
-      Status = Header->Type->TypeInfo.SecurityProcedure(Object,
-					    SetSecurityDescriptor,
-					    SecurityInformation,
-					    SecurityDescriptor,
-					    NULL,
-                        NULL,
-                        NonPagedPool,
-                        NULL);
+        Status = ObReferenceObjectByHandle(Handle,
+                                           DesiredAccess,
+                                           NULL,
+                                           PreviousMode,
+                                           &Object,
+                                           NULL);
 
-  ObDereferenceObject(Object);
+        if (NT_SUCCESS(Status))
+        {
+            Header = BODY_TO_HEADER(Object);
+            ASSERT(Header->Type != NULL);
 
-  return Status;
+            Status = Header->Type->TypeInfo.SecurityProcedure(Object,
+                                                              SetSecurityDescriptor,
+                                                              SecurityInformation,
+                                                              (PSECURITY_DESCRIPTOR)SecurityDescriptor,
+                                                              NULL,
+                                                              &Header->SecurityDescriptor,
+                                                              Header->Type->TypeInfo.PoolType,
+                                                              &Header->Type->TypeInfo.GenericMapping);
+
+            ObDereferenceObject(Object);
+        }
+    }
+
+    /* release the descriptor */
+    SeReleaseSecurityDescriptor((PSECURITY_DESCRIPTOR)CapturedSecurityDescriptor,
+                                PreviousMode,
+                                TRUE);
+
+    return Status;
 }
 
 

Modified: trunk/reactos/ntoskrnl/se/semgr.c
--- trunk/reactos/ntoskrnl/se/semgr.c	2005-12-30 01:39:34 UTC (rev 20454)
+++ trunk/reactos/ntoskrnl/se/semgr.c	2005-12-30 01:41:02 UTC (rev 20455)
@@ -1037,7 +1037,7 @@
 	      OUT PACCESS_MASK GrantedAccess,
 	      OUT PNTSTATUS AccessStatus)
 {
-  SECURITY_SUBJECT_CONTEXT SubjectSecurityContext;
+  SECURITY_SUBJECT_CONTEXT SubjectSecurityContext = {0};
   KPROCESSOR_MODE PreviousMode;
   PTOKEN Token;
   NTSTATUS Status;
@@ -1082,8 +1082,6 @@
       return STATUS_ACCESS_VIOLATION;
     }
 
-  RtlZeroMemory(&SubjectSecurityContext,
-		sizeof(SECURITY_SUBJECT_CONTEXT));
   SubjectSecurityContext.ClientToken = Token;
   SubjectSecurityContext.ImpersonationLevel = Token->ImpersonationLevel;
 
@@ -1118,4 +1116,37 @@
   return Status;
 }
 
+VOID STDCALL
+SeQuerySecurityAccessMask(IN SECURITY_INFORMATION SecurityInformation,
+                          OUT PACCESS_MASK DesiredAccess)
+{
+    if (SecurityInformation & (OWNER_SECURITY_INFORMATION |
+                               GROUP_SECURITY_INFORMATION | DACL_SECURITY_INFORMATION))
+    {
+        *DesiredAccess |= READ_CONTROL;
+    }
+    if (SecurityInformation & SACL_SECURITY_INFORMATION)
+    {
+        *DesiredAccess |= ACCESS_SYSTEM_SECURITY;
+    }
+}
+
+VOID STDCALL
+SeSetSecurityAccessMask(IN SECURITY_INFORMATION SecurityInformation,
+                        OUT PACCESS_MASK DesiredAccess)
+{
+    if (SecurityInformation & (OWNER_SECURITY_INFORMATION | GROUP_SECURITY_INFORMATION))
+    {
+        *DesiredAccess |= WRITE_OWNER;
+    }
+    if (SecurityInformation & DACL_SECURITY_INFORMATION)
+    {
+        *DesiredAccess |= WRITE_DAC;
+    }
+    if (SecurityInformation & SACL_SECURITY_INFORMATION)
+    {
+        *DesiredAccess |= ACCESS_SYSTEM_SECURITY;
+    }
+}
+
 /* EOF */