Author: ion
Date: Mon Oct 30 17:42:07 2006
New Revision: 24669
URL: 
http://svn.reactos.org/svn/reactos?rev=24669&view=rev
Log:
- Implement NtReplyWaitReceivePortEx (And LpcpSave/FreeDataInfoMessage and
LpcpMoveMessage). SMSS can now respond to connection attempts.
Modified:
    trunk/reactos/ntoskrnl/lpc/ntlpc/reply.c
Modified: trunk/reactos/ntoskrnl/lpc/ntlpc/reply.c
URL:
http://svn.reactos.org/svn/reactos/trunk/reactos/ntoskrnl/lpc/ntlpc/reply.c…
==============================================================================
--- trunk/reactos/ntoskrnl/lpc/ntlpc/reply.c (original)
+++ trunk/reactos/ntoskrnl/lpc/ntlpc/reply.c Mon Oct 30 17:42:07 2006
@@ -15,6 +15,115 @@
 /* PRIVATE FUNCTIONS *********************************************************/
+VOID
+NTAPI
+LpcpFreeDataInfoMessage(IN PLPCP_PORT_OBJECT Port,
+                        IN ULONG MessageId,
+                        IN ULONG CallbackId)
+{
+    PLPCP_MESSAGE Message;
+    PLIST_ENTRY ListHead, NextEntry;
+
+    /* Check if the port we want is the connection port */
+    if ((Port->Flags & LPCP_PORT_TYPE_MASK) > LPCP_UNCONNECTED_PORT)
+    {
+        /* Use it */
+        Port = Port->ConnectionPort;
+    }
+
+    /* Loop the list */
+    ListHead = &Port->LpcDataInfoChainHead;
+    NextEntry = ListHead->Flink;
+    while (ListHead != NextEntry)
+    {
+        /* Get the message */
+        Message = CONTAINING_RECORD(NextEntry, LPCP_MESSAGE, Entry);
+
+        /* Make sure it matches */
+        if ((Message->Request.MessageId == MessageId) &&
+            (Message->Request.CallbackId == CallbackId))
+        {
+            /* Unlink and free it */
+            RemoveEntryList(&Message->Entry);
+            InitializeListHead(&Message->Entry);
+            LpcpFreeToPortZone(Message, TRUE);
+            break;
+        }
+
+        /* Go to the next entry */
+        NextEntry = NextEntry->Flink;
+    }
+}
+
+VOID
+NTAPI
+LpcpSaveDataInfoMessage(IN PLPCP_PORT_OBJECT Port,
+                        IN PLPCP_MESSAGE Message)
+{
+    PAGED_CODE();
+
+    /* Acquire the lock */
+    KeAcquireGuardedMutex(&LpcpLock);
+
+    /* Check if the port we want is the connection port */
+    if ((Port->Flags & LPCP_PORT_TYPE_MASK) > LPCP_UNCONNECTED_PORT)
+    {
+        /* Use it */
+        Port = Port->ConnectionPort;
+    }
+
+    /* Link the message */
+    InsertTailList(&Port->LpcDataInfoChainHead, &Message->Entry);
+
+    /* Release the lock */
+    KeReleaseGuardedMutex(&LpcpLock);
+}
+
+VOID
+NTAPI
+LpcpMoveMessage(IN PPORT_MESSAGE Destination,
+                IN PPORT_MESSAGE Origin,
+                IN PVOID Data,
+                IN ULONG MessageType,
+                IN PCLIENT_ID ClientId)
+{
+    /* Set the Message size */
+    LPCTRACE((LPC_REPLY_DEBUG | LPC_SEND_DEBUG),
+             "Destination/Origin: %p/%p. Data: %p. Length: %lx\n",
+             Destination,
+             Origin,
+             Data,
+             Origin->u1.Length);
+    Destination->u1.Length = Origin->u1.Length;
+
+    /* Set the Message Type */
+    Destination->u2.s2.Type = !MessageType ?
+                              Origin->u2.s2.Type : MessageType & 0xFFFF;
+
+    /* Check if we have a Client ID */
+    if (ClientId)
+    {
+        /* Set the Client ID */
+        Destination->ClientId.UniqueProcess = ClientId->UniqueProcess;
+        Destination->ClientId.UniqueThread = ClientId->UniqueThread;
+    }
+    else
+    {
+        /* Otherwise, copy it */
+        Destination->ClientId.UniqueProcess = Origin->ClientId.UniqueProcess;
+        Destination->ClientId.UniqueThread = Origin->ClientId.UniqueThread;
+    }
+
+    /* Copy the MessageId and ClientViewSize */
+    Destination->MessageId = Origin->MessageId;
+    Destination->ClientViewSize = Origin->ClientViewSize;
+
+    /* Copy the Message Data */
+    RtlMoveMemory(Destination + 1,
+                  Data,
+                  ((Destination->u1.Length & 0xFFFF) + 3) &~3);
+}
+
 /* PUBLIC FUNCTIONS **********************************************************/
 /*
@@ -30,7 +139,7 @@
 }
 /*
- * @unimplemented
+ * @implemented
  */
 NTSTATUS
 NTAPI
