Author: ion Date: Mon Oct 30 17:46:56 2006 New Revision: 24672
URL: http://svn.reactos.org/svn/reactos?rev=24672&view=rev Log: - Implement NtSecureConnectPort so that clients can connect to SMSS. Does not yet support actual secure connections (with a SID) and will fail those requests. Also doesn't support memory-mapped LPC yet.
Modified: trunk/reactos/ntoskrnl/lpc/ntlpc/connect.c
Modified: trunk/reactos/ntoskrnl/lpc/ntlpc/connect.c URL: http://svn.reactos.org/svn/reactos/trunk/reactos/ntoskrnl/lpc/ntlpc/connect.... ============================================================================== --- trunk/reactos/ntoskrnl/lpc/ntlpc/connect.c (original) +++ trunk/reactos/ntoskrnl/lpc/ntlpc/connect.c Mon Oct 30 17:46:56 2006 @@ -14,6 +14,52 @@ #include <internal/debug.h>
/* PRIVATE FUNCTIONS *********************************************************/ + +PVOID +NTAPI +LpcpFreeConMsg(IN OUT PLPCP_MESSAGE *Message, + IN OUT PLPCP_CONNECTION_MESSAGE *ConnectMessage, + IN PETHREAD CurrentThread) +{ + PVOID SectionToMap; + + /* Acquire the LPC lock */ + KeAcquireGuardedMutex(&LpcpLock); + + /* Check if the reply chain is not empty */ + if (!IsListEmpty(&CurrentThread->LpcReplyChain)) + { + /* Remove this entry and re-initialize it */ + RemoveEntryList(&CurrentThread->LpcReplyChain); + InitializeListHead(&CurrentThread->LpcReplyChain); + } + + /* Check if there's a reply message */ + if (CurrentThread->LpcReplyMessage) + { + /* Get the message */ + *Message = CurrentThread->LpcReplyMessage; + + /* Clear message data */ + CurrentThread->LpcReceivedMessageId = 0; + CurrentThread->LpcReplyMessage = NULL; + + /* Get the connection message and clear the section */ + *ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(*Message + 1); + SectionToMap = (*ConnectMessage)->SectionToMap; + (*ConnectMessage)->SectionToMap = NULL; + } + else + { + /* No message to return */ + *Message = NULL; + SectionToMap = NULL; + } + + /* Release the lock and return the section */ + KeReleaseGuardedMutex(&LpcpLock); + return SectionToMap; +}
/* PUBLIC FUNCTIONS **********************************************************/
@@ -32,8 +78,389 @@ IN OUT PVOID ConnectionInformation OPTIONAL, IN OUT PULONG ConnectionInformationLength OPTIONAL) { - UNIMPLEMENTED; - return STATUS_NOT_IMPLEMENTED; + ULONG ConnectionInfoLength = 0; + PLPCP_PORT_OBJECT Port, ClientPort; + KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); + NTSTATUS Status = STATUS_SUCCESS; + HANDLE Handle; + PVOID SectionToMap; + PLPCP_MESSAGE Message; + PLPCP_CONNECTION_MESSAGE ConnectMessage; + PETHREAD Thread = PsGetCurrentThread(); + ULONG PortMessageLength; + PAGED_CODE(); + LPCTRACE(LPC_CONNECT_DEBUG, + "Name: %wZ. Qos: %p. Views: %p/%p\n", + PortName, + Qos, + ClientView, + ServerView); + + /* Validate client view */ + if ((ClientView) && (ClientView->Length != sizeof(PORT_VIEW))) + { + /* Fail */ + return STATUS_INVALID_PARAMETER; + } + + /* Validate server view */ + if ((ServerView) && (ServerView->Length != sizeof(REMOTE_PORT_VIEW))) + { + /* Fail */ + return STATUS_INVALID_PARAMETER; + } + + /* Check if caller sent connection information length */ + if (ConnectionInformationLength) + { + /* Retrieve the input length */ + ConnectionInfoLength = *ConnectionInformationLength; + } + + /* Get the port */ + Status = ObReferenceObjectByName(PortName, + 0, + NULL, + PORT_ALL_ACCESS, + LpcPortObjectType, + PreviousMode, + NULL, + (PVOID *)&Port); + if (!NT_SUCCESS(Status)) return Status; + + /* This has to be a connection port */ + if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT) + { + /* It isn't, so fail */ + ObDereferenceObject(Port); + return STATUS_INVALID_PORT_HANDLE; + } + + /* Check if we have a SID */ + if (ServerSid) + { + /* FIXME: TODO */ + UNIMPLEMENTED; + return STATUS_NOT_IMPLEMENTED; + } + + /* Create the client port */ + Status = ObCreateObject(PreviousMode, + LpcPortObjectType, + NULL, + PreviousMode, + NULL, + sizeof(LPCP_PORT_OBJECT), + 0, + 0, + (PVOID *)&ClientPort); + if (!NT_SUCCESS(Status)) + { + /* Failed, dereference the server port and return */ + ObDereferenceObject(Port); + return Status; + } + + /* Setup the client port */ + RtlZeroMemory(ClientPort, sizeof(LPCP_PORT_OBJECT)); + ClientPort->Flags = LPCP_CLIENT_PORT; + ClientPort->ConnectionPort = Port; + ClientPort->MaxMessageLength = Port->MaxMessageLength; + ClientPort->SecurityQos = *Qos; + InitializeListHead(&ClientPort->LpcReplyChainHead); + InitializeListHead(&ClientPort->LpcDataInfoChainHead); + + /* Check if we have dynamic security */ + if (Qos->ContextTrackingMode == SECURITY_DYNAMIC_TRACKING) + { + /* Remember that */ + ClientPort->Flags |= LPCP_SECURITY_DYNAMIC; + } + else + { + /* Create our own client security */ + Status = SeCreateClientSecurity(Thread, + Qos, + FALSE, + &ClientPort->StaticSecurity); + if (!NT_SUCCESS(Status)) + { + /* Security failed, dereference and return */ + ObDereferenceObject(ClientPort); + return Status; + } + } + + /* Initialize the port queue */ + Status = LpcpInitializePortQueue(ClientPort); + if (!NT_SUCCESS(Status)) + { + /* Failed */ + ObDereferenceObject(ClientPort); + return Status; + } + + /* Check if we have a client view */ + if (ClientView) + { + /* FIXME: TODO */ + UNIMPLEMENTED; + return STATUS_NOT_IMPLEMENTED; + } + else + { + /* No section */ + SectionToMap = NULL; + } + + /* Normalize connection information */ + if (ConnectionInfoLength > Port->MaxConnectionInfoLength) + { + /* Use the port's maximum allowed value */ + ConnectionInfoLength = Port->MaxConnectionInfoLength; + } + + /* Allocate a message from the port zone while holding the lock */ + KeAcquireGuardedMutex(&LpcpLock); + Message = ExAllocateFromPagedLookasideList(&LpcpMessagesLookaside); + if (!Message) + { + /* Fail if we couldn't allocate a message */ + KeReleaseGuardedMutex(&LpcpLock); + if (SectionToMap) ObDereferenceObject(SectionToMap); + ObDereferenceObject(ClientPort); + return STATUS_NO_MEMORY; + } + + /* Initialize it */ + InitializeListHead(&Message->Entry); + Message->RepliedToThread = NULL; + Message->Request.u2.ZeroInit = 0; + + /* Release the lock */ + KeReleaseGuardedMutex(&LpcpLock); + + /* Set pointer to the connection message and fill in the CID */ + ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1); + Message->Request.ClientId = Thread->Cid; + + /* Check if we have a client view */ + if (ClientView) + { + /* FIXME: TODO */ + UNIMPLEMENTED; + return STATUS_NOT_IMPLEMENTED; + } + else + { + /* Set the size to 0 and clear the connect message */ + Message->Request.ClientViewSize = 0; + RtlZeroMemory(ConnectMessage, sizeof(LPCP_CONNECTION_MESSAGE)); + } + + /* Set the section and client port. Port is NULL for now */ + ConnectMessage->ClientPort = NULL; + ConnectMessage->SectionToMap = SectionToMap; + + /* Set the data for the connection request message */ + Message->Request.u1.s1.DataLength = sizeof(LPCP_CONNECTION_MESSAGE) + + ConnectionInfoLength; + Message->Request.u1.s1.TotalLength = sizeof(LPCP_MESSAGE) + + Message->Request.u1.s1.DataLength; + Message->Request.u2.s2.Type = LPC_CONNECTION_REQUEST; + + /* Check if we have connection information */ + if (ConnectionInformation) + { + /* Copy it in */ + RtlMoveMemory(ConnectMessage + 1, + ConnectionInformation, + ConnectionInfoLength); + } + + /* Acquire the port lock */ + KeAcquireGuardedMutex(&LpcpLock); + + /* Check if someone already deleted the port name */ + if (Port->Flags & LPCP_NAME_DELETED) + { + /* Fail the request */ + KeReleaseGuardedMutex(&LpcpLock); + Status = STATUS_OBJECT_NAME_NOT_FOUND; + goto Cleanup; + } + + /* Associate no thread yet */ + Message->RepliedToThread = NULL; + + /* Generate the Message ID and set it */ + Message->Request.MessageId = LpcpNextMessageId++; + if (!LpcpNextMessageId) LpcpNextMessageId = 1; + Thread->LpcReplyMessageId = Message->Request.MessageId; + + /* Insert the message into the queue and thread chain */ + InsertTailList(&Port->MsgQueue.ReceiveHead, &Message->Entry); + InsertTailList(&Port->LpcReplyChainHead, &Thread->LpcReplyChain); + Thread->LpcReplyMessage = Message; + + /* Now we can finally reference the client port and link it*/ + ObReferenceObject(ClientPort); + ConnectMessage->ClientPort = ClientPort; + + /* Release the lock */ + KeReleaseGuardedMutex(&LpcpLock); + LPCTRACE(LPC_CONNECT_DEBUG, + "Messages: %p/%p. Ports: %p/%p. Status: %lx\n", + Message, + ConnectMessage, + Port, + ClientPort, + Status); + + /* If this is a waitable port, set the event */ + if (Port->Flags & LPCP_WAITABLE_PORT) KeSetEvent(&Port->WaitEvent, + 1, + FALSE); + + /* Release the queue semaphore */ + LpcpCompleteWait(Port->MsgQueue.Semaphore); + + /* Now wait for a reply */ + LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode); + + /* Check if our wait ended in success */ + if (Status != STATUS_SUCCESS) goto Cleanup; + + /* Free the connection message */ + SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread); + + /* Check if we got a message back */ + if (Message) + { + /* Check for new return length */ + if ((Message->Request.u1.s1.DataLength - + sizeof(LPCP_CONNECTION_MESSAGE)) < ConnectionInfoLength) + { + /* Set new normalized connection length */ + ConnectionInfoLength = Message->Request.u1.s1.DataLength - + sizeof(LPCP_CONNECTION_MESSAGE); + } + + /* Check if we had connection information */ + if (ConnectionInformation) + { + /* Check if we had a length pointer */ + if (ConnectionInformationLength) + { + /* Return the length */ + *ConnectionInformationLength = ConnectionInfoLength; + } + + /* Return the connection information */ + RtlMoveMemory(ConnectionInformation, + ConnectMessage + 1, + ConnectionInfoLength ); + } + + /* Make sure we had a connected port */ + if (ClientPort->ConnectedPort) + { + /* Get the message length before the port might get killed */ + PortMessageLength = Port->MaxMessageLength; + + /* Insert the client port */ + Status = ObInsertObject(ClientPort, + NULL, + PORT_ALL_ACCESS, + 0, + (PVOID *)NULL, + &Handle); + if (NT_SUCCESS(Status)) + { + /* Return the handle */ + *PortHandle = Handle; + LPCTRACE(LPC_CONNECT_DEBUG, + "Handle: %lx. Length: %lx\n", + Handle, + PortMessageLength); + + /* Check if maximum length was requested */ + if (MaxMessageLength) *MaxMessageLength = PortMessageLength; + + /* Check if we had a client view */ + if (ClientView) + { + /* Copy it back */ + RtlMoveMemory(ClientView, + &ConnectMessage->ClientView, + sizeof(PORT_VIEW)); + } + + /* Check if we had a server view */ + if (ServerView) + { + /* Copy it back */ + RtlMoveMemory(ServerView, + &ConnectMessage->ServerView, + sizeof(REMOTE_PORT_VIEW)); + } + } + } + else + { + /* No connection port, we failed */ + if (SectionToMap) ObDereferenceObject(SectionToMap); + + /* Check if it's because the name got deleted */ + if (Port->Flags & LPCP_NAME_DELETED) + { + /* Set the correct status */ + Status = STATUS_OBJECT_NAME_NOT_FOUND; + } + else + { + /* Otherwise, the caller refused us */ + Status = STATUS_PORT_CONNECTION_REFUSED; + } + + /* Kill the port */ + ObDereferenceObject(ClientPort); + } + + /* Free the message */ + LpcpFreeToPortZone(Message, FALSE); + return Status; + } + + /* No reply message, fail */ + if (SectionToMap) ObDereferenceObject(SectionToMap); + ObDereferenceObject(ClientPort); + return STATUS_PORT_CONNECTION_REFUSED; + +Cleanup: + /* We failed, free the message */ + SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread); + + /* Check if the semaphore got signaled */ + if (KeReadStateSemaphore(&Thread->LpcReplySemaphore)) + { + /* Wait on it */ + KeWaitForSingleObject(&Thread->LpcReplySemaphore, + KernelMode, + Executive, + FALSE, + NULL); + } + + /* Check if we had a message and free it */ + if (Message) LpcpFreeToPortZone(Message, FALSE); + + /* Dereference other objects */ + if (SectionToMap) ObDereferenceObject(SectionToMap); + ObDereferenceObject(ClientPort); + + /* Return status */ + return Status; }
/*