Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,15 @@
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.13.0</version>
<configuration>
<annotationProcessorPaths>
<path>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>${lombok.version}</version>
</path>
</annotationProcessorPaths>
</configuration>
Comment on lines +235 to +243
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This project doesn't use Lombok for generating getters/setters. If you apply the suggestion about getConfig() in WeaviateClient.java then this configuration will probably become redundant.

</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
Expand Down
20 changes: 20 additions & 0 deletions src/main/java/io/weaviate/client6/v1/api/Authentication.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ public static Authentication resourceOwnerPassword(String username, String passw
};
}

/**
* Authenticate using Resource Owner Password Credentials authorization grant.
*
* @param clientSecret Client secret.
* @param username Resource owner username.
* @param password Resource owner password.
* @param scopes Client scopes.
*
* @return Authentication provider.
* @throws WeaviateOAuthException if an error occurred at any point of the token
* exchange process.
*/
public static Authentication resourceOwnerPasswordCredentials(String clientSecret, String username, String password,
List<String> scopes) {
return transport -> {
OidcConfig oidc = OidcUtils.getConfig(transport).withScopes(scopes).withScopes("offline_access");
return TokenProvider.resourceOwnerPasswordCredentials(oidc, clientSecret, username, password);
};
}

/**
* Authenticate using Client Credentials authorization grant.
*
Expand Down
21 changes: 17 additions & 4 deletions src/main/java/io/weaviate/client6/v1/api/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import io.weaviate.client6.v1.internal.BuildInfo;
import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.Proxy;
import io.weaviate.client6.v1.internal.Timeout;
import io.weaviate.client6.v1.internal.TokenProvider;
import io.weaviate.client6.v1.internal.grpc.GrpcChannelOptions;
Expand All @@ -23,7 +24,8 @@ public record Config(
Map<String, String> headers,
Authentication authentication,
TrustManagerFactory trustManagerFactory,
Timeout timeout) {
Timeout timeout,
Proxy proxy) {

public static Config of(Function<Custom, ObjectBuilder<Config>> fn) {
return fn.apply(new Custom()).build();
Expand All @@ -39,23 +41,24 @@ private Config(Builder<?> builder) {
builder.headers,
builder.authentication,
builder.trustManagerFactory,
builder.timeout);
builder.timeout,
builder.proxy);
}

RestTransportOptions restTransportOptions() {
return restTransportOptions(null);
}

RestTransportOptions restTransportOptions(TokenProvider tokenProvider) {
return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider, trustManagerFactory, timeout);
return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider, trustManagerFactory, timeout, proxy);
}

GrpcChannelOptions grpcTransportOptions() {
return grpcTransportOptions(null);
}

GrpcChannelOptions grpcTransportOptions(TokenProvider tokenProvider) {
return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider, trustManagerFactory, timeout);
return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider, trustManagerFactory, timeout, proxy);
}

private abstract static class Builder<SelfT extends Builder<SelfT>> implements ObjectBuilder<Config> {
Expand All @@ -69,6 +72,7 @@ private abstract static class Builder<SelfT extends Builder<SelfT>> implements O
protected TrustManagerFactory trustManagerFactory;
protected Timeout timeout = new Timeout();
protected Map<String, String> headers = new HashMap<>();
protected Proxy proxy;

/**
* Set URL scheme. Subclasses may increase the visibility of this method to
Expand Down Expand Up @@ -174,6 +178,15 @@ public SelfT timeout(int initSeconds, int querySeconds, int insertSeconds) {
return (SelfT) this;
}

/**
* Set proxy for all requests.
*/
@SuppressWarnings("unchecked")
public SelfT proxy(Proxy proxy) {
this.proxy = proxy;
return (SelfT) this;
}

/**
* Weaviate will use the URL in this header to call Weaviate Embeddings
* Service if an appropriate vectorizer is configured for collection.
Expand Down
7 changes: 4 additions & 3 deletions src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
import io.weaviate.client6.v1.internal.rest.DefaultRestTransport;
import io.weaviate.client6.v1.internal.rest.RestTransport;
import io.weaviate.client6.v1.internal.rest.RestTransportOptions;
import lombok.Getter;

public class WeaviateClient implements AutoCloseable {
/** Store this for {@link #async()} helper. */
@Getter
private final Config config;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: replace @Getter with an explicit Config getConfig

