https://git.reactos.org/?p=reactos.git;a=commitdiff;h=c8fb3f75145b94cdab5cf…
commit c8fb3f75145b94cdab5cfe908d46f8ed7326494d
Author: Jérôme Gardou <jerome.gardou(a)reactos.org>
AuthorDate: Fri May 28 16:58:59 2021 +0200
Commit: Jérôme Gardou <zefklop(a)users.noreply.github.com>
CommitDate: Wed Jun 9 11:27:18 2021 +0200
[NTOS:MM] Implement proper refcounting of page tables on amd64
CORE-17552
---
ntoskrnl/mm/ARM3/mdlsup.c | 27 ++----
ntoskrnl/mm/ARM3/miarm.h | 215 +++++++++++++++++++++++++++++++++++++-------
ntoskrnl/mm/ARM3/pagfault.c | 5 ++
ntoskrnl/mm/ARM3/pfnlist.c | 2 +
ntoskrnl/mm/ARM3/session.c | 2 +-
ntoskrnl/mm/ARM3/virtual.c | 15 ++--
ntoskrnl/mm/i386/page.c | 8 +-
sdk/include/ndk/mmtypes.h | 11 +++
8 files changed, 217 insertions(+), 68 deletions(-)
diff --git a/ntoskrnl/mm/ARM3/mdlsup.c b/ntoskrnl/mm/ARM3/mdlsup.c
index 10a2aa02f68..9332ec3153e 100644
--- a/ntoskrnl/mm/ARM3/mdlsup.c
+++ b/ntoskrnl/mm/ARM3/mdlsup.c
@@ -248,6 +248,7 @@ MiMapLockedPagesInUserSpace(
/* Acquire a share count */
Pfn1 = MI_PFN_ELEMENT(PointerPde->u.Hard.PageFrameNumber);
+ DPRINT("Incrementing %p from %p\n", Pfn1, _ReturnAddress());
OldIrql = MiAcquirePfnLock();
Pfn1->u2.ShareCount++;
MiReleasePfnLock(OldIrql);
@@ -330,9 +331,6 @@ MiUnmapLockedPagesInUserSpace(
ASSERT(MiAddressToPte(PointerPte)->u.Hard.Valid == 1);
ASSERT(PointerPte->u.Hard.Valid == 1);
- /* Dereference the page */
- MiDecrementPageTableReferences(BaseAddress);
-
/* Invalidate it */
MI_ERASE_PTE(PointerPte);
@@ -341,28 +339,17 @@ MiUnmapLockedPagesInUserSpace(
PageTablePage = PointerPde->u.Hard.PageFrameNumber;
MiDecrementShareCount(MiGetPfnEntry(PageTablePage), PageTablePage);
+ if (MiDecrementPageTableReferences(BaseAddress) == 0)
+ {
+ ASSERT(MiIsPteOnPdeBoundary(PointerPte + 1) || (NumberOfPages == 1));
+ MiDeletePde(PointerPde, Process);
+ }
+
/* Next page */
PointerPte++;
NumberOfPages--;
BaseAddress = (PVOID)((ULONG_PTR)BaseAddress + PAGE_SIZE);
MdlPages++;
-
- /* Moving to a new PDE? */
- if (PointerPde != MiAddressToPde(BaseAddress))
- {
- /* See if we should delete it */
- KeFlushProcessTb();
- PointerPde = MiPteToPde(PointerPte - 1);
- ASSERT(PointerPde->u.Hard.Valid == 1);
- if (MiQueryPageTableReferences(BaseAddress) == 0)
- {
- ASSERT(PointerPde->u.Long != 0);
- MiDeletePte(PointerPde,
- MiPteToAddress(PointerPde),
- Process,
- NULL);
- }
- }
}
KeFlushProcessTb();
diff --git a/ntoskrnl/mm/ARM3/miarm.h b/ntoskrnl/mm/ARM3/miarm.h
index a784d08e15f..8c165759277 100644
--- a/ntoskrnl/mm/ARM3/miarm.h
+++ b/ntoskrnl/mm/ARM3/miarm.h
@@ -1823,40 +1823,7 @@ MiReferenceUnusedPageAndBumpLockCount(IN PMMPFN Pfn1)
}
}
-FORCEINLINE
-VOID
-MiIncrementPageTableReferences(IN PVOID Address)
-{
- PUSHORT RefCount;
-
- RefCount = &MmWorkingSetList->UsedPageTableEntries[MiGetPdeOffset(Address)];
-
- *RefCount += 1;
- ASSERT(*RefCount <= PTE_PER_PAGE);
-}
-FORCEINLINE
-VOID
-MiDecrementPageTableReferences(IN PVOID Address)
-{
- PUSHORT RefCount;
-
- RefCount = &MmWorkingSetList->UsedPageTableEntries[MiGetPdeOffset(Address)];
-
- *RefCount -= 1;
- ASSERT(*RefCount < PTE_PER_PAGE);
-}
-
-FORCEINLINE
-USHORT
-MiQueryPageTableReferences(IN PVOID Address)
-{
- PUSHORT RefCount;
-
- RefCount = &MmWorkingSetList->UsedPageTableEntries[MiGetPdeOffset(Address)];
-
- return *RefCount;
-}
CODE_SEG("INIT")
BOOLEAN
@@ -2484,8 +2451,190 @@ MiSynchronizeSystemPde(PMMPDE PointerPde)
}
#endif
+#if _MI_PAGING_LEVELS == 2
+FORCEINLINE
+USHORT
+MiIncrementPageTableReferences(IN PVOID Address)
+{
+ PUSHORT RefCount;
+
+ RefCount = &MmWorkingSetList->UsedPageTableEntries[MiGetPdeOffset(Address)];
+
+ *RefCount += 1;
+ ASSERT(*RefCount <= PTE_PER_PAGE);
+ return *RefCount;
+}
+
+FORCEINLINE
+USHORT
+MiDecrementPageTableReferences(IN PVOID Address)
+{
+ PUSHORT RefCount;
+
+ RefCount = &MmWorkingSetList->UsedPageTableEntries[MiGetPdeOffset(Address)];
+
+ *RefCount -= 1;
+ ASSERT(*RefCount < PTE_PER_PAGE);
+ return *RefCount;
+}
+
+FORCEINLINE
+USHORT
+MiQueryPageTableReferences(IN PVOID Address)
+{
+ PUSHORT RefCount;
+
+ RefCount = &MmWorkingSetList->UsedPageTableEntries[MiGetPdeOffset(Address)];
+
+ return *RefCount;
+}
+#else
+FORCEINLINE
+USHORT
+MiIncrementPageTableReferences(IN PVOID Address)
+{
+ PMMPDE PointerPde = MiAddressToPde(Address);
+ PMMPFN Pfn;
+
+ /* We should not tinker with this one. */
+ ASSERT(PointerPde != (PMMPDE)PXE_SELFMAP);
+ DPRINT("Incrementing %p from %p\n", Address, _ReturnAddress());
+
+ /* Make sure we're locked */
+ ASSERT(PsGetCurrentThread()->OwnsProcessWorkingSetExclusive);
+
+ /* If we're bumping refcount, then it must be valid! */
+ ASSERT(PointerPde->u.Hard.Valid == 1);
+
+ /* This lies on the PFN */
+ Pfn = MiGetPfnEntry(PFN_FROM_PDE(PointerPde));
+ Pfn->OriginalPte.u.Soft.UsedPageTableEntries++;
+
+ ASSERT(Pfn->OriginalPte.u.Soft.UsedPageTableEntries <= PTE_PER_PAGE);
+
+ return Pfn->OriginalPte.u.Soft.UsedPageTableEntries;
+}
+
+FORCEINLINE
+USHORT
+MiDecrementPageTableReferences(IN PVOID Address)
+{
+ PMMPDE PointerPde = MiAddressToPde(Address);
+ PMMPFN Pfn;
+
+ /* We should not tinker with this one. */
+ ASSERT(PointerPde != (PMMPDE)PXE_SELFMAP);
+
+ DPRINT("Decrementing %p from %p\n", PointerPde, _ReturnAddress());
+
+ /* Make sure we're locked */
+ ASSERT(PsGetCurrentThread()->OwnsProcessWorkingSetExclusive);
+
+ /* If we're decreasing refcount, then it must be valid! */
+ ASSERT(PointerPde->u.Hard.Valid == 1);
+
+ /* This lies on the PFN */
+ Pfn = MiGetPfnEntry(PFN_FROM_PDE(PointerPde));
+
+ ASSERT(Pfn->OriginalPte.u.Soft.UsedPageTableEntries != 0);
+ Pfn->OriginalPte.u.Soft.UsedPageTableEntries--;
+
+ ASSERT(Pfn->OriginalPte.u.Soft.UsedPageTableEntries < PTE_PER_PAGE);
+
+ return Pfn->OriginalPte.u.Soft.UsedPageTableEntries;
+}
+
+FORCEINLINE
+USHORT
+MiQueryPageTableReferences(IN PVOID Address)
+{
+ PMMPDE PointerPde;
+ PMMPPE PointerPpe;
+#if _MI_PAGING_LEVELS == 4
+ PMMPXE PointerPxe;
+#endif
+ PMMPFN Pfn;
+
+ /* Make sure we're locked */
+ ASSERT((PsGetCurrentThread()->OwnsProcessWorkingSetExclusive) ||
(PsGetCurrentThread()->OwnsProcessWorkingSetShared));
+
+ /* Check if PXE or PPE have references first. */
+#if _MI_PAGING_LEVELS == 4
+ PointerPxe = MiAddressToPxe(Address);
+ if ((PointerPxe->u.Hard.Valid == 1) || (PointerPxe->u.Soft.Transition == 1))
+ {
+ Pfn = MiGetPfnEntry(PFN_FROM_PXE(PointerPxe));
+ if (Pfn->OriginalPte.u.Soft.UsedPageTableEntries == 0)
+ return 0;
+ }
+ else if (PointerPxe->u.Soft.UsedPageTableEntries == 0)
+ {
+ return 0;
+ }
+
+ if (PointerPxe->u.Hard.Valid == 0)
+ {
+ MiMakeSystemAddressValid(MiPteToAddress(PointerPxe), PsGetCurrentProcess());
+ }
+#endif
+
+ PointerPpe = MiAddressToPpe(Address);
+ if ((PointerPpe->u.Hard.Valid == 1) || (PointerPpe->u.Soft.Transition == 1))
+ {
+ Pfn = MiGetPfnEntry(PFN_FROM_PPE(PointerPpe));
+ if (Pfn->OriginalPte.u.Soft.UsedPageTableEntries == 0)
+ return 0;
+ }
+ else if (PointerPpe->u.Soft.UsedPageTableEntries == 0)
+ {
+ return 0;
+ }
+
+ if (PointerPpe->u.Hard.Valid == 0)
+ {
+ MiMakeSystemAddressValid(MiPteToAddress(PointerPpe), PsGetCurrentProcess());
+ }
+
+ PointerPde = MiAddressToPde(Address);
+ if ((PointerPde->u.Hard.Valid == 0) && (PointerPde->u.Soft.Transition
== 0))
+ {
+ return PointerPde->u.Soft.UsedPageTableEntries;
+ }
+
+ /* This lies on the PFN */
+ Pfn = MiGetPfnEntry(PFN_FROM_PDE(PointerPde));
+ return Pfn->OriginalPte.u.Soft.UsedPageTableEntries;
+}
+#endif
+
#ifdef __cplusplus
} // extern "C"
#endif
+FORCEINLINE
+VOID
+MiDeletePde(
+ _In_ PMMPDE PointerPde,
+ _In_ PEPROCESS CurrentProcess)
+{
+ /* Only for user-mode ones */
+ ASSERT(MiIsUserPde(PointerPde));
+
+ /* Kill this one as a PTE */
+ MiDeletePte((PMMPTE)PointerPde, MiPdeToPte(PointerPde), CurrentProcess, NULL);
+#if _MI_PAGING_LEVELS >= 3
+ /* Cascade down */
+ if (MiDecrementPageTableReferences(MiPdeToPte(PointerPde)) == 0)
+ {
+ MiDeletePte(MiPdeToPpe(PointerPde), PointerPde, CurrentProcess, NULL);
+#if _MI_PAGING_LEVELS == 4
+ if (MiDecrementPageTableReferences(PointerPde) == 0)
+ {
+ MiDeletePte(MiPdeToPxe(PointerPde), MiPdeToPpe(PointerPde), CurrentProcess,
NULL);
+ }
+#endif
+ }
+#endif
+}
+
/* EOF */
diff --git a/ntoskrnl/mm/ARM3/pagfault.c b/ntoskrnl/mm/ARM3/pagfault.c
index 87c789c1742..b6e2f9e8287 100644
--- a/ntoskrnl/mm/ARM3/pagfault.c
+++ b/ntoskrnl/mm/ARM3/pagfault.c
@@ -2145,6 +2145,7 @@ UserFault:
/* We should come back with a valid PPE */
ASSERT(PointerPpe->u.Hard.Valid == 1);
+ MiIncrementPageTableReferences(PointerPde);
}
#endif
@@ -2184,6 +2185,10 @@ UserFault:
MM_EXECUTE_READWRITE,
CurrentProcess,
MM_NOIRQL);
+#if _MI_PAGING_LEVELS >= 3
+ MiIncrementPageTableReferences(PointerPte);
+#endif
+
#if MI_TRACE_PFNS
UserPdeFault = FALSE;
#endif
diff --git a/ntoskrnl/mm/ARM3/pfnlist.c b/ntoskrnl/mm/ARM3/pfnlist.c
index b9838f29602..726143870f4 100644
--- a/ntoskrnl/mm/ARM3/pfnlist.c
+++ b/ntoskrnl/mm/ARM3/pfnlist.c
@@ -1027,6 +1027,8 @@ MiInitializePfn(IN PFN_NUMBER PageFrameIndex,
ASSERT(PageFrameIndex != 0);
Pfn1->u4.PteFrame = PageFrameIndex;
+ DPRINT("Incrementing share count of %lp from %p\n", PageFrameIndex,
_ReturnAddress());
+
/* Increase its share count so we don't get rid of it */
Pfn1 = MI_PFN_ELEMENT(PageFrameIndex);
Pfn1->u2.ShareCount++;
diff --git a/ntoskrnl/mm/ARM3/session.c b/ntoskrnl/mm/ARM3/session.c
index cb4de5f9b5f..70ae1ec9889 100644
--- a/ntoskrnl/mm/ARM3/session.c
+++ b/ntoskrnl/mm/ARM3/session.c
@@ -477,7 +477,7 @@ MiSessionInitializeWorkingSetList(VOID)
/* Fill out the two pointers */
MmSessionSpace->Vm.VmWorkingSetList = WorkingSetList;
- MmSessionSpace->Wsle = (PMMWSLE)WorkingSetList->UsedPageTableEntries;
+ MmSessionSpace->Wsle = (PMMWSLE)((&WorkingSetList->VadBitMapHint) + 1);
/* Get the PDE for the working set, and check if it's already allocated */
PointerPde = MiAddressToPde(WorkingSetList);
diff --git a/ntoskrnl/mm/ARM3/virtual.c b/ntoskrnl/mm/ARM3/virtual.c
index 1f897c4e10d..f43351b7933 100644
--- a/ntoskrnl/mm/ARM3/virtual.c
+++ b/ntoskrnl/mm/ARM3/virtual.c
@@ -727,18 +727,15 @@ MiDeleteVirtualAddresses(IN ULONG_PTR Va,
/* Check remaining PTE count (go back 1 page due to above loop) */
if (MiQueryPageTableReferences((PVOID)(Va - PAGE_SIZE)) == 0)
{
- if (PointerPde->u.Long != 0)
- {
- /* Delete the PTE proper */
- MiDeletePte(PointerPde,
- MiPteToAddress(PointerPde),
- CurrentProcess,
- NULL);
- }
+ ASSERT(PointerPde->u.Long != 0);
+
+ /* Delete the PDE proper */
+ MiDeletePde(PointerPde, CurrentProcess);
}
- /* Release the lock and get out if we're done */
+ /* Release the lock */
MiReleasePfnLock(OldIrql);
+
if (Va > EndingAddress) return;
/* Otherwise, we exited because we hit a new PDE boundary, so start over */
diff --git a/ntoskrnl/mm/i386/page.c b/ntoskrnl/mm/i386/page.c
index d15a9f74964..7e7db5bd431 100644
--- a/ntoskrnl/mm/i386/page.c
+++ b/ntoskrnl/mm/i386/page.c
@@ -238,11 +238,10 @@ MmDeleteVirtualMapping(PEPROCESS Process, PVOID Address,
if (Address < MmSystemRangeStart)
{
/* Remove PDE reference */
- MiDecrementPageTableReferences(Address);
- if (MiQueryPageTableReferences(Address) == 0)
+ if (MiDecrementPageTableReferences(Address) == 0)
{
KIRQL OldIrql = MiAcquirePfnLock();
- MiDeletePte(MiAddressToPte(PointerPte), PointerPte, Process, NULL);
+ MiDeletePde(MiAddressToPde(Address), Process);
MiReleasePfnLock(OldIrql);
}
@@ -293,8 +292,7 @@ MmDeletePageFileMapping(
}
/* This used to be a non-zero PTE, now we can let the PDE go. */
- MiDecrementPageTableReferences(Address);
- if (MiQueryPageTableReferences(Address) == 0)
+ if (MiDecrementPageTableReferences(Address) == 0)
{
/* We can let it go */
KIRQL OldIrql = MiAcquirePfnLock();
diff --git a/sdk/include/ndk/mmtypes.h b/sdk/include/ndk/mmtypes.h
index 28d6cb00006..f33d5cc8b14 100644
--- a/sdk/include/ndk/mmtypes.h
+++ b/sdk/include/ndk/mmtypes.h
@@ -879,8 +879,19 @@ typedef struct _MMWSL
PVOID HighestPermittedHashAddress;
ULONG NumberOfImageWaiters;
ULONG VadBitMapHint;
+#ifndef _M_AMD64
USHORT UsedPageTableEntries[768];
ULONG CommittedPageTables[24];
+#else
+ VOID* HighestUserAddress;
+ ULONG MaximumUserPageTablePages;
+ ULONG MaximumUserPageDirectoryPages;
+ ULONG* CommittedPageTables;
+ ULONG NumberOfCommittedPageDirectories;
+ ULONG* CommittedPageDirectories;
+ ULONG NumberOfCommittedPageDirectoryParents;
+ ULONGLONG CommittedPageDirectoryParents[1];
+#endif
} MMWSL, *PMMWSL;
//