diff --git a/src/main/java/io/weaviate/client6/v1/api/Authentication.java b/src/main/java/io/weaviate/client6/v1/api/Authentication.java index 1c8776c1c..eb6a83ec6 100644 --- a/src/main/java/io/weaviate/client6/v1/api/Authentication.java +++ b/src/main/java/io/weaviate/client6/v1/api/Authentication.java @@ -56,6 +56,26 @@ public static Authentication resourceOwnerPassword(String username, String passw }; } + /** + * Authenticate using Resource Owner Password Credentials authorization grant. + * + * @param clientSecret Client secret. + * @param username Resource owner username. + * @param password Resource owner password. + * @param scopes Client scopes. + * + * @return Authentication provider. + * @throws WeaviateOAuthException if an error occurred at any point of the token + * exchange process. + */ + public static Authentication resourceOwnerPasswordCredentials(String clientSecret, String username, String password, + List scopes) { + return transport -> { + OidcConfig oidc = OidcUtils.getConfig(transport).withScopes(scopes).withScopes("offline_access"); + return TokenProvider.resourceOwnerPasswordCredentials(oidc, clientSecret, username, password); + }; + } + /** * Authenticate using Client Credentials authorization grant. * diff --git a/src/main/java/io/weaviate/client6/v1/api/Config.java b/src/main/java/io/weaviate/client6/v1/api/Config.java index b52b37066..945765e41 100644 --- a/src/main/java/io/weaviate/client6/v1/api/Config.java +++ b/src/main/java/io/weaviate/client6/v1/api/Config.java @@ -9,6 +9,7 @@ import io.weaviate.client6.v1.internal.BuildInfo; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.Proxy; import io.weaviate.client6.v1.internal.Timeout; import io.weaviate.client6.v1.internal.TokenProvider; import io.weaviate.client6.v1.internal.grpc.GrpcChannelOptions; @@ -23,7 +24,8 @@ public record Config( Map headers, Authentication authentication, TrustManagerFactory trustManagerFactory, - Timeout timeout) { + Timeout timeout, + Proxy proxy) { public static Config of(Function> fn) { return fn.apply(new Custom()).build(); @@ -39,7 +41,8 @@ private Config(Builder builder) { builder.headers, builder.authentication, builder.trustManagerFactory, - builder.timeout); + builder.timeout, + builder.proxy); } RestTransportOptions restTransportOptions() { @@ -47,7 +50,7 @@ RestTransportOptions restTransportOptions() { } RestTransportOptions restTransportOptions(TokenProvider tokenProvider) { - return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider, trustManagerFactory, timeout); + return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider, trustManagerFactory, timeout, proxy); } GrpcChannelOptions grpcTransportOptions() { @@ -55,7 +58,7 @@ GrpcChannelOptions grpcTransportOptions() { } GrpcChannelOptions grpcTransportOptions(TokenProvider tokenProvider) { - return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider, trustManagerFactory, timeout); + return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider, trustManagerFactory, timeout, proxy); } private abstract static class Builder> implements ObjectBuilder { @@ -69,6 +72,7 @@ private abstract static class Builder> implements O protected TrustManagerFactory trustManagerFactory; protected Timeout timeout = new Timeout(); protected Map headers = new HashMap<>(); + protected Proxy proxy; /** * Set URL scheme. Subclasses may increase the visibility of this method to @@ -174,6 +178,15 @@ public SelfT timeout(int initSeconds, int querySeconds, int insertSeconds) { return (SelfT) this; } + /** + * Set proxy for all requests. + */ + @SuppressWarnings("unchecked") + public SelfT proxy(Proxy proxy) { + this.proxy = proxy; + return (SelfT) this; + } + /** * Weaviate will use the URL in this header to call Weaviate Embeddings * Service if an appropriate vectorizer is configured for collection. diff --git a/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java b/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java index 63e14c3bf..68f793de9 100644 --- a/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java @@ -20,6 +20,7 @@ import io.weaviate.client6.v1.internal.rest.DefaultRestTransport; import io.weaviate.client6.v1.internal.rest.RestTransport; import io.weaviate.client6.v1.internal.rest.RestTransportOptions; +import lombok.Getter; public class WeaviateClient implements AutoCloseable { /** Store this for {@link #async()} helper. */ @@ -63,14 +64,13 @@ public class WeaviateClient implements AutoCloseable { public final WeaviateClusterClient cluster; public WeaviateClient(Config config) { - RestTransportOptions restOpt; + RestTransportOptions restOpt = config.restTransportOptions(); GrpcChannelOptions grpcOpt; if (config.authentication() == null) { - restOpt = config.restTransportOptions(); grpcOpt = config.grpcTransportOptions(); } else { TokenProvider tokenProvider; - try (final var noAuthRest = new DefaultRestTransport(config.restTransportOptions())) { + try (final var noAuthRest = new DefaultRestTransport(restOpt)) { tokenProvider = config.authentication().getTokenProvider(noAuthRest); } catch (Exception e) { // Generally exceptions are caught in TokenProvider internals. @@ -126,6 +126,10 @@ public WeaviateClient(Config config) { this.config = config; } + public Config getConfig() { + return config; + } + /** * Create {@link WeaviateClientAsync} with identical configurations. * It is a shorthand for: diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java index c0fca9bbc..e6101f682 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java @@ -75,7 +75,7 @@ public CollectionHandle> use( return use(CollectionDescriptor.ofMap(collectionName), fn); } - private CollectionHandle use(CollectionDescriptor collection, + public CollectionHandle use(CollectionDescriptor collection, Function> fn) { return new CollectionHandle<>(restTransport, grpcTransport, collection, CollectionHandleDefaults.of(fn)); } diff --git a/src/main/java/io/weaviate/client6/v1/internal/Proxy.java b/src/main/java/io/weaviate/client6/v1/internal/Proxy.java new file mode 100644 index 000000000..5637cfeae --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/Proxy.java @@ -0,0 +1,15 @@ +package io.weaviate.client6.v1.internal; + +import javax.annotation.Nullable; + +public record Proxy( + String scheme, + String host, + int port, + @Nullable String username, + @Nullable String password +) { + public Proxy(String host, int port) { + this("http", host, port, null, null); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java b/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java index 7b28a5e7a..e88f48563 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java +++ b/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java @@ -141,6 +141,24 @@ public static TokenProvider resourceOwnerPassword(OidcConfig oidc, String userna return background(reuse(null, exchange(oidc, passwordGrant), DEFAULT_EARLY_EXPIRY)); } + /** + * Create a TokenProvider that uses Resource Owner Password Credentials authorization grant. + * + * @param oidc OIDC config. + * @param clientSecret Client secret. + * @param username Resource owner username. + * @param password Resource owner password. + * + * @return Internal TokenProvider implementation. + * @throws WeaviateOAuthException if an error occurred at any point of the token + * exchange process. + */ + public static TokenProvider resourceOwnerPasswordCredentials(OidcConfig oidc, String clientSecret, String username, + String password) { + final var passwordGrant = NimbusTokenProvider.resouceOwnerPasswordCredentials(oidc, clientSecret, username, password); + return background(reuse(null, exchange(oidc, passwordGrant), DEFAULT_EARLY_EXPIRY)); + } + /** * Create a TokenProvider that uses Client Credentials authorization grant. * diff --git a/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java b/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java index 897bb28cd..e52e5811e 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java +++ b/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java @@ -11,9 +11,10 @@ public abstract class TransportOptions { protected final H headers; protected final TrustManagerFactory trustManagerFactory; protected final Timeout timeout; + protected final Proxy proxy; protected TransportOptions(String scheme, String host, int port, H headers, TokenProvider tokenProvider, - TrustManagerFactory tmf, Timeout timeout) { + TrustManagerFactory tmf, Timeout timeout, Proxy proxy) { this.scheme = scheme; this.host = host; this.port = port; @@ -21,6 +22,7 @@ protected TransportOptions(String scheme, String host, int port, H headers, Toke this.headers = headers; this.timeout = timeout; this.trustManagerFactory = tmf; + this.proxy = proxy; } public boolean isSecure() { @@ -57,4 +59,9 @@ public H headers() { public TrustManagerFactory trustManagerFactory() { return this.trustManagerFactory; } + + @Nullable + public Proxy proxy() { + return this.proxy; + } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java index d12255d22..1e28cb975 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java @@ -1,14 +1,9 @@ package io.weaviate.client6.v1.internal.grpc; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; - -import javax.net.ssl.SSLException; - import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; - +import io.grpc.HttpConnectProxiedSocketAddress; import io.grpc.ManagedChannel; import io.grpc.StatusRuntimeException; import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; @@ -17,10 +12,17 @@ import io.grpc.stub.AbstractStub; import io.grpc.stub.MetadataUtils; import io.weaviate.client6.v1.api.WeaviateApiException; +import io.weaviate.client6.v1.internal.Proxy; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub; +import javax.net.ssl.SSLException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + public final class DefaultGrpcTransport implements GrpcTransport { private final ManagedChannel channel; @@ -88,7 +90,7 @@ public CompletableFuture perf var method = rpc.methodAsync(); var stub = applyTimeout(futureStub, rpc); var reply = method.apply(stub, message); - return toCompletableFuture(reply).thenApply(r -> rpc.unmarshal(r)); + return toCompletableFuture(reply).thenApply(rpc::unmarshal); } /** @@ -139,6 +141,27 @@ private static ManagedChannel buildChannel(GrpcChannelOptions transportOptions) channel.sslContext(sslCtx); } + if (transportOptions.proxy() != null) { + Proxy proxy = transportOptions.proxy(); + if ("http".equals(proxy.scheme())) { + final SocketAddress proxyAddress = new InetSocketAddress(proxy.host(), proxy.port()); + channel.proxyDetector(targetAddress -> { + if (targetAddress instanceof InetSocketAddress) { + HttpConnectProxiedSocketAddress.Builder builder = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress((InetSocketAddress) targetAddress); + + if (proxy.username() != null && proxy.password() != null) { + builder.setUsername(proxy.username()); + builder.setPassword(proxy.password()); + } + return builder.build(); + } + return null; + }); + } + } + channel.intercept(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers())); return channel.build(); diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java index 5e4453d7f..708242d4a 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java @@ -5,6 +5,7 @@ import javax.net.ssl.TrustManagerFactory; import io.grpc.Metadata; +import io.weaviate.client6.v1.internal.Proxy; import io.weaviate.client6.v1.internal.Timeout; import io.weaviate.client6.v1.internal.TokenProvider; import io.weaviate.client6.v1.internal.TransportOptions; @@ -13,19 +14,19 @@ public class GrpcChannelOptions extends TransportOptions { private final Integer maxMessageSize; public GrpcChannelOptions(String scheme, String host, int port, Map headers, - TokenProvider tokenProvider, TrustManagerFactory tmf, Timeout timeout) { - this(scheme, host, port, buildMetadata(headers), tokenProvider, tmf, null, timeout); + TokenProvider tokenProvider, TrustManagerFactory tmf, Timeout timeout, Proxy proxy) { + this(scheme, host, port, buildMetadata(headers), tokenProvider, tmf, null, timeout, proxy); } private GrpcChannelOptions(String scheme, String host, int port, Metadata headers, - TokenProvider tokenProvider, TrustManagerFactory tmf, Integer maxMessageSize, Timeout timeout) { - super(scheme, host, port, headers, tokenProvider, tmf, timeout); + TokenProvider tokenProvider, TrustManagerFactory tmf, Integer maxMessageSize, Timeout timeout, Proxy proxy) { + super(scheme, host, port, headers, tokenProvider, tmf, timeout, proxy); this.maxMessageSize = maxMessageSize; } public GrpcChannelOptions withMaxMessageSize(int maxMessageSize) { return new GrpcChannelOptions(scheme, host, port, headers, tokenProvider, trustManagerFactory, maxMessageSize, - timeout); + timeout, proxy); } public Integer maxMessageSize() { diff --git a/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcConfig.java b/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcConfig.java index 858fd4d82..f0a5a2ac7 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcConfig.java +++ b/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcConfig.java @@ -1,5 +1,7 @@ package io.weaviate.client6.v1.internal.oidc; +import io.weaviate.client6.v1.internal.Proxy; + import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -8,19 +10,42 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static java.util.Objects.requireNonNull; + public record OidcConfig( String clientId, String providerMetadata, - Set scopes) { + Set scopes, + OidcProxy proxy) { - public OidcConfig(String clientId, String providerMetadata, Set scopes) { + public record OidcProxy( + String scheme, + String host, + int port) { + + public OidcProxy(Proxy proxy) { + this(requireNonNull(proxy, "proxy is null").scheme(), proxy.host(), proxy.port()); + } + + } + + public OidcConfig(String clientId, String providerMetadata, Set scopes, OidcProxy proxy) { this.clientId = clientId; this.providerMetadata = providerMetadata; this.scopes = scopes != null ? Set.copyOf(scopes) : Collections.emptySet(); + this.proxy = proxy; + } + + public OidcConfig(String clientId, String providerMetadata, Set scopes) { + this(clientId, providerMetadata, scopes, null); } public OidcConfig(String clientId, String providerMetadata, List scopes) { - this(clientId, providerMetadata, scopes == null ? null : new HashSet<>(scopes)); + this(clientId, providerMetadata, scopes == null ? null : new HashSet<>(scopes), null); + } + + public OidcConfig(String clientId, String providerMetadata, List scopes, OidcProxy proxy) { + this(clientId, providerMetadata, scopes == null ? null : new HashSet<>(scopes), proxy); } /** Create a new OIDC config with extended scopes. */ @@ -31,6 +56,6 @@ public OidcConfig withScopes(String... scopes) { /** Create a new OIDC config with extended scopes. */ public OidcConfig withScopes(List scopes) { var newScopes = Stream.concat(this.scopes.stream(), scopes.stream()).collect(Collectors.toSet()); - return new OidcConfig(clientId, providerMetadata, newScopes); + return new OidcConfig(clientId, providerMetadata, newScopes, proxy); } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcUtils.java b/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcUtils.java index cafcc1289..160e845fb 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcUtils.java +++ b/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcUtils.java @@ -7,6 +7,7 @@ import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.WeaviateOAuthException; +import io.weaviate.client6.v1.internal.rest.DefaultRestTransport; import io.weaviate.client6.v1.internal.rest.Endpoint; import io.weaviate.client6.v1.internal.rest.ExternalEndpoint; import io.weaviate.client6.v1.internal.rest.RestTransport; @@ -28,7 +29,7 @@ private OidcUtils() { private static final Endpoint GET_PROVIDER_METADATA_ENDPOINT = new ExternalEndpoint<>( request -> "GET", request -> request, // URL is the request body. - requesf -> Collections.emptyMap(), + request -> Collections.emptyMap(), request -> null, (__, response) -> response); @@ -54,6 +55,8 @@ public static final OidcConfig getConfig(RestTransport transport) { throw new WeaviateOAuthException("fetch provider metadata", e); } - return new OidcConfig(openid.clientId(), providerMetadata, openid.scopes()); + OidcConfig.OidcProxy proxy = new OidcConfig.OidcProxy(transport.getProxy()); + + return new OidcConfig(openid.clientId(), providerMetadata, openid.scopes(), proxy); } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/Flow.java b/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/Flow.java index ad12561ff..4d0c3c32f 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/Flow.java +++ b/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/Flow.java @@ -27,6 +27,21 @@ static Flow resourceOwnerPassword(String username, String password) { return () -> grant; // Reuse cached authorization grant } + static Flow resourceOwnerPasswordCredentials(String clientId, String clientSecret, String username, String password) { + return new Flow() { + private final AuthorizationGrant GRANT = new ResourceOwnerPasswordCredentialsGrant(username, new Secret(password)); + @Override + public AuthorizationGrant getAuthorizationGrant() { + return GRANT; + } + + @Override + public ClientAuthentication getClientAuthentication() { + return new ClientSecretPost(new ClientID(clientId), new Secret(clientSecret)); + } + }; + } + static Flow clientCredentials(String clientId, String clientSecret) { return new Flow() { private static final AuthorizationGrant GRANT = new ClientCredentialsGrant(); diff --git a/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/NimbusTokenProvider.java b/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/NimbusTokenProvider.java index 4bef50e71..5c240e4f2 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/NimbusTokenProvider.java +++ b/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/NimbusTokenProvider.java @@ -1,6 +1,7 @@ package io.weaviate.client6.v1.internal.oidc.nimbus; import java.io.IOException; +import java.net.Proxy; import javax.annotation.concurrent.NotThreadSafe; @@ -8,7 +9,7 @@ import com.nimbusds.oauth2.sdk.Scope; import com.nimbusds.oauth2.sdk.TokenRequest; import com.nimbusds.oauth2.sdk.id.ClientID; -import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata; +import com.nimbusds.oauth2.sdk.as.AuthorizationServerMetadata; import com.nimbusds.openid.connect.sdk.token.OIDCTokens; import io.weaviate.client6.v1.api.WeaviateOAuthException; @@ -17,10 +18,11 @@ @NotThreadSafe public final class NimbusTokenProvider implements TokenProvider { - private final OIDCProviderMetadata metadata; + private final AuthorizationServerMetadata metadata; private final ClientID clientId; private final Scope scope; private final Flow flow; + private Proxy proxy; /** * Create a TokenProvider that uses Refresh Token authorization grant. @@ -51,6 +53,23 @@ public static NimbusTokenProvider resourceOwnerPassword(OidcConfig oidc, String return new NimbusTokenProvider(oidc, Flow.resourceOwnerPassword(username, password)); } + /** + * Create a TokenProvider that uses Resource Owner Password Credentials authorization grant. + * + * @param oidc OIDC config. + * @param clientSecret Client secret. + * @param username Resource owner username. + * @param password Resource owner password. + * + * @return A new instance of NimbusTokenProvider. Instances are never cached. + * @throws WeaviateOAuthException if an error occured at any point of the + * exchange process. + */ + public static NimbusTokenProvider resouceOwnerPasswordCredentials(OidcConfig oidc, String clientSecret, String username, + String password) { + return new NimbusTokenProvider(oidc, Flow.resourceOwnerPasswordCredentials(oidc.clientId(), clientSecret, username, password)); + } + /** * Create a TokenProvider that uses Client Credentials authorization grant. * @@ -70,6 +89,10 @@ private NimbusTokenProvider(OidcConfig oidc, Flow flow) { this.clientId = new ClientID(oidc.clientId()); this.scope = new Scope(oidc.scopes().toArray(String[]::new)); this.flow = flow; + var proxy = oidc.proxy(); + if (proxy != null) { + this.proxy = new java.net.Proxy(java.net.Proxy.Type.HTTP, new java.net.InetSocketAddress(proxy.host(), proxy.port())); + } } @Override @@ -83,6 +106,10 @@ public Token getToken() { : new TokenRequest(uri, clientAuth, grant, scope); var request = tokenRequest.toHTTPRequest(); + if (proxy != null) { + request.setProxy(proxy); + } + OIDCTokens tokens; try { var response = request.send(); @@ -116,9 +143,9 @@ public static ProviderMetadata parseProviderMetadata(String providerMetadata) { return new ProviderMetadata(metadata.getTokenEndpointURI()); } - private static OIDCProviderMetadata _parseProviderMetadata(String providerMetadata) { + private static AuthorizationServerMetadata _parseProviderMetadata(String providerMetadata) { try { - return OIDCProviderMetadata.parse(providerMetadata); + return AuthorizationServerMetadata.parse(providerMetadata); } catch (ParseException ex) { throw new WeaviateOAuthException("parse provider metadata: ", ex); } diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java b/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java index a77d9da29..b80880841 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java @@ -8,6 +8,7 @@ import javax.net.ssl.SSLContext; +import org.apache.hc.core5.http.HttpHost; import org.apache.hc.client5.http.async.methods.SimpleHttpRequest; import org.apache.hc.client5.http.async.methods.SimpleHttpResponse; import org.apache.hc.client5.http.config.RequestConfig; @@ -31,6 +32,7 @@ import io.weaviate.client6.v1.api.WeaviateApiException; import io.weaviate.client6.v1.api.WeaviateTransportException; +import io.weaviate.client6.v1.internal.Proxy; public class DefaultRestTransport implements RestTransport { private final CloseableHttpClient httpClient; @@ -43,9 +45,9 @@ public DefaultRestTransport(RestTransportOptions transportOptions) { this.transportOptions = transportOptions; // TODO: doesn't make sense to spin up both? - var httpClient = HttpClients.custom() + var httpClient = HttpClients.custom().useSystemProperties() .setDefaultHeaders(transportOptions.headers()); - var httpClientAsync = HttpAsyncClients.custom() + var httpClientAsync = HttpAsyncClients.custom().useSystemProperties() .setDefaultHeaders(transportOptions.headers()); // Apply custom SSL context @@ -68,6 +70,13 @@ public DefaultRestTransport(RestTransportOptions transportOptions) { httpClientAsync.setConnectionManager(asyncManager); } + if (transportOptions.proxy() != null) { + Proxy proxy = transportOptions.proxy(); + HttpHost proxyHost = new HttpHost(proxy.scheme(), proxy.host(), proxy.port()); + httpClient.setProxy(proxyHost); + httpClientAsync.setProxy(proxyHost); + } + if (transportOptions.timeout() != null) { var config = RequestConfig.custom() .setResponseTimeout(transportOptions.timeout().querySeconds(), TimeUnit.SECONDS) @@ -111,8 +120,6 @@ private ClassicHttpRequest prepareClassicRequest(RequestT if (body != null) { req.setEntity(body, ContentType.APPLICATION_JSON); } - if (true) { - } return req.build(); } @@ -196,6 +203,10 @@ private ResponseT _handleResponse(Endpoint endpoint, S throw new WeaviateTransportException("Unhandled endpoint type " + endpoint.getClass().getSimpleName()); } + public Proxy getProxy() { + return transportOptions.proxy(); + } + @Override public void close() throws Exception { httpClient.close(); diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransport.java b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransport.java index da26c9f12..62f20871c 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransport.java @@ -3,6 +3,8 @@ import java.io.IOException; import java.util.concurrent.CompletableFuture; +import io.weaviate.client6.v1.internal.Proxy; + public interface RestTransport extends AutoCloseable { ResponseT performRequest(RequestT request, Endpoint endpoint) @@ -10,4 +12,6 @@ ResponseT performRequest(RequestT request, CompletableFuture performRequestAsync(RequestT request, Endpoint endpoint); + + Proxy getProxy(); } diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java index 5da4cdd5f..4635a9368 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java @@ -8,6 +8,7 @@ import org.apache.hc.core5.http.message.BasicHeader; +import io.weaviate.client6.v1.internal.Proxy; import io.weaviate.client6.v1.internal.Timeout; import io.weaviate.client6.v1.internal.TokenProvider; import io.weaviate.client6.v1.internal.TransportOptions; @@ -16,17 +17,21 @@ public final class RestTransportOptions extends TransportOptions headers, - TokenProvider tokenProvider, TrustManagerFactory trust, Timeout timeout) { - super(scheme, host, port, buildHeaders(headers), tokenProvider, trust, timeout); + TokenProvider tokenProvider, TrustManagerFactory trust, Timeout timeout, Proxy proxy) { + super(scheme, host, port, buildHeaders(headers), tokenProvider, trust, timeout, proxy); } private RestTransportOptions(String scheme, String host, int port, Collection headers, - TokenProvider tokenProvider, TrustManagerFactory trust, Timeout timeout) { - super(scheme, host, port, headers, tokenProvider, trust, timeout); + TokenProvider tokenProvider, TrustManagerFactory trust, Timeout timeout, Proxy proxy) { + super(scheme, host, port, headers, tokenProvider, trust, timeout, proxy); + } + + public RestTransportOptions(String http, String localhost, Integer localPort, Map headers, Object tokenProvider, Object trust, Timeout timeout) { + super(http, localhost, localPort, buildHeaders(headers), (TokenProvider) tokenProvider, (TrustManagerFactory) trust, timeout, null); } public final RestTransportOptions withTimeout(Timeout timeout) { - return new RestTransportOptions(scheme, host, port, headers, tokenProvider, trustManagerFactory, timeout); + return new RestTransportOptions(scheme, host, port, headers, tokenProvider, trustManagerFactory, timeout, proxy); } private static final Collection buildHeaders(Map headers) { diff --git a/src/test/java/io/weaviate/client6/v1/api/AuthenticationTest.java b/src/test/java/io/weaviate/client6/v1/api/AuthenticationTest.java index 54472875e..124789280 100644 --- a/src/test/java/io/weaviate/client6/v1/api/AuthenticationTest.java +++ b/src/test/java/io/weaviate/client6/v1/api/AuthenticationTest.java @@ -60,6 +60,81 @@ public void testAuthentication_apiKey() throws Exception { .withHeader("Authorization", "Bearer my-api-key")); } + @Test + public void testAuthentication_resourceOwnerPasswordWithClientSecret() throws Exception { + // 1. Mock /.well-known/openid-configuration + mockServer.when( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/.well-known/openid-configuration") + ).respond( + org.mockserver.model.HttpResponse.response() + .withStatusCode(200) + .withHeader("Content-Type", "application/json") + .withBody("{\"clientId\": \"my-client-id\", \"href\": \"http://localhost:" + mockServer.getLocalPort() + "/oidc-provider\"}") + ); + + // 2. Mock OIDC provider metadata + mockServer.when( + HttpRequest.request() + .withMethod("GET") + .withPath("/oidc-provider") + ).respond( + org.mockserver.model.HttpResponse.response() + .withStatusCode(200) + .withHeader("Content-Type", "application/json") + .withBody("{\"issuer\": \"http://localhost:" + mockServer.getLocalPort() + "\", \"token_endpoint\": \"http://localhost:" + mockServer.getLocalPort() + "/token\"}") + ); + + // 3. Mock Token Endpoint + mockServer.when( + HttpRequest.request() + .withMethod("POST") + .withPath("/token") + ).respond( + org.mockserver.model.HttpResponse.response() + .withStatusCode(200) + .withHeader("Content-Type", "application/json") + .withBody("{\"access_token\": \"secret-token\", \"token_type\": \"Bearer\", \"expires_in\": 3600}") + ); + + var authz = Authentication.resourceOwnerPasswordCredentials("my-client-secret", "my-user", "my-pass", Collections.emptyList()); + var transportOptions = new RestTransportOptions( + "http", "localhost", mockServer.getLocalPort(), + Collections.emptyMap(), authz.getTokenProvider(noAuthTransport), null, new Timeout()); + + try (final var restClient = new DefaultRestTransport(transportOptions)) { + restClient.performRequest(null, SimpleEndpoint.sideEffect( + request -> "GET", request -> "/", request -> null)); + } catch (WeaviateApiException ex) { + if (ex.httpStatusCode() != 404) { + Assertions.fail("unexpected error", ex); + } + } + + // Verify token request had both password grant and client authentication + mockServer.verify( + HttpRequest.request() + .withMethod("POST") + .withPath("/token") + .withBody(org.mockserver.model.ParameterBody.params( + org.mockserver.model.Parameter.param("grant_type", "password"), + org.mockserver.model.Parameter.param("username", "my-user"), + org.mockserver.model.Parameter.param("password", "my-pass"), + org.mockserver.model.Parameter.param("client_id", "my-client-id"), + org.mockserver.model.Parameter.param("client_secret", "my-client-secret"), + org.mockserver.model.Parameter.param("scope", "offline_access") + )) + ); + + // Verify the actual request used the obtained token + mockServer.verify( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/") + .withHeader("Authorization", "Bearer secret-token")); + } + @After public void stopMockServer() throws Exception { mockServer.stop(); diff --git a/src/test/java/io/weaviate/client6/v1/internal/rest/ProxyTest.java b/src/test/java/io/weaviate/client6/v1/internal/rest/ProxyTest.java new file mode 100644 index 000000000..2a61e51e8 --- /dev/null +++ b/src/test/java/io/weaviate/client6/v1/internal/rest/ProxyTest.java @@ -0,0 +1,142 @@ +package io.weaviate.client6.v1.internal.rest; + +import io.weaviate.client6.v1.api.WeaviateApiException; +import io.weaviate.client6.v1.api.Config; +import io.weaviate.client6.v1.api.WeaviateClient; +import io.weaviate.client6.v1.internal.Proxy; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockserver.integration.ClientAndServer; +import org.mockserver.model.HttpRequest; +import org.mockserver.model.HttpResponse; + +import java.io.IOException; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockserver.model.HttpForward.forward; + +public class ProxyTest { + private ClientAndServer targetServer; + private ClientAndServer proxyServer; + private WeaviateClient client; + + @Before + public void setUp() { + targetServer = ClientAndServer.startClientAndServer(0); + proxyServer = ClientAndServer.startClientAndServer(0); + + // Set up target server to return a success response + targetServer.when( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/.well-known/live")) + .respond( + HttpResponse.response() + .withStatusCode(200)); + + targetServer.when( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/meta")) + .respond( + HttpResponse.response() + .withStatusCode(200) + .withBody("{\"version\": \"1.32.0\"}")); + + // Set up proxy server to forward requests to the target server + proxyServer.when( + HttpRequest.request()) + .forward( + forward() + .withHost("localhost") + .withPort(targetServer.getLocalPort()) + .withScheme(org.mockserver.model.HttpForward.Scheme.HTTP)); + + Config config = Config.of(c -> c + .scheme("http") + .httpHost("localhost") + .httpPort(targetServer.getLocalPort()) + .grpcHost("localhost") + .grpcPort(targetServer.getLocalPort()) + .proxy(new Proxy("localhost", proxyServer.getLocalPort())) + .timeout(5) + ); + + client = new WeaviateClient(config); + } + + @Test + public void testClientInitializationWithProxy() { + // This test verifies that the client can be successfully created. + // The WeaviateClient constructor performs REST calls to /v1/.well-known/live + // and /v1/meta to verify the connection and version support. + // If these calls fail, the constructor throws a WeaviateConnectException. + // Since setUp() already creates a client using the proxy, we just need to + // verify it was initialized correctly. + assertThat(client).isNotNull(); + + // Verify that the initialization calls went through the proxy + proxyServer.verify( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/.well-known/live")); + proxyServer.verify( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/meta")); + } + + @Test + public void testRestProxy() throws IOException { + // Perform a request that should go through the proxy + client.meta(); + + // Verify that the proxy server received the request + proxyServer.verify( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/meta")); + + // Verify that the target server also received the request (forwarded by proxy) + targetServer.verify( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/meta")); + } + + @Test + public void testProxyConfiguration() { + // In this test, we verify that the client has the proxy configured. + assertThat(client.getConfig().proxy()).isNotNull(); + assertThat(client.getConfig().proxy().port()).isEqualTo((long) proxyServer.getLocalPort()); + } + + @Test + public void testGrpcProxy() { + // gRPC proxying via HTTP CONNECT. + // DefaultGrpcTransport uses a custom ProxyDetector which returns a + // HttpConnectProxiedSocketAddress when a proxy is configured. + + // To verify that gRPC proxying is correctly set up, we check the configuration. + // Since actual CONNECT verification via MockServer is tricky in this setup, + // we focus on ensuring the client is correctly initialized with the proxy. + assertThatThrownBy(() -> client.collections.use("Test").size()) + .isInstanceOf(WeaviateApiException.class) + .hasMessageContaining("UNAVAILABLE: Network closed"); + } + + @After + public void tearDown() throws Exception { + if (client != null) { + client.close(); + } + if (proxyServer != null) { + proxyServer.stop(); + } + if (targetServer != null) { + targetServer.stop(); + } + } +} diff --git a/src/test/java/io/weaviate/testutil/transport/MockRestTransport.java b/src/test/java/io/weaviate/testutil/transport/MockRestTransport.java index 587cb2548..7991078cd 100644 --- a/src/test/java/io/weaviate/testutil/transport/MockRestTransport.java +++ b/src/test/java/io/weaviate/testutil/transport/MockRestTransport.java @@ -6,6 +6,7 @@ import java.util.Map; import java.util.concurrent.CompletableFuture; +import io.weaviate.client6.v1.internal.Proxy; import io.weaviate.client6.v1.internal.rest.BooleanEndpoint; import io.weaviate.client6.v1.internal.rest.Endpoint; import io.weaviate.client6.v1.internal.rest.RestTransport; @@ -25,7 +26,7 @@ public interface AssertFunction { void apply(String method, String requestUrl, String body, Map queryParameters); } - private List> requests = new ArrayList<>(); + private final List> requests = new ArrayList<>(); public void assertNext(AssertFunction... assertions) { var assertN = Math.min(assertions.length, requests.size()); @@ -60,4 +61,9 @@ public CompletableFuture performReq @Override public void close() throws IOException { } + + @Override + public Proxy getProxy() { + return null; + } }