private final RestTransport restTransport;
Expand Down Expand Up @@ -63,14 +65,13 @@ public class WeaviateClient implements AutoCloseable {
public final WeaviateClusterClient cluster;

public WeaviateClient(Config config) {
RestTransportOptions restOpt;
RestTransportOptions restOpt = config.restTransportOptions();
GrpcChannelOptions grpcOpt;
if (config.authentication() == null) {
restOpt = config.restTransportOptions();
grpcOpt = config.grpcTransportOptions();
} else {
TokenProvider tokenProvider;
try (final var noAuthRest = new DefaultRestTransport(config.restTransportOptions())) {
try (final var noAuthRest = new DefaultRestTransport(restOpt)) {
tokenProvider = config.authentication().getTokenProvider(noAuthRest);
} catch (Exception e) {
// Generally exceptions are caught in TokenProvider internals.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public CollectionHandle<Map<String, Object>> use(
return use(CollectionDescriptor.ofMap(collectionName), fn);
}

private <PropertiesT> CollectionHandle<PropertiesT> use(CollectionDescriptor<PropertiesT> collection,
public <PropertiesT> CollectionHandle<PropertiesT> use(CollectionDescriptor<PropertiesT> collection,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

Function<CollectionHandleDefaults.Builder, ObjectBuilder<CollectionHandleDefaults>> fn) {
return new CollectionHandle<>(restTransport, grpcTransport, collection, CollectionHandleDefaults.of(fn));
}
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/io/weaviate/client6/v1/internal/Proxy.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package io.weaviate.client6.v1.internal;

import javax.annotation.Nullable;

public record Proxy(
String scheme,
String host,
int port,
@Nullable String username,
@Nullable String password
) {
public Proxy(String host, int port) {
this("http", host, port, null, null);
}
}
18 changes: 18 additions & 0 deletions src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,24 @@ public static TokenProvider resourceOwnerPassword(OidcConfig oidc, String userna
return background(reuse(null, exchange(oidc, passwordGrant), DEFAULT_EARLY_EXPIRY));
}

/**
* Create a TokenProvider that uses Resource Owner Password Credentials authorization grant.
*
* @param oidc OIDC config.
* @param clientSecret Client secret.
* @param username Resource owner username.
* @param password Resource owner password.
*
* @return Internal TokenProvider implementation.
* @throws WeaviateOAuthException if an error occurred at any point of the token
* exchange process.
*/
public static TokenProvider resourceOwnerPasswordCredentials(OidcConfig oidc, String clientSecret, String username,
String password) {
final var passwordGrant = NimbusTokenProvider.resouceOwnerPasswordCredentials(oidc, clientSecret, username, password);
return background(reuse(null, exchange(oidc, passwordGrant), DEFAULT_EARLY_EXPIRY));
}

/**
* Create a TokenProvider that uses Client Credentials authorization grant.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@ public abstract class TransportOptions<H> {
protected final H headers;
protected final TrustManagerFactory trustManagerFactory;
protected final Timeout timeout;
protected final Proxy proxy;

protected TransportOptions(String scheme, String host, int port, H headers, TokenProvider tokenProvider,
TrustManagerFactory tmf, Timeout timeout) {
TrustManagerFactory tmf, Timeout timeout, Proxy proxy) {
this.scheme = scheme;
this.host = host;
this.port = port;
this.tokenProvider = tokenProvider;
this.headers = headers;
this.timeout = timeout;
this.trustManagerFactory = tmf;
this.proxy = proxy;
}

public boolean isSecure() {
Expand Down Expand Up @@ -57,4 +59,9 @@ public H headers() {
public TrustManagerFactory trustManagerFactory() {
return this.trustManagerFactory;
}

@Nullable
public Proxy proxy() {
return this.proxy;
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
package io.weaviate.client6.v1.internal.grpc;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

import javax.net.ssl.SSLException;

import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;

import io.grpc.HttpConnectProxiedSocketAddress;
import io.grpc.ManagedChannel;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
Expand All @@ -17,10 +12,17 @@
import io.grpc.stub.AbstractStub;
import io.grpc.stub.MetadataUtils;
import io.weaviate.client6.v1.api.WeaviateApiException;
import io.weaviate.client6.v1.internal.Proxy;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub;

import javax.net.ssl.SSLException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

public final class DefaultGrpcTransport implements GrpcTransport {
private final ManagedChannel channel;

Expand Down Expand Up @@ -88,7 +90,7 @@ public <RequestT, RequestM, ReplyM, ResponseT> CompletableFuture<ResponseT> perf
var method = rpc.methodAsync();
var stub = applyTimeout(futureStub, rpc);
var reply = method.apply(stub, message);
return toCompletableFuture(reply).thenApply(r -> rpc.unmarshal(r));
return toCompletableFuture(reply).thenApply(rpc::unmarshal);
}

/**
Expand Down Expand Up @@ -139,6 +141,27 @@ private static ManagedChannel buildChannel(GrpcChannelOptions transportOptions)
channel.sslContext(sslCtx);
}

if (transportOptions.proxy() != null) {
Proxy proxy = transportOptions.proxy();
if ("http".equals(proxy.scheme())) {
final SocketAddress proxyAddress = new InetSocketAddress(proxy.host(), proxy.port());
channel.proxyDetector(targetAddress -> {
if (targetAddress instanceof InetSocketAddress) {
HttpConnectProxiedSocketAddress.Builder builder = HttpConnectProxiedSocketAddress.newBuilder()
.setProxyAddress(proxyAddress)
.setTargetAddress((InetSocketAddress) targetAddress);

if (proxy.username() != null && proxy.password() != null) {
builder.setUsername(proxy.username());
builder.setPassword(proxy.password());
}
return builder.build();
}
return null;
});
}
}

channel.intercept(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers()));

return channel.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import javax.net.ssl.TrustManagerFactory;

import io.grpc.Metadata;
import io.weaviate.client6.v1.internal.Proxy;
import io.weaviate.client6.v1.internal.Timeout;
import io.weaviate.client6.v1.internal.TokenProvider;
import io.weaviate.client6.v1.internal.TransportOptions;
Expand All @@ -13,19 +14,19 @@ public class GrpcChannelOptions extends TransportOptions<Metadata> {
private final Integer maxMessageSize;

public GrpcChannelOptions(String scheme, String host, int port, Map<String, String> headers,
TokenProvider tokenProvider, TrustManagerFactory tmf, Timeout timeout) {
this(scheme, host, port, buildMetadata(headers), tokenProvider, tmf, null, timeout);
TokenProvider tokenProvider, TrustManagerFactory tmf, Timeout timeout, Proxy proxy) {
this(scheme, host, port, buildMetadata(headers), tokenProvider, tmf, null, timeout, proxy);
}

private GrpcChannelOptions(String scheme, String host, int port, Metadata headers,
TokenProvider tokenProvider, TrustManagerFactory tmf, Integer maxMessageSize, Timeout timeout) {
super(scheme, host, port, headers, tokenProvider, tmf, timeout);
TokenProvider tokenProvider, TrustManagerFactory tmf, Integer maxMessageSize, Timeout timeout, Proxy proxy) {
super(scheme, host, port, headers, tokenProvider, tmf, timeout, proxy);
this.maxMessageSize = maxMessageSize;
}

public GrpcChannelOptions withMaxMessageSize(int maxMessageSize) {
return new GrpcChannelOptions(scheme, host, port, headers, tokenProvider, trustManagerFactory, maxMessageSize,
timeout);
timeout, proxy);
}

public Integer maxMessageSize() {
Expand Down
30 changes: 26 additions & 4 deletions src/main/java/io/weaviate/client6/v1/internal/oidc/OidcConfig.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.weaviate.client6.v1.internal.oidc;

import io.weaviate.client6.v1.internal.Proxy;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
Expand All @@ -11,16 +13,36 @@
public record OidcConfig(
String clientId,
String providerMetadata,
Set<String> scopes) {
Set<String> scopes,
OidcProxy proxy) {

public OidcConfig(String clientId, String providerMetadata, Set<String> scopes) {
public record OidcProxy(
String host,
int port,
String scheme) {
Comment on lines +19 to +22
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: similarly to Proxy, can we order the parameters as scheme, host, port?


public static OidcProxy from(Proxy proxy) {
return proxy == null ? null : new OidcProxy(proxy.host(), proxy.port(), proxy.scheme());
}
Comment on lines +24 to +26
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public static OidcProxy from(Proxy proxy) {
return proxy == null ? null : new OidcProxy(proxy.host(), proxy.port(), proxy.scheme());
}
public OidcProxy(Proxy proxy) {
this(requireNonNull(proxy, "proxy is null").scheme(), proxy.host(), proxy.port());
}

Feels like proxy == null ? null will be causing ambiguity at the call site. IMO a less ambiguous contract would be "create a non-null OidcProxy from a non-null Proxy".

See another comment about the ordering of parameters.

}

public OidcConfig(String clientId, String providerMetadata, Set<String> scopes, OidcProxy proxy) {
this.clientId = clientId;
this.providerMetadata = providerMetadata;
this.scopes = scopes != null ? Set.copyOf(scopes) : Collections.emptySet();
this.proxy = proxy;
}

public OidcConfig(String clientId, String providerMetadata, Set<String> scopes) {
this(clientId, providerMetadata, scopes, null);
}

public OidcConfig(String clientId, String providerMetadata, List<String> scopes) {
this(clientId, providerMetadata, scopes == null ? null : new HashSet<>(scopes));
this(clientId, providerMetadata, scopes == null ? null : new HashSet<>(scopes), null);
}

public OidcConfig(String clientId, String providerMetadata, List<String> scopes, OidcProxy proxy) {
this(clientId, providerMetadata, scopes == null ? null : new HashSet<>(scopes), proxy);
}

/** Create a new OIDC config with extended scopes. */
Expand All @@ -31,6 +53,6 @@ public OidcConfig withScopes(String... scopes) {
/** Create a new OIDC config with extended scopes. */
public OidcConfig withScopes(List<String> scopes) {
var newScopes = Stream.concat(this.scopes.stream(), scopes.stream()).collect(Collectors.toSet());
return new OidcConfig(clientId, providerMetadata, newScopes);
return new OidcConfig(clientId, providerMetadata, newScopes, proxy);
}
}
Loading