From c39ba4a35c23d3262a4c89f3fbed6078f191fd58 Mon Sep 17 00:00:00 2001 From: "nov.lzf" Date: Tue, 21 Mar 2023 15:27:57 +0800 Subject: [PATCH] ssl context reload spi (#10150) --- .../common/remote/client/grpc/GrpcClient.java | 114 +++++++------- .../nacos/core/remote/BaseRpcServer.java | 45 ++++-- .../remote/RpcServerSslContextRefresher.java | 41 +++++ .../RpcServerSslContextRefresherHolder.java | 75 +++++++++ .../nacos/core/remote/RpcServerTlsConfig.java | 12 +- .../core/remote/SslContextChangeAware.java | 43 +++++ .../core/remote/grpc/BaseGrpcServer.java | 147 ++++++++++-------- .../grpc/OptionalTlsProtocolNegotiator.java | 53 ++++--- .../core/remote/grpc/GrpcServerTest.java | 63 +++++--- 9 files changed, 410 insertions(+), 183 deletions(-) create mode 100644 core/src/main/java/com/alibaba/nacos/core/remote/RpcServerSslContextRefresher.java create mode 100644 core/src/main/java/com/alibaba/nacos/core/remote/RpcServerSslContextRefresherHolder.java create mode 100644 core/src/main/java/com/alibaba/nacos/core/remote/SslContextChangeAware.java diff --git a/common/src/main/java/com/alibaba/nacos/common/remote/client/grpc/GrpcClient.java b/common/src/main/java/com/alibaba/nacos/common/remote/client/grpc/GrpcClient.java index 391be357c..c4b21b204 100644 --- a/common/src/main/java/com/alibaba/nacos/common/remote/client/grpc/GrpcClient.java +++ b/common/src/main/java/com/alibaba/nacos/common/remote/client/grpc/GrpcClient.java @@ -71,18 +71,18 @@ import java.util.concurrent.TimeUnit; */ @SuppressWarnings("PMD.AbstractClassShouldStartWithAbstractNamingRule") public abstract class GrpcClient extends RpcClient { - + private static final Logger LOGGER = LoggerFactory.getLogger(GrpcClient.class); - + private final GrpcClientConfig clientConfig; - + private ThreadPoolExecutor grpcExecutor; - + @Override public ConnectionType getConnectionType() { return ConnectionType.GRPC; } - + /** * constructor. * @@ -91,7 +91,7 @@ public abstract class GrpcClient extends RpcClient { public GrpcClient(String name) { this(DefaultGrpcClientConfig.newBuilder().setName(name).build()); } - + /** * constructor. * @@ -100,7 +100,7 @@ public abstract class GrpcClient extends RpcClient { public GrpcClient(Properties properties) { this(DefaultGrpcClientConfig.newBuilder().fromProperties(properties).build()); } - + /** * constructor. * @@ -110,7 +110,7 @@ public abstract class GrpcClient extends RpcClient { super(clientConfig); this.clientConfig = clientConfig; } - + /** * constructor. * @@ -121,7 +121,7 @@ public abstract class GrpcClient extends RpcClient { super(clientConfig, serverListFactory); this.clientConfig = clientConfig; } - + /** * constructor. * @@ -134,12 +134,13 @@ public abstract class GrpcClient extends RpcClient { this(DefaultGrpcClientConfig.newBuilder().setName(name).setThreadPoolCoreSize(threadPoolCoreSize) .setThreadPoolMaxSize(threadPoolMaxSize).setLabels(labels).build()); } - - public GrpcClient(String name, Integer threadPoolCoreSize, Integer threadPoolMaxSize, Map labels, RpcClientTlsConfig tlsConfig) { - this(DefaultGrpcClientConfig.newBuilder().setName(name).setThreadPoolCoreSize(threadPoolCoreSize).setTlsConfig(tlsConfig) - .setThreadPoolMaxSize(threadPoolMaxSize).setLabels(labels).build()); + + public GrpcClient(String name, Integer threadPoolCoreSize, Integer threadPoolMaxSize, Map labels, + RpcClientTlsConfig tlsConfig) { + this(DefaultGrpcClientConfig.newBuilder().setName(name).setThreadPoolCoreSize(threadPoolCoreSize) + .setTlsConfig(tlsConfig).setThreadPoolMaxSize(threadPoolMaxSize).setLabels(labels).build()); } - + protected ThreadPoolExecutor createGrpcExecutor(String serverIp) { // Thread name will use String.format, ipv6 maybe contain special word %, so handle it first. serverIp = serverIp.replaceAll("%", "-"); @@ -151,7 +152,7 @@ public abstract class GrpcClient extends RpcClient { grpcExecutor.allowCoreThreadTimeOut(true); return grpcExecutor; } - + @Override public void shutdown() throws NacosException { super.shutdown(); @@ -160,7 +161,7 @@ public abstract class GrpcClient extends RpcClient { grpcExecutor.shutdown(); } } - + /** * Create a stub using a channel. * @@ -170,7 +171,7 @@ public abstract class GrpcClient extends RpcClient { private RequestGrpc.RequestFutureStub createNewChannelStub(ManagedChannel managedChannelTemp) { return RequestGrpc.newFutureStub(managedChannelTemp); } - + /** * create a new channel with specific server address. * @@ -181,15 +182,15 @@ public abstract class GrpcClient extends RpcClient { private ManagedChannel createNewManagedChannel(String serverIp, int serverPort) { LOGGER.info("grpc client connection server:{} ip,serverPort:{},grpcTslConfig:{}", serverIp, serverPort, JacksonUtils.toJson(clientConfig.tlsConfig())); - ManagedChannelBuilder managedChannelBuilder = buildChannel(serverIp, serverPort, buildSslContext()) - .executor(grpcExecutor).compressorRegistry(CompressorRegistry.getDefaultInstance()) + ManagedChannelBuilder managedChannelBuilder = buildChannel(serverIp, serverPort, buildSslContext()).executor( + grpcExecutor).compressorRegistry(CompressorRegistry.getDefaultInstance()) .decompressorRegistry(DecompressorRegistry.getDefaultInstance()) .maxInboundMessageSize(clientConfig.maxInboundMessageSize()) .keepAliveTime(clientConfig.channelKeepAlive(), TimeUnit.MILLISECONDS) .keepAliveTimeout(clientConfig.channelKeepAliveTimeout(), TimeUnit.MILLISECONDS); return managedChannelBuilder.build(); } - + /** * shutdown a channel. * @@ -200,7 +201,7 @@ public abstract class GrpcClient extends RpcClient { managedChannel.shutdownNow(); } } - + /** * check server if success. * @@ -221,25 +222,30 @@ public abstract class GrpcClient extends RpcClient { } catch (Exception e) { LoggerUtils.printIfErrorEnabled(LOGGER, "Server check fail, please check server {} ,port {} is available , error ={}", ip, port, e); + if (this.clientConfig != null && this.clientConfig.tlsConfig() != null && this.clientConfig.tlsConfig() + .getEnableTls()) { + LoggerUtils.printIfErrorEnabled(LOGGER, + "current client is require tls encrypted ,server must support tls ,please check"); + } return null; } } - + private StreamObserver bindRequestStream(final BiRequestStreamGrpc.BiRequestStreamStub streamStub, - final GrpcConnection grpcConn) { - + final GrpcConnection grpcConn) { + return streamStub.requestBiStream(new StreamObserver() { - + @Override public void onNext(Payload payload) { - + LoggerUtils.printIfDebugEnabled(LOGGER, "[{}]Stream server request receive, original info: {}", grpcConn.getConnectionId(), payload.toString()); try { Object parseBody = GrpcUtils.parse(payload); final Request request = (Request) parseBody; if (request != null) { - + try { Response response = handleServerRequest(request); if (response != null) { @@ -249,7 +255,7 @@ public abstract class GrpcClient extends RpcClient { LOGGER.warn("[{}]Fail to process server request, ackId->{}", grpcConn.getConnectionId(), request.getRequestId()); } - + } catch (Exception e) { LoggerUtils.printIfErrorEnabled(LOGGER, "[{}]Handle server request exception: {}", grpcConn.getConnectionId(), payload.toString(), e.getMessage()); @@ -258,16 +264,16 @@ public abstract class GrpcClient extends RpcClient { errResponse.setRequestId(request.getRequestId()); sendResponse(errResponse); } - + } - + } catch (Exception e) { - + LoggerUtils.printIfErrorEnabled(LOGGER, "[{}]Error to process server push response: {}", grpcConn.getConnectionId(), payload.getBody().getValue().toStringUtf8()); } } - + @Override public void onError(Throwable throwable) { boolean isRunning = isRunning(); @@ -278,14 +284,14 @@ public abstract class GrpcClient extends RpcClient { if (rpcClientStatus.compareAndSet(RpcClientStatus.RUNNING, RpcClientStatus.UNHEALTHY)) { switchServerAsync(); } - + } else { LoggerUtils.printIfWarnEnabled(LOGGER, "[{}]Ignore error event,isRunning:{},isAbandon={}", grpcConn.getConnectionId(), isRunning, isAbandon); } - + } - + @Override public void onCompleted() { boolean isRunning = isRunning(); @@ -296,16 +302,16 @@ public abstract class GrpcClient extends RpcClient { if (rpcClientStatus.compareAndSet(RpcClientStatus.RUNNING, RpcClientStatus.UNHEALTHY)) { switchServerAsync(); } - + } else { LoggerUtils.printIfInfoEnabled(LOGGER, "[{}]Ignore complete event,isRunning:{},isAbandon={}", grpcConn.getConnectionId(), isRunning, isAbandon); } - + } }); } - + private void sendResponse(Response response) { try { ((GrpcConnection) this.currentConnection).sendResponse(response); @@ -314,7 +320,7 @@ public abstract class GrpcClient extends RpcClient { response.getRequestId()); } } - + @Override public Connection connectToServer(ServerInfo serverInfo) { try { @@ -325,21 +331,21 @@ public abstract class GrpcClient extends RpcClient { ManagedChannel managedChannel = createNewManagedChannel(serverInfo.getServerIp(), port); RequestGrpc.RequestFutureStub newChannelStubTemp = createNewChannelStub(managedChannel); if (newChannelStubTemp != null) { - + Response response = serverCheck(serverInfo.getServerIp(), port, newChannelStubTemp); if (response == null || !(response instanceof ServerCheckResponse)) { shuntDownChannel(managedChannel); return null; } - + BiRequestStreamGrpc.BiRequestStreamStub biRequestStreamStub = BiRequestStreamGrpc.newStub( newChannelStubTemp.getChannel()); GrpcConnection grpcConn = new GrpcConnection(serverInfo, grpcExecutor); grpcConn.setConnectionId(((ServerCheckResponse) response).getConnectionId()); - + //create stream request and bind connection event to this connection. StreamObserver payloadStreamObserver = bindRequestStream(biRequestStreamStub, grpcConn); - + // stream observer to send response to server grpcConn.setPayloadStreamObserver(payloadStreamObserver); grpcConn.setGrpcFutureServiceStub(newChannelStubTemp); @@ -361,21 +367,19 @@ public abstract class GrpcClient extends RpcClient { } return null; } - + private ManagedChannelBuilder buildChannel(String serverIp, int port, Optional sslContext) { if (sslContext.isPresent()) { - return NettyChannelBuilder.forAddress(serverIp, port) - .negotiationType(NegotiationType.TLS) + return NettyChannelBuilder.forAddress(serverIp, port).negotiationType(NegotiationType.TLS) .sslContext(sslContext.get()); - + } else { - return ManagedChannelBuilder - .forAddress(serverIp, port).usePlaintext(); + return ManagedChannelBuilder.forAddress(serverIp, port).usePlaintext(); } } - + private Optional buildSslContext() { - + RpcClientTlsConfig tlsConfig = clientConfig.tlsConfig(); if (!tlsConfig.getEnableTls()) { return Optional.absent(); @@ -398,14 +402,16 @@ public abstract class GrpcClient extends RpcClient { Resource resource = resourceLoader.getResource(tlsConfig.getTrustCollectionCertFile()); builder.trustManager(resource.getInputStream()); } - + if (tlsConfig.getMutualAuthEnable()) { - if (StringUtils.isBlank(tlsConfig.getCertChainFile()) || StringUtils.isBlank(tlsConfig.getCertPrivateKey())) { + if (StringUtils.isBlank(tlsConfig.getCertChainFile()) || StringUtils.isBlank( + tlsConfig.getCertPrivateKey())) { throw new IllegalArgumentException("client certChainFile or certPrivateKey must be not null"); } Resource certChainFile = resourceLoader.getResource(tlsConfig.getCertChainFile()); Resource privateKey = resourceLoader.getResource(tlsConfig.getCertPrivateKey()); - builder.keyManager(certChainFile.getInputStream(), privateKey.getInputStream(), tlsConfig.getCertPrivateKeyPassword()); + builder.keyManager(certChainFile.getInputStream(), privateKey.getInputStream(), + tlsConfig.getCertPrivateKeyPassword()); } return Optional.of(builder.build()); } catch (Exception e) { diff --git a/core/src/main/java/com/alibaba/nacos/core/remote/BaseRpcServer.java b/core/src/main/java/com/alibaba/nacos/core/remote/BaseRpcServer.java index 232c8a984..fac469256 100644 --- a/core/src/main/java/com/alibaba/nacos/core/remote/BaseRpcServer.java +++ b/core/src/main/java/com/alibaba/nacos/core/remote/BaseRpcServer.java @@ -16,7 +16,6 @@ package com.alibaba.nacos.core.remote; -import com.alibaba.nacos.common.JustForTest; import com.alibaba.nacos.common.remote.ConnectionType; import com.alibaba.nacos.common.remote.PayloadRegistry; import com.alibaba.nacos.common.utils.JacksonUtils; @@ -34,31 +33,32 @@ import javax.annotation.PreDestroy; * @version $Id: BaseRpcServer.java, v 0.1 2020年07月13日 3:41 PM liuzunfei Exp $ */ public abstract class BaseRpcServer { - + static { PayloadRegistry.init(); } - + @Autowired - protected RpcServerTlsConfig grpcServerConfig; - - @JustForTest - public void setGrpcServerConfig(RpcServerTlsConfig grpcServerConfig) { - this.grpcServerConfig = grpcServerConfig; - } - + protected RpcServerTlsConfig rpcServerTlsConfig; + /** * Start sever. */ @PostConstruct public void start() throws Exception { String serverName = getClass().getSimpleName(); - String tlsConfig = JacksonUtils.toJson(grpcServerConfig); - Loggers.REMOTE.info("Nacos {} Rpc server starting at port {} and tls config:{}", serverName, getServicePort(), tlsConfig); + String tlsConfig = JacksonUtils.toJson(rpcServerTlsConfig); + Loggers.REMOTE.info("Nacos {} Rpc server starting at port {} and tls config:{}", serverName, getServicePort(), + tlsConfig); startServer(); - - Loggers.REMOTE.info("Nacos {} Rpc server started at port {} and tls config:{}", serverName, getServicePort(), tlsConfig); + + if (RpcServerSslContextRefresherHolder.getInstance() != null) { + RpcServerSslContextRefresherHolder.getInstance().refresh(this); + } + + Loggers.REMOTE.info("Nacos {} Rpc server started at port {} and tls config:{}", serverName, getServicePort(), + tlsConfig); Runtime.getRuntime().addShutdownHook(new Thread(() -> { Loggers.REMOTE.info("Nacos {} Rpc server stopping", serverName); try { @@ -68,7 +68,7 @@ public abstract class BaseRpcServer { Loggers.REMOTE.error("Nacos {} Rpc server stopped fail...", serverName, e); } })); - + } /** @@ -78,6 +78,19 @@ public abstract class BaseRpcServer { */ public abstract ConnectionType getConnectionType(); + public RpcServerTlsConfig getRpcServerTlsConfig() { + return rpcServerTlsConfig; + } + + public void setRpcServerTlsConfig(RpcServerTlsConfig rpcServerTlsConfig) { + this.rpcServerTlsConfig = rpcServerTlsConfig; + } + + /** + * reload ssl context. + */ + public abstract void reloadSslContext(); + /** * Start sever. * @@ -115,5 +128,5 @@ public abstract class BaseRpcServer { */ @PreDestroy public abstract void shutdownServer(); - + } diff --git a/core/src/main/java/com/alibaba/nacos/core/remote/RpcServerSslContextRefresher.java b/core/src/main/java/com/alibaba/nacos/core/remote/RpcServerSslContextRefresher.java new file mode 100644 index 000000000..06ad43afa --- /dev/null +++ b/core/src/main/java/com/alibaba/nacos/core/remote/RpcServerSslContextRefresher.java @@ -0,0 +1,41 @@ +package com.alibaba.nacos.core.remote; + +/** + * ssl context refresher spi holder. + * + * @author liuzunfei + * @version $Id: RequestFilters.java, v 0.1 2023年03月17日 12:00 PM liuzunfei Exp $ + */ +/* + * Copyright 1999-2020 Alibaba Group Holding Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +public interface RpcServerSslContextRefresher { + + /** + * listener current rpc server and do something on ssl context change. + * + * @param baseRpcServer rpc server. + * @return + */ + SslContextChangeAware refresh(BaseRpcServer baseRpcServer); + + /** + * refresher name. + * + * @return + */ + String getName(); +} diff --git a/core/src/main/java/com/alibaba/nacos/core/remote/RpcServerSslContextRefresherHolder.java b/core/src/main/java/com/alibaba/nacos/core/remote/RpcServerSslContextRefresherHolder.java new file mode 100644 index 000000000..6b910cd86 --- /dev/null +++ b/core/src/main/java/com/alibaba/nacos/core/remote/RpcServerSslContextRefresherHolder.java @@ -0,0 +1,75 @@ +/* + * Copyright 1999-2020 Alibaba Group Holding Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.nacos.core.remote; + +import com.alibaba.nacos.common.spi.NacosServiceLoader; +import com.alibaba.nacos.common.utils.StringUtils; +import com.alibaba.nacos.core.utils.Loggers; +import com.alibaba.nacos.sys.utils.ApplicationUtils; + +import java.util.Collection; + +/** + * ssl context refresher spi holder. + * + * @author liuzunfei + * @version $Id: RequestFilters.java, v 0.1 2023年03月17日 12:00 PM liuzunfei Exp $ + */ +public class RpcServerSslContextRefresherHolder { + + private static RpcServerSslContextRefresher instance; + + private static volatile boolean init = false; + + public static RpcServerSslContextRefresher getInstance() { + if (init) { + return instance; + } + synchronized (RpcServerSslContextRefresherHolder.class) { + if (init) { + return instance; + } + RpcServerTlsConfig rpcServerTlsConfig = ApplicationUtils.getBean(RpcServerTlsConfig.class); + String sslContextRefresher = rpcServerTlsConfig.getSslContextRefresher(); + if (StringUtils.isNotBlank(sslContextRefresher)) { + Collection load = NacosServiceLoader.load( + RpcServerSslContextRefresher.class); + for (RpcServerSslContextRefresher contextRefresher : load) { + if (sslContextRefresher.equals(contextRefresher.getName())) { + instance = contextRefresher; + Loggers.REMOTE.info("RpcServerSslContextRefresher of Name {} Founded->{}", sslContextRefresher, + contextRefresher.getClass().getSimpleName()); + break; + } + } + if (instance == null) { + Loggers.REMOTE.info("RpcServerSslContextRefresher of Name {} not found", sslContextRefresher); + } + + } else { + Loggers.REMOTE.info( + "No RpcServerSslContextRefresher specified,Ssl Context auto refresh not supported."); + } + + Loggers.REMOTE.info("RpcServerSslContextRefresher init end"); + init = true; + } + + return instance; + } + +} diff --git a/core/src/main/java/com/alibaba/nacos/core/remote/RpcServerTlsConfig.java b/core/src/main/java/com/alibaba/nacos/core/remote/RpcServerTlsConfig.java index 589f2b3e3..528d2f1c5 100644 --- a/core/src/main/java/com/alibaba/nacos/core/remote/RpcServerTlsConfig.java +++ b/core/src/main/java/com/alibaba/nacos/core/remote/RpcServerTlsConfig.java @@ -31,7 +31,9 @@ import org.springframework.stereotype.Component; public class RpcServerTlsConfig extends TlsConfig { public static final String PREFIX = "nacos.remote.server.rpc.tls"; - + + private String sslContextRefresher = ""; + private Boolean compatibility = true; public Boolean getCompatibility() { @@ -41,4 +43,12 @@ public class RpcServerTlsConfig extends TlsConfig { public void setCompatibility(Boolean compatibility) { this.compatibility = compatibility; } + + public String getSslContextRefresher() { + return sslContextRefresher; + } + + public void setSslContextRefresher(String sslContextRefresher) { + this.sslContextRefresher = sslContextRefresher; + } } diff --git a/core/src/main/java/com/alibaba/nacos/core/remote/SslContextChangeAware.java b/core/src/main/java/com/alibaba/nacos/core/remote/SslContextChangeAware.java new file mode 100644 index 000000000..347de201d --- /dev/null +++ b/core/src/main/java/com/alibaba/nacos/core/remote/SslContextChangeAware.java @@ -0,0 +1,43 @@ +/* + * Copyright 1999-2020 Alibaba Group Holding Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.nacos.core.remote; + +/** + * ssl context refresher spi holder. + * + * @author liuzunfei + * @version $Id: RequestFilters.java, v 0.1 2023年03月17日 12:00 PM liuzunfei Exp $ + */ +public interface SslContextChangeAware { + + /** + * init rpc server ssl context. + * + * @param baseRpcServer rpc server. + */ + void init(BaseRpcServer baseRpcServer); + + /** + * do something on ssl context change. + */ + void onSslContextChange(); + + /** + * shutdown to clear context. + */ + void shutdown(); +} diff --git a/core/src/main/java/com/alibaba/nacos/core/remote/grpc/BaseGrpcServer.java b/core/src/main/java/com/alibaba/nacos/core/remote/grpc/BaseGrpcServer.java index 517d73cf3..9007aea25 100644 --- a/core/src/main/java/com/alibaba/nacos/core/remote/grpc/BaseGrpcServer.java +++ b/core/src/main/java/com/alibaba/nacos/core/remote/grpc/BaseGrpcServer.java @@ -22,10 +22,12 @@ import com.alibaba.nacos.common.packagescan.resource.Resource; import com.alibaba.nacos.common.packagescan.resource.ResourceLoader; import com.alibaba.nacos.common.remote.ConnectionType; +import com.alibaba.nacos.common.utils.JacksonUtils; import com.alibaba.nacos.common.utils.StringUtils; import com.alibaba.nacos.common.utils.TlsTypeResolve; import com.alibaba.nacos.core.remote.BaseRpcServer; import com.alibaba.nacos.core.remote.ConnectionManager; +import com.alibaba.nacos.core.utils.Loggers; import com.alibaba.nacos.sys.env.EnvUtil; import io.grpc.CompressorRegistry; import io.grpc.DecompressorRegistry; @@ -61,63 +63,78 @@ import java.util.concurrent.TimeUnit; * @version $Id: BaseGrpcServer.java, v 0.1 2020年07月13日 3:42 PM liuzunfei Exp $ */ public abstract class BaseGrpcServer extends BaseRpcServer { - + private Server server; - + private final ResourceLoader resourceLoader = new DefaultResourceLoader(); - + @Autowired private GrpcRequestAcceptor grpcCommonRequestAcceptor; - + @Autowired private GrpcBiStreamRequestAcceptor grpcBiStreamRequestAcceptor; - + @Autowired private ConnectionManager connectionManager; - + + private OptionalTlsProtocolNegotiator optionalTlsProtocolNegotiator; + @Override public ConnectionType getConnectionType() { return ConnectionType.GRPC; } - + @Override public void startServer() throws Exception { final MutableHandlerRegistry handlerRegistry = new MutableHandlerRegistry(); addServices(handlerRegistry, new GrpcConnectionInterceptor()); NettyServerBuilder builder = NettyServerBuilder.forPort(getServicePort()).executor(getRpcExecutor()); - - if (grpcServerConfig.getEnableTls()) { - if (grpcServerConfig.getCompatibility()) { - builder.protocolNegotiator(new OptionalTlsProtocolNegotiator(getSslContextBuilder())); - } else { - builder.sslContext(getSslContextBuilder()); - } + + if (rpcServerTlsConfig.getEnableTls()) { + builder.protocolNegotiator( + new OptionalTlsProtocolNegotiator(getSslContextBuilder(), rpcServerTlsConfig.getCompatibility())); + } - + server = builder.maxInboundMessageSize(getMaxInboundMessageSize()).fallbackHandlerRegistry(handlerRegistry) .compressorRegistry(CompressorRegistry.getDefaultInstance()) .decompressorRegistry(DecompressorRegistry.getDefaultInstance()) .addTransportFilter(new AddressTransportFilter(connectionManager)) .keepAliveTime(getKeepAliveTime(), TimeUnit.MILLISECONDS) .keepAliveTimeout(getKeepAliveTimeout(), TimeUnit.MILLISECONDS) - .permitKeepAliveTime(getPermitKeepAliveTime(), TimeUnit.MILLISECONDS) - .build(); - + .permitKeepAliveTime(getPermitKeepAliveTime(), TimeUnit.MILLISECONDS).build(); + server.start(); } - + + /** + * reload ssl context. + */ + public void reloadSslContext() { + if (optionalTlsProtocolNegotiator != null) { + try { + optionalTlsProtocolNegotiator.setSslContext(getSslContextBuilder()); + } catch (Throwable throwable) { + Loggers.REMOTE.info("Nacos {} Rpc server reload ssl context fail at port {} and tls config:{}", + this.getClass().getSimpleName(), getServicePort(), + JacksonUtils.toJson(super.rpcServerTlsConfig)); + throw throwable; + } + } + } + protected long getPermitKeepAliveTime() { return GrpcServerConstants.GrpcConfig.DEFAULT_GRPC_PERMIT_KEEP_ALIVE_TIME; } - + protected long getKeepAliveTime() { return GrpcServerConstants.GrpcConfig.DEFAULT_GRPC_KEEP_ALIVE_TIME; } - + protected long getKeepAliveTimeout() { return GrpcServerConstants.GrpcConfig.DEFAULT_GRPC_KEEP_ALIVE_TIMEOUT; } - + protected int getMaxInboundMessageSize() { Integer property = EnvUtil.getProperty(GrpcServerConstants.GrpcConfig.MAX_INBOUND_MSG_SIZE_PROPERTY, Integer.class); @@ -126,88 +143,90 @@ public abstract class BaseGrpcServer extends BaseRpcServer { } return GrpcServerConstants.GrpcConfig.DEFAULT_GRPC_MAX_INBOUND_MSG_SIZE; } - + private void addServices(MutableHandlerRegistry handlerRegistry, ServerInterceptor... serverInterceptor) { - + // unary common call register. final MethodDescriptor unaryPayloadMethod = MethodDescriptor.newBuilder() - .setType(MethodDescriptor.MethodType.UNARY) - .setFullMethodName(MethodDescriptor.generateFullMethodName(GrpcServerConstants.REQUEST_SERVICE_NAME, - GrpcServerConstants.REQUEST_METHOD_NAME)) + .setType(MethodDescriptor.MethodType.UNARY).setFullMethodName( + MethodDescriptor.generateFullMethodName(GrpcServerConstants.REQUEST_SERVICE_NAME, + GrpcServerConstants.REQUEST_METHOD_NAME)) .setRequestMarshaller(ProtoUtils.marshaller(Payload.getDefaultInstance())) .setResponseMarshaller(ProtoUtils.marshaller(Payload.getDefaultInstance())).build(); - - final ServerCallHandler payloadHandler = ServerCalls - .asyncUnaryCall((request, responseObserver) -> grpcCommonRequestAcceptor.request(request, responseObserver)); - + + final ServerCallHandler payloadHandler = ServerCalls.asyncUnaryCall( + (request, responseObserver) -> grpcCommonRequestAcceptor.request(request, responseObserver)); + final ServerServiceDefinition serviceDefOfUnaryPayload = ServerServiceDefinition.builder( - GrpcServerConstants.REQUEST_SERVICE_NAME) - .addMethod(unaryPayloadMethod, payloadHandler).build(); + GrpcServerConstants.REQUEST_SERVICE_NAME).addMethod(unaryPayloadMethod, payloadHandler).build(); handlerRegistry.addService(ServerInterceptors.intercept(serviceDefOfUnaryPayload, serverInterceptor)); - + // bi stream register. final ServerCallHandler biStreamHandler = ServerCalls.asyncBidiStreamingCall( (responseObserver) -> grpcBiStreamRequestAcceptor.requestBiStream(responseObserver)); - + final MethodDescriptor biStreamMethod = MethodDescriptor.newBuilder() - .setType(MethodDescriptor.MethodType.BIDI_STREAMING).setFullMethodName(MethodDescriptor - .generateFullMethodName(GrpcServerConstants.REQUEST_BI_STREAM_SERVICE_NAME, + .setType(MethodDescriptor.MethodType.BIDI_STREAMING).setFullMethodName( + MethodDescriptor.generateFullMethodName(GrpcServerConstants.REQUEST_BI_STREAM_SERVICE_NAME, GrpcServerConstants.REQUEST_BI_STREAM_METHOD_NAME)) .setRequestMarshaller(ProtoUtils.marshaller(Payload.newBuilder().build())) .setResponseMarshaller(ProtoUtils.marshaller(Payload.getDefaultInstance())).build(); - - final ServerServiceDefinition serviceDefOfBiStream = ServerServiceDefinition - .builder(GrpcServerConstants.REQUEST_BI_STREAM_SERVICE_NAME).addMethod(biStreamMethod, biStreamHandler).build(); + + final ServerServiceDefinition serviceDefOfBiStream = ServerServiceDefinition.builder( + GrpcServerConstants.REQUEST_BI_STREAM_SERVICE_NAME).addMethod(biStreamMethod, biStreamHandler).build(); handlerRegistry.addService(ServerInterceptors.intercept(serviceDefOfBiStream, serverInterceptor)); - + } - + @Override public void shutdownServer() { if (server != null) { server.shutdownNow(); } } - + private SslContext getSslContextBuilder() { try { - if (StringUtils.isBlank(grpcServerConfig.getCertChainFile()) || StringUtils.isBlank(grpcServerConfig.getCertPrivateKey())) { + if (StringUtils.isBlank(rpcServerTlsConfig.getCertChainFile()) || StringUtils.isBlank( + rpcServerTlsConfig.getCertPrivateKey())) { throw new IllegalArgumentException("Server certChainFile or certPrivateKey must be not null"); } - InputStream certificateChainFile = getInputStream(grpcServerConfig.getCertChainFile(), "certChainFile"); - InputStream privateKeyFile = getInputStream(grpcServerConfig.getCertPrivateKey(), "certPrivateKey"); - SslContextBuilder sslClientContextBuilder = SslContextBuilder.forServer(certificateChainFile, privateKeyFile, - grpcServerConfig.getCertPrivateKeyPassword()); - - if (StringUtils.isNotBlank(grpcServerConfig.getProtocols())) { - sslClientContextBuilder.protocols(grpcServerConfig.getProtocols().split(",")); + InputStream certificateChainFile = getInputStream(rpcServerTlsConfig.getCertChainFile(), "certChainFile"); + InputStream privateKeyFile = getInputStream(rpcServerTlsConfig.getCertPrivateKey(), "certPrivateKey"); + SslContextBuilder sslClientContextBuilder = SslContextBuilder.forServer(certificateChainFile, + privateKeyFile, rpcServerTlsConfig.getCertPrivateKeyPassword()); + + if (StringUtils.isNotBlank(rpcServerTlsConfig.getProtocols())) { + sslClientContextBuilder.protocols(rpcServerTlsConfig.getProtocols().split(",")); } - - if (StringUtils.isNotBlank(grpcServerConfig.getCiphers())) { - sslClientContextBuilder.ciphers(Arrays.asList(grpcServerConfig.getCiphers().split(","))); + + if (StringUtils.isNotBlank(rpcServerTlsConfig.getCiphers())) { + sslClientContextBuilder.ciphers(Arrays.asList(rpcServerTlsConfig.getCiphers().split(","))); } - if (grpcServerConfig.getMutualAuthEnable()) { + if (rpcServerTlsConfig.getMutualAuthEnable()) { // trust all certificate - if (grpcServerConfig.getTrustAll()) { + if (rpcServerTlsConfig.getTrustAll()) { sslClientContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE); } else { - if (StringUtils.isBlank(grpcServerConfig.getTrustCollectionCertFile())) { - throw new IllegalArgumentException("enable mutual auth,trustCollectionCertFile must be not null"); + if (StringUtils.isBlank(rpcServerTlsConfig.getTrustCollectionCertFile())) { + throw new IllegalArgumentException( + "enable mutual auth,trustCollectionCertFile must be not null"); } - - InputStream clientCert = getInputStream(grpcServerConfig.getTrustCollectionCertFile(), "trustCollectionCertFile"); + + InputStream clientCert = getInputStream(rpcServerTlsConfig.getTrustCollectionCertFile(), + "trustCollectionCertFile"); sslClientContextBuilder.trustManager(clientCert); } sslClientContextBuilder.clientAuth(ClientAuth.REQUIRE); } SslContextBuilder configure = GrpcSslContexts.configure(sslClientContextBuilder, - TlsTypeResolve.getSslProvider(grpcServerConfig.getSslProvider())); + TlsTypeResolve.getSslProvider(rpcServerTlsConfig.getSslProvider())); return configure.build(); } catch (SSLException e) { throw new RuntimeException(e); } } - + private InputStream getInputStream(String path, String config) { try { Resource resource = resourceLoader.getResource(path); @@ -216,12 +235,12 @@ public abstract class BaseGrpcServer extends BaseRpcServer { throw new RuntimeException(config + " load fail", e); } } - + /** * get rpc executor. * * @return executor. */ public abstract ThreadPoolExecutor getRpcExecutor(); - + } diff --git a/core/src/main/java/com/alibaba/nacos/core/remote/grpc/OptionalTlsProtocolNegotiator.java b/core/src/main/java/com/alibaba/nacos/core/remote/grpc/OptionalTlsProtocolNegotiator.java index 85cd3516d..50cefc931 100644 --- a/core/src/main/java/com/alibaba/nacos/core/remote/grpc/OptionalTlsProtocolNegotiator.java +++ b/core/src/main/java/com/alibaba/nacos/core/remote/grpc/OptionalTlsProtocolNegotiator.java @@ -32,40 +32,45 @@ import java.lang.reflect.Field; import java.util.List; /** - * support the tls and plain protocol one the same port. + * support the tls and plain protocol one the same port. * * @author githubcheng2978. */ public class OptionalTlsProtocolNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { - + private static final int MAGIC_VALUE = 5; - + + private boolean supportPlainText; + private SslContext sslContext; - - public OptionalTlsProtocolNegotiator(SslContext sslContext) { + + public OptionalTlsProtocolNegotiator(SslContext sslContext, boolean supportPlainText) { + this.sslContext = sslContext; + this.supportPlainText = supportPlainText; + } + + void setSslContext(SslContext sslContext) { this.sslContext = sslContext; } - + @Override public AsciiString scheme() { return AsciiString.of("https"); } - + @Override public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHttp2ConnectionHandler) { - ChannelHandler plaintext = - InternalProtocolNegotiators.serverPlaintext().newHandler(grpcHttp2ConnectionHandler); - ChannelHandler ssl = - InternalProtocolNegotiators.serverTls(sslContext).newHandler(grpcHttp2ConnectionHandler); + ChannelHandler plaintext = InternalProtocolNegotiators.serverPlaintext().newHandler(grpcHttp2ConnectionHandler); + ChannelHandler ssl = InternalProtocolNegotiators.serverTls(sslContext).newHandler(grpcHttp2ConnectionHandler); ChannelHandler decoder = new PortUnificationServerHandler(ssl, plaintext); return decoder; } - + @Override public void close() { - + } - + private ProtocolNegotiationEvent getDefPne() { ProtocolNegotiationEvent protocolNegotiationEvent = null; try { @@ -77,31 +82,31 @@ public class OptionalTlsProtocolNegotiator implements InternalProtocolNegotiator } return protocolNegotiationEvent; } - + public class PortUnificationServerHandler extends ByteToMessageDecoder { + private ProtocolNegotiationEvent pne; - + private final ChannelHandler ssl; - + private final ChannelHandler plaintext; - + public PortUnificationServerHandler(ChannelHandler ssl, ChannelHandler plaintext) { this.ssl = ssl; this.plaintext = plaintext; this.pne = getDefPne(); } - + private boolean isSsl(ByteBuf buf) { return SslHandler.isEncrypted(buf); } - + @Override - protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) - throws Exception { + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { if (in.readableBytes() < MAGIC_VALUE) { return; } - if (isSsl(in)) { + if (isSsl(in) || !supportPlainText) { ctx.pipeline().addAfter(ctx.name(), (String) null, this.ssl); ctx.fireUserEventTriggered(pne); ctx.pipeline().remove(this); @@ -112,5 +117,5 @@ public class OptionalTlsProtocolNegotiator implements InternalProtocolNegotiator } } } - + } diff --git a/core/src/test/java/com/alibaba/nacos/core/remote/grpc/GrpcServerTest.java b/core/src/test/java/com/alibaba/nacos/core/remote/grpc/GrpcServerTest.java index c8df469b4..d0250049a 100644 --- a/core/src/test/java/com/alibaba/nacos/core/remote/grpc/GrpcServerTest.java +++ b/core/src/test/java/com/alibaba/nacos/core/remote/grpc/GrpcServerTest.java @@ -20,10 +20,14 @@ package com.alibaba.nacos.core.remote.grpc; import com.alibaba.nacos.common.remote.ConnectionType; import com.alibaba.nacos.core.remote.RpcServerTlsConfig; import com.alibaba.nacos.sys.env.EnvUtil; +import com.alibaba.nacos.sys.utils.ApplicationUtils; +import org.junit.AfterClass; import org.junit.Assert; -import org.junit.Before; +import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.MockedStatic; +import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; import org.springframework.mock.env.MockEnvironment; @@ -40,36 +44,46 @@ import static org.mockito.Mockito.when; */ @RunWith(MockitoJUnitRunner.Silent.class) public class GrpcServerTest { - + private final RpcServerTlsConfig grpcServerConfig = mock(RpcServerTlsConfig.class); - - @Before - public void setUp() { + + static MockedStatic applicationUtilsMockedStatic = null; + + @BeforeClass + public static void setUp() { EnvUtil.setEnvironment(new MockEnvironment()); + applicationUtilsMockedStatic = Mockito.mockStatic(ApplicationUtils.class); } - + + @AfterClass + public static void after() { + applicationUtilsMockedStatic.close(); + } + @Test public void testGrpcSdkServer() throws Exception { BaseGrpcServer grpcSdkServer = new GrpcSdkServer(); - grpcSdkServer.setGrpcServerConfig(grpcServerConfig); + grpcSdkServer.setRpcServerTlsConfig(grpcServerConfig); when(grpcServerConfig.getEnableTls()).thenReturn(false); + when(ApplicationUtils.getBean(RpcServerTlsConfig.class)).thenReturn(grpcServerConfig); grpcSdkServer.start(); Assert.assertEquals(grpcSdkServer.getConnectionType(), ConnectionType.GRPC); Assert.assertEquals(grpcSdkServer.rpcPortOffset(), 1000); grpcSdkServer.stopServer(); } - + @Test public void testGrpcClusterServer() throws Exception { BaseGrpcServer grpcSdkServer = new GrpcClusterServer(); - grpcSdkServer.setGrpcServerConfig(grpcServerConfig); + grpcSdkServer.setRpcServerTlsConfig(grpcServerConfig); when(grpcServerConfig.getEnableTls()).thenReturn(false); + when(ApplicationUtils.getBean(RpcServerTlsConfig.class)).thenReturn(grpcServerConfig); grpcSdkServer.start(); Assert.assertEquals(grpcSdkServer.getConnectionType(), ConnectionType.GRPC); Assert.assertEquals(grpcSdkServer.rpcPortOffset(), 1001); grpcSdkServer.stopServer(); } - + @Test public void testGrpcEnableTls() throws Exception { final BaseGrpcServer grpcSdkServer = new BaseGrpcServer() { @@ -77,7 +91,7 @@ public class GrpcServerTest { public ThreadPoolExecutor getRpcExecutor() { return null; } - + @Override public int rpcPortOffset() { return 100; @@ -86,40 +100,41 @@ public class GrpcServerTest { when(grpcServerConfig.getEnableTls()).thenReturn(true); when(grpcServerConfig.getCiphers()).thenReturn("ECDHE-RSA-AES128-GCM-SHA256,ECDHE-RSA-AES256-GCM-SHA384"); when(grpcServerConfig.getProtocols()).thenReturn("TLSv1.2,TLSv1.3"); - + when(grpcServerConfig.getCertPrivateKey()).thenReturn("test-server-key.pem"); when(grpcServerConfig.getCertChainFile()).thenReturn("test-server-cert.pem"); - grpcSdkServer.setGrpcServerConfig(grpcServerConfig); + when(ApplicationUtils.getBean(RpcServerTlsConfig.class)).thenReturn(grpcServerConfig); + grpcSdkServer.setRpcServerTlsConfig(grpcServerConfig); grpcSdkServer.start(); grpcSdkServer.shutdownServer(); } - + @Test public void testGrpcEnableMutualAuthAndTrustAll() throws Exception { - + final BaseGrpcServer grpcSdkServer = new BaseGrpcServer() { @Override public ThreadPoolExecutor getRpcExecutor() { return null; } - + @Override public int rpcPortOffset() { return 100; } }; - + when(grpcServerConfig.getEnableTls()).thenReturn(true); when(grpcServerConfig.getTrustAll()).thenReturn(true); when(grpcServerConfig.getCiphers()).thenReturn("ECDHE-RSA-AES128-GCM-SHA256,ECDHE-RSA-AES256-GCM-SHA384"); when(grpcServerConfig.getProtocols()).thenReturn("TLSv1.2,TLSv1.3"); when(grpcServerConfig.getCertPrivateKey()).thenReturn("test-server-key.pem"); when(grpcServerConfig.getCertChainFile()).thenReturn("test-server-cert.pem"); - grpcSdkServer.setGrpcServerConfig(grpcServerConfig); + grpcSdkServer.setRpcServerTlsConfig(grpcServerConfig); grpcSdkServer.start(); grpcSdkServer.shutdownServer(); } - + @Test public void testGrpcEnableMutualAuthAndPart() throws Exception { final BaseGrpcServer grpcSdkServer = new BaseGrpcServer() { @@ -127,7 +142,7 @@ public class GrpcServerTest { public ThreadPoolExecutor getRpcExecutor() { return null; } - + @Override public int rpcPortOffset() { return 100; @@ -138,13 +153,13 @@ public class GrpcServerTest { when(grpcServerConfig.getEnableTls()).thenReturn(true); when(grpcServerConfig.getCiphers()).thenReturn("ECDHE-RSA-AES128-GCM-SHA256,ECDHE-RSA-AES256-GCM-SHA384"); when(grpcServerConfig.getProtocols()).thenReturn("TLSv1.2,TLSv1.3"); - + when(grpcServerConfig.getCertPrivateKey()).thenReturn("test-server-key.pem"); when(grpcServerConfig.getCertChainFile()).thenReturn("test-server-cert.pem"); when(grpcServerConfig.getTrustCollectionCertFile()).thenReturn("test-ca-cert.pem"); - - grpcSdkServer.setGrpcServerConfig(grpcServerConfig); - + + grpcSdkServer.setRpcServerTlsConfig(grpcServerConfig); + grpcSdkServer.start(); grpcSdkServer.shutdownServer(); }