@@ -40,8 +149,287 @@
                          OUT PPORT_MESSAGE ReceiveMessage,
                          IN PLARGE_INTEGER Timeout OPTIONAL)
 {
-    UNIMPLEMENTED;
-    return STATUS_NOT_IMPLEMENTED;
+    PLPCP_PORT_OBJECT Port, ReceivePort;
+    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(), WaitMode = PreviousMode;
+    NTSTATUS Status;
+    PLPCP_MESSAGE Message;
+    PETHREAD Thread = PsGetCurrentThread(), WakeupThread;
+    PLPCP_CONNECTION_MESSAGE ConnectMessage;
+    ULONG ConnectionInfoLength;
+    PAGED_CODE();
+    LPCTRACE(LPC_REPLY_DEBUG,
+             "Handle: %lx. Messages: %p/%p. Context: %p\n",
+             PortHandle,
+             ReplyMessage,
+             ReceiveMessage,
+             PortContext);
+
+    /* If this is a system thread, then let it page out its stack */
+    if (Thread->SystemThread) WaitMode = UserMode;
+
+    /* Check if caller has a reply message */
+    if (ReplyMessage)
+    {
+        /* Validate its length */
+        if ((ReplyMessage->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
+            ReplyMessage->u1.s1.TotalLength)
+        {
+            /* Fail */
+            return STATUS_INVALID_PARAMETER;
+        }
+
+        /* Make sure it has a valid ID */
+        if (!ReplyMessage->MessageId) return STATUS_INVALID_PARAMETER;
+    }
+
+    /* Get the Port object */
+    Status = ObReferenceObjectByHandle(PortHandle,
+                                       0,
+                                       LpcPortObjectType,
+                                       PreviousMode,
+                                       (PVOID*)&Port,
+                                       NULL);
+    if (!NT_SUCCESS(Status)) return Status;
+
+    /* Check if the caller has a reply message */
+    if (ReplyMessage)
+    {
+        /* Validate its length in respect to the port object */
+        if ((ReplyMessage->u1.s1.TotalLength > Port->MaxMessageLength) ||
+            (ReplyMessage->u1.s1.TotalLength <= ReplyMessage->u1.s1.DataLength))
+        {
+            /* Too large, fail */
+            ObDereferenceObject(Port);
+            return STATUS_PORT_MESSAGE_TOO_LONG;
+        }
+    }
+
+    /* Check if this is anything but a client port */
+    if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CLIENT_PORT)
+    {
+        /* Use the connection port */
+        ReceivePort = Port->ConnectionPort;
+    }
+    else
+    {
+        /* Otherwise, use the port itself */
+        ReceivePort = Port;
+    }
+
+    /* Check if the caller gave a reply message */
+    if (ReplyMessage)
+    {
+        /* Get the ETHREAD corresponding to it */
+        Status = PsLookupProcessThreadByCid(&ReplyMessage->ClientId,
+                                            NULL,
+                                            &WakeupThread);
+        if (!NT_SUCCESS(Status))
+        {
+            /* No thread found, fail */
+            ObDereferenceObject(Port);
+            return Status;
+        }
+
+        /* Acquire the LPC Lock */
+        KeAcquireGuardedMutex(&LpcpLock);
+
+        /* Allocate a new message */
+        Message = ExAllocateFromPagedLookasideList(&LpcpMessagesLookaside);
+        if (!Message)
+        {
+            /* Out of memory, fail */
+            KeReleaseGuardedMutex(&LpcpLock);
+            ObDereferenceObject(WakeupThread);
+            ObDereferenceObject(Port);
+            return STATUS_NO_MEMORY;
+        }
+
+        /* Initialize the header */
+        InitializeListHead(&Message->Entry);
+        Message->RepliedToThread = NULL;
+        Message->Request.u2.ZeroInit = 0;
+
+        /* Make sure this is the reply the thread is waiting for */
+        if (WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId)
+        {
+            /* It isn't, fail */
+            LpcpFreeToPortZone(Message, TRUE);
+            KeReleaseGuardedMutex(&LpcpLock);
+            ObDereferenceObject(WakeupThread);
+            ObDereferenceObject(Port);
+            return STATUS_REPLY_MESSAGE_MISMATCH;
+        }
+
+        /* Copy the message */
+        LpcpMoveMessage(&Message->Request,
+                        ReplyMessage,
+                        ReplyMessage + 1,
+                        LPC_REPLY,
+                        NULL);
+
+        /* Free any data information */
+        LpcpFreeDataInfoMessage(Port,
+                                ReplyMessage->MessageId,
+                                ReplyMessage->CallbackId);
+
+        /* Reference the thread while we use it */
+        ObReferenceObject(WakeupThread);
+        Message->RepliedToThread = WakeupThread;
+
+        /* Set this as the reply message */
+        WakeupThread->LpcReplyMessageId = 0;
+        WakeupThread->LpcReplyMessage = (PVOID)Message;
+
+        /* Check if we have messages on the reply chain */
+        if (!(WakeupThread->LpcExitThreadCalled) &&
+            !(IsListEmpty(&WakeupThread->LpcReplyChain)))
+        {
+            /* Remove us from it and reinitiailize it */
+            RemoveEntryList(&WakeupThread->LpcReplyChain);
+            InitializeListHead(&WakeupThread->LpcReplyChain);
+        }
+
+        /* Check if this is the message the thread had received */
+        if ((Thread->LpcReceivedMsgIdValid) &&
+            (Thread->LpcReceivedMessageId == ReplyMessage->MessageId))
+        {
+            /* Clear this data */
+            Thread->LpcReceivedMessageId = 0;
+            Thread->LpcReceivedMsgIdValid = FALSE;
+        }
+
+        /* Release the lock and release the LPC semaphore to wake up waiters */
+        KeReleaseGuardedMutex(&LpcpLock);
+        LpcpCompleteWait(&WakeupThread->LpcReplySemaphore);
+
+        /* Now we can let go of the thread */
+        ObDereferenceObject(WakeupThread);
+    }
+
+    /* Now wait for someone to reply to us */
+    LpcpReceiveWait(ReceivePort->MsgQueue.Semaphore, WaitMode);
+    if (Status != STATUS_SUCCESS) goto Cleanup;
+
+    /* Wait done, get the LPC lock */
+    KeAcquireGuardedMutex(&LpcpLock);
+
+    /* Check if we've received nothing */
+    if (IsListEmpty(&ReceivePort->MsgQueue.ReceiveHead))
+    {
+        /* Check if this was a waitable port and wake it */
+        if (ReceivePort->Flags & LPCP_WAITABLE_PORT)
+        {
+            /* Reset its event */
+            KeResetEvent(&ReceivePort->WaitEvent);
+        }
+
+        /* Release the lock and fail */
+        KeReleaseGuardedMutex(&LpcpLock);
+        ObDereferenceObject(Port);
+        return STATUS_UNSUCCESSFUL;
+    }
+
+    /* Get the message on the queue */
+    Message = CONTAINING_RECORD(RemoveHeadList(&ReceivePort->
+                                               MsgQueue.ReceiveHead),
+                                LPCP_MESSAGE,
+                                Entry);
+
+    /* Check if the queue is empty now */
+    if (IsListEmpty(&ReceivePort->MsgQueue.ReceiveHead))
+    {
+        /* Check if this was a waitable port */
+        if (ReceivePort->Flags & LPCP_WAITABLE_PORT)
+        {
+            /* Reset its event */
+            KeResetEvent(&ReceivePort->WaitEvent);
+        }
+    }
+
+    /* Re-initialize the message's list entry */
+    InitializeListHead(&Message->Entry);
+
+    /* Set this as the received message */
+    Thread->LpcReceivedMessageId = Message->Request.MessageId;
+    Thread->LpcReceivedMsgIdValid = TRUE;
+
+    /* Done touching global data, release the lock */
+    KeReleaseGuardedMutex(&LpcpLock);
+
+    /* Check if this was a connection request */
+    if (LpcpGetMessageType(&Message->Request) == LPC_CONNECTION_REQUEST)
+    {
+        /* Get the connection message */
+        ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
+        LPCTRACE(LPC_REPLY_DEBUG,
+                 "Request Messages: %p/%p\n",
+                 Message,
+                 ConnectMessage);
+
+        /* Get its length */
+        ConnectionInfoLength = Message->Request.u1.s1.DataLength -
+                               sizeof(LPCP_CONNECTION_MESSAGE);
+
+        /* Return it as the receive message */
+        *ReceiveMessage = Message->Request;
+
+        /* Clear our stack variable so the message doesn't get freed */
+        Message = NULL;
+
+        /* Setup the receive message */
+        ReceiveMessage->u1.s1.TotalLength = sizeof(LPCP_MESSAGE) +
+                                            ConnectionInfoLength;
+        ReceiveMessage->u1.s1.DataLength = ConnectionInfoLength;
+        RtlMoveMemory(ReceiveMessage + 1,
+                      ConnectMessage + 1,
+                      ConnectionInfoLength);
+
+        /* Clear the port context if the caller requested one */
+        if (PortContext) *PortContext = NULL;
+    }
+    else if (Message->Request.u2.s2.Type != LPC_REPLY)
+    {
+        /* Otherwise, this is a new message or event */
+        LPCTRACE(LPC_REPLY_DEBUG,
+                 "Non-Reply Messages: %p/%p\n",
+                 &Message->Request,
+                 (&Message->Request) + 1);
+
+        /* Copy it */
+        LpcpMoveMessage(ReceiveMessage,
+                        &Message->Request,
+                        (&Message->Request) + 1,
+                        0,
+                        NULL);
+
+        /* Return its context */
+        if (PortContext) *PortContext = Message->PortContext;
+
+        /* And check if it has data information */
+        if (Message->Request.u2.s2.DataInfoOffset)
+        {
+            /* It does, save it, and don't free the message below */
+            LpcpSaveDataInfoMessage(Port, Message);
+            Message = NULL;
+        }
+    }
+    else
+    {
+        /* This is a reply message, should never happen! */
+        ASSERT(FALSE);
+    }
+
+    /* If we have a message pointer here, free it */
+    if (Message) LpcpFreeToPortZone(Message, FALSE);
+
+Cleanup:
+    /* All done, dereference the port and return the status */
+    LPCTRACE(LPC_REPLY_DEBUG,
+             "Port: %p. Status: %p\n",
+             Port,
+             Status);
+    ObDereferenceObject(Port);
+    return Status;
 }
 /*