Skip to content
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions pyiceberg/catalog/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import socket
import time
from functools import cached_property
from types import TracebackType
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -142,41 +143,42 @@
class _HiveClient:
"""Helper class to nicely open and close the transport."""

_transport: TTransport
_client: Client
_ugi: Optional[List[str]]

def __init__(self, uri: str, ugi: Optional[str] = None, kerberos_auth: Optional[bool] = HIVE_KERBEROS_AUTH_DEFAULT):
self._uri = uri
self._kerberos_auth = kerberos_auth
self._ugi = ugi.split(":") if ugi else None
self._transport = self._init_thrift_transport()

self._init_thrift_client()

def _init_thrift_client(self) -> None:
def _init_thrift_transport(self) -> TTransport:
url_parts = urlparse(self._uri)

socket = TSocket.TSocket(url_parts.hostname, url_parts.port)

if not self._kerberos_auth:
self._transport = TTransport.TBufferedTransport(socket)
return TTransport.TBufferedTransport(socket)
else:
self._transport = TTransport.TSaslClientTransport(socket, host=url_parts.hostname, service="hive")
return TTransport.TSaslClientTransport(socket, host=url_parts.hostname, service="hive")

@cached_property
def _client(self) -> Client:
protocol = TBinaryProtocol.TBinaryProtocol(self._transport)

self._client = Client(protocol)
client = Client(protocol)
if self._ugi:
client.set_ugi(*self._ugi)
return client

def __enter__(self) -> Client:
self._transport.open()
if self._ugi:
self._client.set_ugi(*self._ugi)
"""Make sure the transport is initialized and open."""
if not self._transport:
self._transport = self._init_thrift_transport()
if not self._transport.isOpen():
self._transport.open()
return self._client

def __exit__(
self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType]
) -> None:
self._transport.close()
"""Close transport if it was opened."""
if self._transport and self._transport.isOpen():
self._transport.close()


def _construct_hive_storage_descriptor(
Expand Down