Author: ion
Date: Mon Oct 30 17:46:56 2006
New Revision: 24672
URL:
http://svn.reactos.org/svn/reactos?rev=24672&view=rev
Log:
- Implement NtSecureConnectPort so that clients can connect to SMSS. Does not yet support
actual secure connections (with a SID) and will fail those requests. Also doesn't
support memory-mapped LPC yet.
Modified:
trunk/reactos/ntoskrnl/lpc/ntlpc/connect.c
Modified: trunk/reactos/ntoskrnl/lpc/ntlpc/connect.c
URL:
http://svn.reactos.org/svn/reactos/trunk/reactos/ntoskrnl/lpc/ntlpc/connect…
==============================================================================
--- trunk/reactos/ntoskrnl/lpc/ntlpc/connect.c (original)
+++ trunk/reactos/ntoskrnl/lpc/ntlpc/connect.c Mon Oct 30 17:46:56 2006
@@ -14,6 +14,52 @@
#include <internal/debug.h>
/* PRIVATE FUNCTIONS *********************************************************/
+
+PVOID
+NTAPI
+LpcpFreeConMsg(IN OUT PLPCP_MESSAGE *Message,
+ IN OUT PLPCP_CONNECTION_MESSAGE *ConnectMessage,
+ IN PETHREAD CurrentThread)
+{
+ PVOID SectionToMap;
+
+ /* Acquire the LPC lock */
+ KeAcquireGuardedMutex(&LpcpLock);
+
+ /* Check if the reply chain is not empty */
+ if (!IsListEmpty(&CurrentThread->LpcReplyChain))
+ {
+ /* Remove this entry and re-initialize it */
+ RemoveEntryList(&CurrentThread->LpcReplyChain);
+ InitializeListHead(&CurrentThread->LpcReplyChain);
+ }
+
+ /* Check if there's a reply message */
+ if (CurrentThread->LpcReplyMessage)
+ {
+ /* Get the message */
+ *Message = CurrentThread->LpcReplyMessage;
+
+ /* Clear message data */
+ CurrentThread->LpcReceivedMessageId = 0;
+ CurrentThread->LpcReplyMessage = NULL;
+
+ /* Get the connection message and clear the section */
+ *ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(*Message + 1);
+ SectionToMap = (*ConnectMessage)->SectionToMap;
+ (*ConnectMessage)->SectionToMap = NULL;
+ }
+ else
+ {
+ /* No message to return */
+ *Message = NULL;
+ SectionToMap = NULL;
+ }
+
+ /* Release the lock and return the section */
+ KeReleaseGuardedMutex(&LpcpLock);
+ return SectionToMap;
+}
/* PUBLIC FUNCTIONS **********************************************************/
@@ -32,8 +78,389 @@
IN OUT PVOID ConnectionInformation OPTIONAL,
IN OUT PULONG ConnectionInformationLength OPTIONAL)
{
- UNIMPLEMENTED;
- return STATUS_NOT_IMPLEMENTED;
+ ULONG ConnectionInfoLength = 0;
+ PLPCP_PORT_OBJECT Port, ClientPort;
+ KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
+ NTSTATUS Status = STATUS_SUCCESS;
+ HANDLE Handle;
+ PVOID SectionToMap;
+ PLPCP_MESSAGE Message;
+ PLPCP_CONNECTION_MESSAGE ConnectMessage;
+ PETHREAD Thread = PsGetCurrentThread();
+ ULONG PortMessageLength;
+ PAGED_CODE();
+ LPCTRACE(LPC_CONNECT_DEBUG,
+ "Name: %wZ. Qos: %p. Views: %p/%p\n",
+ PortName,
+ Qos,
+ ClientView,
+ ServerView);
+
+ /* Validate client view */
+ if ((ClientView) && (ClientView->Length != sizeof(PORT_VIEW)))
+ {
+ /* Fail */
+ return STATUS_INVALID_PARAMETER;
+ }
+
+ /* Validate server view */
+ if ((ServerView) && (ServerView->Length != sizeof(REMOTE_PORT_VIEW)))
+ {
+ /* Fail */
+ return STATUS_INVALID_PARAMETER;
+ }
+
+ /* Check if caller sent connection information length */
+ if (ConnectionInformationLength)
+ {
+ /* Retrieve the input length */
+ ConnectionInfoLength = *ConnectionInformationLength;
+ }
+
+ /* Get the port */
+ Status = ObReferenceObjectByName(PortName,
+ 0,
+ NULL,
+ PORT_ALL_ACCESS,
+ LpcPortObjectType,
+ PreviousMode,
+ NULL,
+ (PVOID *)&Port);
+ if (!NT_SUCCESS(Status)) return Status;
+
+ /* This has to be a connection port */
+ if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
+ {
+ /* It isn't, so fail */
+ ObDereferenceObject(Port);
+ return STATUS_INVALID_PORT_HANDLE;
+ }
+
+ /* Check if we have a SID */
+ if (ServerSid)
+ {
+ /* FIXME: TODO */
+ UNIMPLEMENTED;
+ return STATUS_NOT_IMPLEMENTED;
+ }
+
+ /* Create the client port */
+ Status = ObCreateObject(PreviousMode,
+ LpcPortObjectType,
+ NULL,
+ PreviousMode,
+ NULL,
+ sizeof(LPCP_PORT_OBJECT),
+ 0,
+ 0,
+ (PVOID *)&ClientPort);
+ if (!NT_SUCCESS(Status))
+ {
+ /* Failed, dereference the server port and return */
+ ObDereferenceObject(Port);
+ return Status;
+ }
+
+ /* Setup the client port */
+ RtlZeroMemory(ClientPort, sizeof(LPCP_PORT_OBJECT));
+ ClientPort->Flags = LPCP_CLIENT_PORT;
+ ClientPort->ConnectionPort = Port;
+ ClientPort->MaxMessageLength = Port->MaxMessageLength;
+ ClientPort->SecurityQos = *Qos;
+ InitializeListHead(&ClientPort->LpcReplyChainHead);
+ InitializeListHead(&ClientPort->LpcDataInfoChainHead);
+
+ /* Check if we have dynamic security */
+ if (Qos->ContextTrackingMode == SECURITY_DYNAMIC_TRACKING)
+ {
+ /* Remember that */
+ ClientPort->Flags |= LPCP_SECURITY_DYNAMIC;
+ }
+ else
+ {
+ /* Create our own client security */
+ Status = SeCreateClientSecurity(Thread,
+ Qos,
+ FALSE,
+ &ClientPort->StaticSecurity);
+ if (!NT_SUCCESS(Status))
+ {
+ /* Security failed, dereference and return */
+ ObDereferenceObject(ClientPort);
+ return Status;
+ }
+ }
+
+ /* Initialize the port queue */
+ Status = LpcpInitializePortQueue(ClientPort);
+ if (!NT_SUCCESS(Status))
+ {
+ /* Failed */
+ ObDereferenceObject(ClientPort);
+ return Status;
+ }
+
+ /* Check if we have a client view */
+ if (ClientView)
+ {
+ /* FIXME: TODO */
+ UNIMPLEMENTED;
+ return STATUS_NOT_IMPLEMENTED;
+ }
+ else
+ {
+ /* No section */
+ SectionToMap = NULL;
+ }
+
+ /* Normalize connection information */
+ if (ConnectionInfoLength > Port->MaxConnectionInfoLength)
+ {
+ /* Use the port's maximum allowed value */
+ ConnectionInfoLength = Port->MaxConnectionInfoLength;
+ }
+
+ /* Allocate a message from the port zone while holding the lock */
+ KeAcquireGuardedMutex(&LpcpLock);
+ Message = ExAllocateFromPagedLookasideList(&LpcpMessagesLookaside);
+ if (!Message)
+ {
+ /* Fail if we couldn't allocate a message */
+ KeReleaseGuardedMutex(&LpcpLock);
+ if (SectionToMap) ObDereferenceObject(SectionToMap);
+ ObDereferenceObject(ClientPort);
+ return STATUS_NO_MEMORY;
+ }
+
+ /* Initialize it */
+ InitializeListHead(&Message->Entry);
+ Message->RepliedToThread = NULL;
+ Message->Request.u2.ZeroInit = 0;
+
+ /* Release the lock */
+ KeReleaseGuardedMutex(&LpcpLock);
+
+ /* Set pointer to the connection message and fill in the CID */
+ ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
+ Message->Request.ClientId = Thread->Cid;
+
+ /* Check if we have a client view */
+ if (ClientView)
+ {
+ /* FIXME: TODO */
+ UNIMPLEMENTED;
+ return STATUS_NOT_IMPLEMENTED;
+ }
+ else
+ {
+ /* Set the size to 0 and clear the connect message */
+ Message->Request.ClientViewSize = 0;
+ RtlZeroMemory(ConnectMessage, sizeof(LPCP_CONNECTION_MESSAGE));
+ }
+
+ /* Set the section and client port. Port is NULL for now */
+ ConnectMessage->ClientPort = NULL;
+ ConnectMessage->SectionToMap = SectionToMap;
+
+ /* Set the data for the connection request message */
+ Message->Request.u1.s1.DataLength = sizeof(LPCP_CONNECTION_MESSAGE) +
+ ConnectionInfoLength;
+ Message->Request.u1.s1.TotalLength = sizeof(LPCP_MESSAGE) +
+ Message->Request.u1.s1.DataLength;
+ Message->Request.u2.s2.Type = LPC_CONNECTION_REQUEST;
+
+ /* Check if we have connection information */
+ if (ConnectionInformation)
+ {
+ /* Copy it in */
+ RtlMoveMemory(ConnectMessage + 1,
+ ConnectionInformation,
+ ConnectionInfoLength);
+ }
+
+ /* Acquire the port lock */
+ KeAcquireGuardedMutex(&LpcpLock);
+
+ /* Check if someone already deleted the port name */
+ if (Port->Flags & LPCP_NAME_DELETED)
+ {
+ /* Fail the request */
+ KeReleaseGuardedMutex(&LpcpLock);
+ Status = STATUS_OBJECT_NAME_NOT_FOUND;
+ goto Cleanup;
+ }
+
+ /* Associate no thread yet */
+ Message->RepliedToThread = NULL;
+
+ /* Generate the Message ID and set it */
+ Message->Request.MessageId = LpcpNextMessageId++;
+ if (!LpcpNextMessageId) LpcpNextMessageId = 1;
+ Thread->LpcReplyMessageId = Message->Request.MessageId;
+
+ /* Insert the message into the queue and thread chain */
+ InsertTailList(&Port->MsgQueue.ReceiveHead, &Message->Entry);
+ InsertTailList(&Port->LpcReplyChainHead, &Thread->LpcReplyChain);
+ Thread->LpcReplyMessage = Message;
+
+ /* Now we can finally reference the client port and link it*/
+ ObReferenceObject(ClientPort);
+ ConnectMessage->ClientPort = ClientPort;
+
+ /* Release the lock */
+ KeReleaseGuardedMutex(&LpcpLock);
+ LPCTRACE(LPC_CONNECT_DEBUG,
+ "Messages: %p/%p. Ports: %p/%p. Status: %lx\n",
+ Message,
+ ConnectMessage,
+ Port,
+ ClientPort,
+ Status);
+
+ /* If this is a waitable port, set the event */
+ if (Port->Flags & LPCP_WAITABLE_PORT) KeSetEvent(&Port->WaitEvent,
+ 1,
+ FALSE);
+
+ /* Release the queue semaphore */
+ LpcpCompleteWait(Port->MsgQueue.Semaphore);
+
+ /* Now wait for a reply */
+ LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode);
+
+ /* Check if our wait ended in success */
+ if (Status != STATUS_SUCCESS) goto Cleanup;
+
+ /* Free the connection message */
+ SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
+
+ /* Check if we got a message back */
+ if (Message)
+ {
+ /* Check for new return length */
+ if ((Message->Request.u1.s1.DataLength -
+ sizeof(LPCP_CONNECTION_MESSAGE)) < ConnectionInfoLength)
+ {
+ /* Set new normalized connection length */
+ ConnectionInfoLength = Message->Request.u1.s1.DataLength -
+ sizeof(LPCP_CONNECTION_MESSAGE);
+ }
+
+ /* Check if we had connection information */
+ if (ConnectionInformation)
+ {
+ /* Check if we had a length pointer */
+ if (ConnectionInformationLength)
+ {
+ /* Return the length */
+ *ConnectionInformationLength = ConnectionInfoLength;
+ }
+
+ /* Return the connection information */
+ RtlMoveMemory(ConnectionInformation,
+ ConnectMessage + 1,
+ ConnectionInfoLength );
+ }
+
+ /* Make sure we had a connected port */
+ if (ClientPort->ConnectedPort)
+ {
+ /* Get the message length before the port might get killed */
+ PortMessageLength = Port->MaxMessageLength;
+
+ /* Insert the client port */
+ Status = ObInsertObject(ClientPort,
+ NULL,
+ PORT_ALL_ACCESS,
+ 0,
+ (PVOID *)NULL,
+ &Handle);
+ if (NT_SUCCESS(Status))
+ {
+ /* Return the handle */
+ *PortHandle = Handle;
+ LPCTRACE(LPC_CONNECT_DEBUG,
+ "Handle: %lx. Length: %lx\n",
+ Handle,
+ PortMessageLength);
+
+ /* Check if maximum length was requested */
+ if (MaxMessageLength) *MaxMessageLength = PortMessageLength;
+
+ /* Check if we had a client view */
+ if (ClientView)
+ {
+ /* Copy it back */
+ RtlMoveMemory(ClientView,
+ &ConnectMessage->ClientView,
+ sizeof(PORT_VIEW));
+ }
+
+ /* Check if we had a server view */
+ if (ServerView)
+ {
+ /* Copy it back */
+ RtlMoveMemory(ServerView,
+ &ConnectMessage->ServerView,
+ sizeof(REMOTE_PORT_VIEW));
+ }
+ }
+ }
+ else
+ {
+ /* No connection port, we failed */
+ if (SectionToMap) ObDereferenceObject(SectionToMap);
+
+ /* Check if it's because the name got deleted */
+ if (Port->Flags & LPCP_NAME_DELETED)
+ {
+ /* Set the correct status */
+ Status = STATUS_OBJECT_NAME_NOT_FOUND;
+ }
+ else
+ {
+ /* Otherwise, the caller refused us */
+ Status = STATUS_PORT_CONNECTION_REFUSED;
+ }
+
+ /* Kill the port */
+ ObDereferenceObject(ClientPort);
+ }
+
+ /* Free the message */
+ LpcpFreeToPortZone(Message, FALSE);
+ return Status;
+ }
+
+ /* No reply message, fail */
+ if (SectionToMap) ObDereferenceObject(SectionToMap);
+ ObDereferenceObject(ClientPort);
+ return STATUS_PORT_CONNECTION_REFUSED;
+
+Cleanup:
+ /* We failed, free the message */
+ SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
+
+ /* Check if the semaphore got signaled */
+ if (KeReadStateSemaphore(&Thread->LpcReplySemaphore))
+ {
+ /* Wait on it */
+ KeWaitForSingleObject(&Thread->LpcReplySemaphore,
+ KernelMode,
+ Executive,
+ FALSE,
+ NULL);
+ }
+
+ /* Check if we had a message and free it */
+ if (Message) LpcpFreeToPortZone(Message, FALSE);
+
+ /* Dereference other objects */
+ if (SectionToMap) ObDereferenceObject(SectionToMap);
+ ObDereferenceObject(ClientPort);
+
+ /* Return status */
+ return Status;
}
/*