diff --git a/pyiceberg/catalog/sql.py b/pyiceberg/catalog/sql.py index bc11e4f907..e656fbed64 100644 --- a/pyiceberg/catalog/sql.py +++ b/pyiceberg/catalog/sql.py @@ -497,10 +497,16 @@ def commit_table( def _namespace_exists(self, identifier: Union[str, Identifier]) -> bool: namespace_tuple = Catalog.identifier_to_tuple(identifier) namespace = Catalog.namespace_to_string(namespace_tuple, NoSuchNamespaceError) + namespace_starts_with = namespace.replace("!", "!!").replace("_", "!_").replace("%", "!%") + ".%" + with Session(self.engine) as session: stmt = ( select(IcebergTables) - .where(IcebergTables.catalog_name == self.name, IcebergTables.table_namespace == namespace) + .where( + IcebergTables.catalog_name == self.name, + (IcebergTables.table_namespace == namespace) + | (IcebergTables.table_namespace.like(namespace_starts_with, escape="!")), + ) .limit(1) ) result = session.execute(stmt).all() @@ -510,7 +516,8 @@ def _namespace_exists(self, identifier: Union[str, Identifier]) -> bool: select(IcebergNamespaceProperties) .where( IcebergNamespaceProperties.catalog_name == self.name, - IcebergNamespaceProperties.namespace == namespace, + (IcebergNamespaceProperties.namespace == namespace) + | (IcebergNamespaceProperties.namespace.like(namespace_starts_with, escape="!")), ) .limit(1) ) diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index d2800363a6..33a76f7308 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -1110,6 +1110,24 @@ def test_create_namespace_with_empty_identifier(catalog: SqlCatalog, empty_names catalog.create_namespace(empty_namespace) +@pytest.mark.parametrize( + "catalog", + [ + lazy_fixture("catalog_memory"), + lazy_fixture("catalog_sqlite"), + ], +) +def test_namespace_exists(catalog: SqlCatalog) -> None: + for ns in [("db1",), ("db1", "ns1"), ("db2", "ns1"), ("db3", "ns1", "ns2")]: + catalog.create_namespace(ns) + assert catalog._namespace_exists(ns) + + assert catalog._namespace_exists("db2") # `db2` exists because `db2.ns1` exists + assert catalog._namespace_exists("db3.ns1") # `db3.ns1` exists because `db3.ns1.ns2` exists + assert not catalog._namespace_exists("db_") # make sure '_' is escaped in the query + assert not catalog._namespace_exists("db%") # make sure '%' is escaped in the query + + @pytest.mark.parametrize( "catalog", [