diff --git a/dissect/database/ese/btree.py b/dissect/database/ese/btree.py deleted file mode 100644 index f4c64e7..0000000 --- a/dissect/database/ese/btree.py +++ /dev/null @@ -1,177 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from dissect.database.ese.exception import KeyNotFoundError, NoNeighbourPageError - -if TYPE_CHECKING: - from dissect.database.ese.ese import ESE - from dissect.database.ese.page import Node, Page - - -class BTree: - """A simple implementation for searching the ESE B+Trees. - - This is a stateful interactive class that moves an internal cursor to a position within the BTree. - - Args: - db: An instance of :class:`~dissect.database.ese.ese.ESE`. - page: The page to open the :class:`BTree` on. - """ - - def __init__(self, db: ESE, root: int | Page): - self.db = db - - if isinstance(root, int): - page_num = root - root = db.page(page_num) - else: - page_num = root.num - - self.root = root - - self._page = root - self._page_num = page_num - self._node_num = 0 - - def reset(self) -> None: - """Reset the internal state to the root of the BTree.""" - self._page = self.root - self._page_num = self._page.num - self._node_num = 0 - - def node(self) -> Node: - """Return the node the BTree is currently on. - - Returns: - A :class:`~dissect.database.ese.page.Node` object of the current node. - """ - return self._page.node(self._node_num) - - def next(self) -> Node: - """Move the BTree to the next node and return it. - - Can move the BTree to the next page as a side effect. - - Returns: - A :class:`~dissect.database.ese.page.Node` object of the next node. - """ - if self._node_num + 1 > self._page.node_count - 1: - self.next_page() - else: - self._node_num += 1 - - return self.node() - - def next_page(self) -> None: - """Move the BTree to the next page in the tree. - - Raises: - NoNeighbourPageError: If the current page has no next page. - """ - if self._page.next_page: - self._page = self.db.page(self._page.next_page) - self._node_num = 0 - else: - raise NoNeighbourPageError(f"{self._page} has no next page") - - def prev(self) -> Node: - """Move the BTree to the previous node and return it. - - Can move the BTree to the previous page as a side effect. - - Returns: - A :class:`~dissect.database.ese.page.Node` object of the previous node. - """ - if self._node_num - 1 < 0: - self.prev_page() - else: - self._node_num -= 1 - - return self.node() - - def prev_page(self) -> None: - """Move the BTree to the previous page in the tree. - - Raises: - NoNeighbourPageError: If the current page has no previous page. - """ - if self._page.previous_page: - self._page = self.db.page(self._page.previous_page) - self._node_num = self._page.node_count - 1 - else: - raise NoNeighbourPageError(f"{self._page} has no previous page") - - def search(self, key: bytes, exact: bool = True) -> Node: - """Search the tree for the given ``key``. - - Moves the BTree to the matching node, or on the last node that is less than the requested key. - - Args: - key: The key to search for. - exact: Whether to only return successfully on an exact match. - - Raises: - KeyNotFoundError: If an ``exact`` match was requested but not found. - """ - page = self._page - while True: - num = find_node(page, key) - node = page.node(num) - - if page.is_branch: - page = self.db.page(node.child) - else: - self._page = page - self._page_num = page.num - self._node_num = node.num - break - - if exact and key != node.key: - raise KeyNotFoundError(f"Can't find key: {key}") - - return self.node() - - -def find_node(page: Page, key: bytes) -> int: - """Search a page for a node matching ``key``. - - Referencing Extensible-Storage-Engine source, they bail out early if they find an exact match. - However, we prefer to always find the _first_ node that is greater than or equal to the key, - so we can handle cases where there are duplicate index keys. This is important for "range" searches - where we want to find all keys matching a certain prefix, and not end up somewhere in the middle of the range. - - Args: - page: The page to search. - key: The key to search. - - Returns: - The node number of the first node that's greater than or equal to the key. - """ - lo, hi = 0, page.node_count - 1 - res = 0 - - node = None - while lo < hi: - mid = (lo + hi) // 2 - node = page.node(mid) - - # It turns out that the way BTree keys are compared matches 1:1 with how Python compares bytes - # First compare data, then length - res = (key < node.key) - (key > node.key) - - if res < 0: - lo = mid + 1 - else: - hi = mid - - # Final comparison on the last node - node = page.node(lo) - res = (key < node.key) - (key > node.key) - - if page.is_branch and res == 0: - # If there's an exact match on a key on a branch page, the actual leaf nodes are in the next branch - # Page keys for branch pages appear to be non-inclusive upper bounds - lo = min(lo + 1, page.node_count - 1) - - return lo diff --git a/dissect/database/ese/cursor.py b/dissect/database/ese/cursor.py index 529453b..d7882e3 100644 --- a/dissect/database/ese/cursor.py +++ b/dissect/database/ese/cursor.py @@ -2,8 +2,7 @@ from typing import TYPE_CHECKING -from dissect.database.ese.btree import BTree -from dissect.database.ese.exception import KeyNotFoundError, NoNeighbourPageError +from dissect.database.ese.exception import KeyNotFoundError from dissect.database.ese.record import Record if TYPE_CHECKING: @@ -11,13 +10,14 @@ from typing_extensions import Self + from dissect.database.ese.ese import ESE from dissect.database.ese.index import Index - from dissect.database.ese.page import Node + from dissect.database.ese.page import Node, Page from dissect.database.ese.util import RecordValue class Cursor: - """A simple cursor implementation for searching the ESE indexes. + """A simple cursor implementation for searching the ESE indexes on their records. Args: index: The :class:`~dissect.database.ese.index.Index` to create the cursor for. @@ -28,10 +28,13 @@ def __init__(self, index: Index): self.table = index.table self.db = index.db - self._primary = BTree(self.db, index.root) - self._secondary = None if index.is_primary else BTree(self.db, self.table.root) + self._primary = RawCursor(self.db, index.root) + self._secondary = None if index.is_primary else RawCursor(self.db, self.table.root) def __iter__(self) -> Iterator[Record]: + if self._primary._page.is_branch: + self._primary.first() + record = self.record() while record is not None: yield record @@ -44,9 +47,8 @@ def _node(self) -> Node: A :class:`~dissect.database.ese.page.Node` object of the current node. """ node = self._primary.node() - if self._secondary: - self._secondary.reset() - node = self._secondary.search(node.data.tobytes(), exact=True) + if self._secondary is not None: + node = self._secondary.search(node.data.tobytes(), exact=True).node() return node def record(self) -> Record: @@ -67,30 +69,22 @@ def reset(self) -> Self: def next(self) -> Record | None: """Move the cursor to the next record and return it. - Can move the cursor to the next page as a side effect. - Returns: A :class:`~dissect.database.ese.record.Record` object of the next record. """ - try: - self._primary.next() - except NoNeighbourPageError: - return None - return self.record() + if self._primary.next(): + return self.record() + return None def prev(self) -> Record | None: """Move the cursor to the previous node and return it. - Can move the cursor to the previous page as a side effect. - Returns: A :class:`~dissect.database.ese.record.Record` object of the previous record. """ - try: - self._primary.prev() - except NoNeighbourPageError: - return None - return self.record() + if self._primary.prev(): + return self.record() + return None def make_key(self, *args: RecordValue, **kwargs: RecordValue) -> bytes: """Generate a key for this index from the given values. @@ -137,7 +131,7 @@ def search_key(self, key: bytes, exact: bool = True) -> Record: exact: If ``True``, search for an exact match. If ``False``, sets the cursor on the next record that is greater than or equal to the key. """ - self._primary.search(key, exact) + self._primary.search(key, exact=exact) return self.record() def seek(self, *args: RecordValue, **kwargs: RecordValue) -> Self: @@ -189,18 +183,6 @@ def find_all(self, **kwargs: RecordValue) -> Iterator[Record]: return current_key = self._primary.node().key - - # Check if we need to move the cursor back to find the first record - while True: - if current_key != self._primary.node().key: - self._primary.next() - break - - try: - self._primary.prev() - except NoNeighbourPageError: - break - while True: # Entries with the same indexed columns are guaranteed to be adjacent if current_key != self._primary.node().key: @@ -224,7 +206,228 @@ def find_all(self, **kwargs: RecordValue) -> Iterator[Record]: else: yield record - try: - self._primary.next() - except NoNeighbourPageError: + if not self._primary.next(): break + + +class RawCursor: + """A simple cursor implementation for searching the ESE B+Trees on their raw nodes. + + Args: + db: An instance of :class:`~dissect.database.ese.ese.ESE`. + root: The page to open the raw cursor on. + """ + + def __init__(self, db: ESE, root: Page | int): + self.db = db + self.root = db.page(root) if isinstance(root, int) else root + + self._page = self.root + self._idx = 0 + + # Stack of (page, idx, stack[:]) for traversing back up the tree when doing in-order traversal + self._stack = [] + + @property + def state(self) -> tuple[Page, int, list[tuple[Page, int]]]: + """Get the current cursor state.""" + return self._page, self._idx, self._stack[:] + + @state.setter + def state(self, value: tuple[Page, int, list[tuple[Page, int]]]) -> None: + """Set the current cursor state.""" + self._page, self._idx, self._stack = value[0], value[1], value[2][:] + + def reset(self) -> Self: + """Reset the cursor to the root of the B+Tree.""" + self._page = self.root + self._idx = 0 + self._stack = [] + + return self + + def node(self) -> Node: + """Return the node the cursor is currently on. + + Returns: + A :class:`~dissect.database.ese.page.Node` object of the current node. + """ + return self._page.node(self._idx) + + def first(self) -> bool: + """Move the cursor to the first leaf node in the B+Tree.""" + self.reset() + while self._page.is_branch and self._page.node_count > 0: + self.push() + + return self._page.node_count != 0 + + def last(self) -> bool: + """Move the cursor to the last leaf node in the B+Tree.""" + self.reset() + while self._page.is_branch and self._page.node_count > 0: + self._idx = self._page.node_count - 1 + self.push() + + self._idx = self._page.node_count - 1 + return self._page.node_count != 0 + + def next(self) -> bool: + """Move the cursor to the next leaf node.""" + if self._page.is_branch: + # Treat as if we were at the first node + self.first() + return self._page.node_count != 0 + + if self._idx + 1 < self._page.node_count: + self._idx += 1 + elif self._stack: + # End of current page, traverse to the next leaf page + + # First pop until we find a page with unvisited nodes + while self._idx + 1 >= self._page.node_count: + if not self._stack: + return False + self.pop() + + self._idx += 1 + + # Then push down to the next page + while self._page.is_branch: + self.push() + else: + return False + + return True + + def prev(self) -> bool: + """Move the cursor to the previous leaf node.""" + if self._page.is_branch: + # Treat as if we were at the last node + self.last() + return self._page.node_count != 0 + + if self._idx - 1 >= 0: + self._idx -= 1 + elif self._stack: + # Start of current page, traverse to the previous leaf page + + # First pop until we find a page with unvisited nodes + while self._idx - 1 < 0: + if not self._stack: + # Start of B+Tree reached + return False + self.pop() + + self._idx -= 1 + + # Then push down to the rightmost leaf + while self._page.is_branch: + self._idx = self._page.node_count - 1 + self.push() + else: + # Start of B+Tree reached + return False + + return True + + def push(self) -> Self: + """Push down to the child page at the current index.""" + child_page = self.db.page(self._page.node(self._idx).child) + + self._stack.append((self._page, self._idx)) + self._page = child_page + self._idx = 0 + + return self + + def pop(self) -> Self: + """Pop back to the parent page.""" + if not self._stack: + raise IndexError("Cannot pop from an empty stack") + + self._page, self._idx = self._stack.pop() + + return self + + def walk(self) -> Iterator[Node]: + """Walk the B+Tree in order, yielding nodes.""" + if self.first(): + yield self.node() + + while self.next(): + yield self.node() + + def search(self, key: bytes, *, exact: bool = True) -> Self: + """Search the tree for the given ``key``. + + Moves the cursor to the matching node, or on the last node that is less than the requested key. + + Args: + key: The key to search for. + exact: Whether to only return successfully on an exact match. + + Raises: + KeyNotFoundError: If an ``exact`` match was requested but not found. + """ + self.reset() + + while self._page.is_branch: + self._idx = find_node(self._page, key, exact=False) + self.push() + + self._idx = find_node(self._page, key, exact=exact) + if self._idx >= self._page.node_count or self._idx == -1: + raise KeyNotFoundError(f"Key not found: {key!r}") + + return self + + +def find_node(page: Page, key: bytes, *, exact: bool) -> int: + """Search a page for a node matching the given key. + + Referencing Extensible-Storage-Engine source, they bail out early if they find an exact match. + However, we prefer to always find the _first_ node that is greater than or equal to the key, + so we can handle cases where there are duplicate index keys. This is important for "range" searches + where we want to find all keys matching a certain prefix, and not end up somewhere in the middle of the range. + + Args: + page: The page to search. + key: The key to search. + exact: Whether to only return successfully on an exact match. + + Returns: + The node number of the first node that's greater than or equal to the key, or the last node on the page if + the key is larger than all nodes. If ``exact`` is ``True`` and an exact match is not found, returns -1. + """ + if page.node_count == 0: + return -1 + + lo, hi = 0, page.node_count - 1 + + node = None + while lo < hi: + mid = (lo + hi) // 2 + node = page.node(mid) + + # It turns out that the way BTree keys are compared matches 1:1 with how Python compares bytes + # First compare data, then length + if key > node.key: + lo = mid + 1 + else: + hi = mid + + # Final comparison on the last node + node = page.node(lo) + + if key == node.key: + if page.is_branch: + # If there's an exact match on a key on a branch page, the actual leaf nodes are in the next branch + # Page keys for branch pages appear to be non-inclusive upper bounds + lo = min(lo + 1, page.node_count - 1) + + # key != node.key + elif exact: + return -1 + + return lo diff --git a/dissect/database/ese/ntds/database.py b/dissect/database/ese/ntds/database.py index 8c9fdfe..7d080c6 100644 --- a/dissect/database/ese/ntds/database.py +++ b/dissect/database/ese/ntds/database.py @@ -114,10 +114,17 @@ def walk(self) -> Iterator[Object]: yield (obj := stack.pop()) stack.extend(obj.children()) - def iter(self) -> Iterator[Object]: - """Iterate over all objects in the NTDS database.""" + def iter(self, raw: bool = False) -> Iterator[Object]: + """Iterate over all objects in the NTDS database. + + Args: + raw: Whether to return base :class:`Object` instances without upcasting to more specific types + based on the objectClass. + """ + from_record = Object if raw else Object.from_record + for record in self.table.records(): - yield Object.from_record(self.db, record) + yield from_record(self.db, record) def get(self, dnt: int) -> Object: """Retrieve an object by its Directory Number Tag (DNT) value. diff --git a/dissect/database/ese/ntds/query.py b/dissect/database/ese/ntds/query.py index e48d1b0..a95e18a 100644 --- a/dissect/database/ese/ntds/query.py +++ b/dissect/database/ese/ntds/query.py @@ -1,7 +1,9 @@ from __future__ import annotations +import fnmatch import logging -from typing import TYPE_CHECKING, Any +import re +from typing import TYPE_CHECKING from dissect.util.ldap import LogicalOperator, SearchFilter @@ -33,30 +35,25 @@ def process(self) -> Iterator[Object]: """ yield from self._process_query(self._filter) - def _process_query(self, filter: SearchFilter, records: list[Record] | None = None) -> Iterator[Record]: + def _process_query(self, filter: SearchFilter, records: Iterator[Record] | None = None) -> Iterator[Record]: """Process LDAP query recursively, handling nested logical operations. Args: filter: The LDAP search filter to process. - records: Optional list of records to filter instead of querying the database. + records: Optional iterable of records to filter instead of querying the database. Yields: Records matching the search filter. """ - if not filter.is_nested(): - if records is None: - try: - yield from self._query_database(filter) - except IndexError: - log.debug("No records found for filter: %s", filter) - else: - yield from self._filter_records(filter, records) - return - - if filter.operator == LogicalOperator.AND: - yield from self._process_and_operation(filter, records) - elif filter.operator == LogicalOperator.OR: - yield from self._process_or_operation(filter, records) + if filter.is_nested(): + if filter.operator == LogicalOperator.AND: + yield from self._process_and_operation(filter, records) + elif filter.operator == LogicalOperator.OR: + yield from self._process_or_operation(filter, records) + elif records is not None: + yield from self._filter_records(filter, records) + else: + yield from self._query_database(filter) def _query_database(self, filter: SearchFilter) -> Iterator[Record]: """Execute a simple LDAP filter against the database. @@ -73,25 +70,33 @@ def _query_database(self, filter: SearchFilter) -> Iterator[Record]: # Get the database index for this attribute if (index := self.db.data.table.find_index([schema.column])) is None: - raise ValueError(f"Index for attribute {schema.column!r} not found in the NTDS database") - - if "*" in filter.value: - # Handle wildcard searches differently - if filter.value.endswith("*"): - yield from _process_wildcard_tail(index, filter.value) - else: - raise NotImplementedError("Wildcards in the middle or start of the value are not yet supported") + # If no index is available, we have to scan the entire table + log.debug("No index found for attribute %s (%s), scanning entire table", filter.attribute, schema.column) + yield from self._filter_records(filter, self.db.data.table.records()) else: - # Exact match query - encoded_value = encode_value(self.db, schema, filter.value) - yield from index.cursor().find_all(**{schema.column: encoded_value}) + if "*" in filter.value: + # Handle wildcard searches differently + if filter.value.endswith("*"): + yield from _process_wildcard_tail(index, filter.value) + else: + # For more complex wildcard patterns, we need to scan the index and apply the filter + log.debug( + "Complex wildcard search for attribute %s (%s), scanning entire index", + filter.attribute, + schema.column, + ) + yield from self._filter_records(filter, index.cursor()) + else: + # Exact match query + encoded_value = encode_value(self.db, schema, filter.value) + yield from index.cursor().find_all(**{schema.column: encoded_value}) - def _process_and_operation(self, filter: SearchFilter, records: list[Record] | None) -> Iterator[Record]: + def _process_and_operation(self, filter: SearchFilter, records: Iterator[Record] | None) -> Iterator[Record]: """Process AND logical operation. Args: filter: The LDAP search filter with AND operator. - records: Optional list of records to filter. + records: Optional iterable of records to filter. Yields: Records matching all conditions in the AND operation. @@ -102,19 +107,19 @@ def _process_and_operation(self, filter: SearchFilter, records: list[Record] | N else: # Use the first child as base query, then filter with remaining children base_query, *remaining_children = filter.children - records_to_process = list(self._process_query(base_query)) + records_to_process = self._process_query(base_query) children_to_check = remaining_children for record in records_to_process: if all(any(self._process_query(child, records=[record])) for child in children_to_check): yield record - def _process_or_operation(self, filter: SearchFilter, records: list[Record] | None) -> Iterator[Record]: + def _process_or_operation(self, filter: SearchFilter, records: Iterator[Record] | None) -> Iterator[Record]: """Process OR logical operation. Args: filter: The LDAP search filter with OR operator. - records: Optional list of records to filter. + records: Optional iterable of records to filter. Yields: Records matching any condition in the OR operation. @@ -122,12 +127,12 @@ def _process_or_operation(self, filter: SearchFilter, records: list[Record] | No for child in filter.children: yield from self._process_query(child, records=records) - def _filter_records(self, filter: SearchFilter, records: list[Record]) -> Iterator[Record]: - """Filter a list of records against a simple LDAP filter. + def _filter_records(self, filter: SearchFilter, records: Iterator[Record]) -> Iterator[Record]: + """Filter an iterable of records against a simple LDAP filter. Args: filter: The LDAP search filter to apply. - records: The list of records to filter. + records: The iterable of records to filter. Yields: Records that match the filter criteria. @@ -136,14 +141,26 @@ def _filter_records(self, filter: SearchFilter, records: list[Record]) -> Iterat return encoded_value = encode_value(self.db, schema, filter.value) + re_encoded_value = None - has_wildcard = "*" in filter.value - wildcard_prefix = filter.value.replace("*", "").lower() if has_wildcard else None + has_wildcard = "*" in filter.value and isinstance(encoded_value, str) + if has_wildcard: + re_encoded_value = re.compile(fnmatch.translate(encoded_value), re.IGNORECASE) for record in records: record_value = record.get(schema.column) - if _value_matches_filter(record_value, encoded_value, has_wildcard, wildcard_prefix): + if isinstance(record_value, list): + # Currently assume that we can only search for single values, not lists + if has_wildcard and record_value and isinstance(record_value[0], str): + if any(re_encoded_value.match(rv) for rv in record_value): + yield record + elif encoded_value in record_value: + yield record + + elif ( + has_wildcard and isinstance(record_value, str) and re_encoded_value.match(record_value) + ) or record_value == encoded_value: yield record @@ -174,26 +191,6 @@ def _process_wildcard_tail(index: Index, filter_value: str) -> Iterator[Record]: record = cursor.next() -def _value_matches_filter( - record_value: Any, encoded_value: Any, has_wildcard: bool, wildcard_prefix: str | None -) -> bool: - """Return whether a record value matches the filter criteria. - - Args: - record_value: The value from the database record. - encoded_value: The encoded filter value to match against. - has_wildcard: Whether the filter contains wildcard characters. - wildcard_prefix: The prefix to match for wildcard searches. - """ - if isinstance(record_value, list): - return encoded_value in record_value - - if has_wildcard and wildcard_prefix and isinstance(record_value, str): - return record_value.lower().startswith(wildcard_prefix) - - return encoded_value == record_value - - def _increment_last_char(value: str) -> str: """Increment the last character in a string to find the next lexicographically sortable key. diff --git a/dissect/database/ese/page.py b/dissect/database/ese/page.py index 28b942b..41f1e16 100644 --- a/dissect/database/ese/page.py +++ b/dissect/database/ese/page.py @@ -49,6 +49,9 @@ def __init__(self, db: ESE, num: int, buf: bytes): self._node_cls = LeafNode if self.is_leaf else BranchNode self._node_cache = {} + def __repr__(self) -> str: + return f"" + @cached_property def is_small_page(self) -> bool: return self.db.has_small_pages @@ -123,7 +126,7 @@ def node(self, num: int) -> BranchNode | LeafNode: IndexError: If the node number is out of bounds. """ if num < 0 or num > self.node_count - 1: - raise IndexError(f"Node number exceeds boundaries: 0-{self.node_count - 1}") + raise IndexError(f"Node number exceeds boundaries 0-{self.node_count - 1}: {num}") if num not in self._node_cache: self._node_cache[num] = self._node_cls(self.tag(num + 1)) @@ -161,9 +164,6 @@ def iter_leaf_nodes(self) -> Iterator[LeafNode]: if self.is_root and leaf and leaf.tag.page.next_page: yield from db.page(leaf.tag.page.next_page).iter_leaf_nodes() - def __repr__(self) -> str: - return f"" - class Tag: """A tag is the "physical" data entry of a page. diff --git a/dissect/database/ese/table.py b/dissect/database/ese/table.py index f862e48..03ba67f 100644 --- a/dissect/database/ese/table.py +++ b/dissect/database/ese/table.py @@ -5,14 +5,13 @@ from typing import TYPE_CHECKING, Any from dissect.database.ese import compression -from dissect.database.ese.btree import BTree from dissect.database.ese.c_ese import ( CODEPAGE, FIELDFLAG, SYSOBJ, JET_coltyp, ) -from dissect.database.ese.exception import NoNeighbourPageError +from dissect.database.ese.cursor import RawCursor from dissect.database.ese.index import Index from dissect.database.ese.record import Record from dissect.database.ese.util import COLUMN_TYPE_MAP, ColumnType, RecordValue @@ -189,19 +188,16 @@ def get_long_value(self, key: bytes) -> bytes: key: The lookup key for the long value. """ rkey = key[::-1] - btree = BTree(self.db, self.lv_page) - header = btree.search(rkey) + cursor = RawCursor(self.db, self.lv_page) + header = cursor.search(rkey).node() _, size = struct.unpack("<2I", header.data) chunks = [] chunk_offsets = [] - while True: - try: - node = btree.next() - if not node.key.startswith(rkey): - break - except NoNeighbourPageError: + while cursor.next(): + node = cursor.node() + if not node.key.startswith(rkey): break chunks.append(node.data) diff --git a/tests/ese/ntds/test_query.py b/tests/ese/ntds/test_query.py index 31b8f02..5ed0652 100644 --- a/tests/ese/ntds/test_query.py +++ b/tests/ese/ntds/test_query.py @@ -86,6 +86,18 @@ def test_simple_wildcard(goad: NTDS) -> None: assert len(records) == 1 assert mock_fetch.call_count == 1 + query = Query(goad.db, "(&(sAMAccountName=*odor)(objectCategory=person))") + with patch.object(query, "_query_database", wraps=query._query_database) as mock_fetch: + records = list(query.process()) + assert len(records) == 1 + assert mock_fetch.call_count == 1 + + query = Query(goad.db, "(&(sAMAccountName=h*d*r)(objectCategory=person))") + with patch.object(query, "_query_database", wraps=query._query_database) as mock_fetch: + records = list(query.process()) + assert len(records) == 1 + assert mock_fetch.call_count == 1 + def test_simple_wildcard_in_AND(goad: NTDS) -> None: query = Query(goad.db, "(&(objectCategory=person)(sAMAccountName=hod*))") @@ -102,14 +114,13 @@ def test_invalid_attribute(goad: NTDS) -> None: list(query.process()) -def test_invalid_index(goad: NTDS) -> None: - """Test index not found for attribute.""" - query = Query(goad.db, "(cn=ThisIsNotExistingInTheDB)") - with ( - patch.object(goad.db.data.table, "find_index", return_value=None), - pytest.raises(ValueError, match=r"Index for attribute.*not found in the NTDS database"), - ): - list(query.process()) +def test_no_index(goad: NTDS) -> None: + """Test searching for attribute with no index.""" + schema = goad.db.data.schema.lookup_attribute(name="description") + assert goad.db.data.table.find_index([schema.column]) is None + + query = Query(goad.db, "(description=Brainless*)") + assert len(list(query.process())) == 1 def test_increment_last_char() -> None: