package hera.strategy;

import hera.annotation.ApiAudience;
import hera.annotation.ApiStability;
import hera.api.model.BytesValue;
import hera.exception.HerajException;
import hera.util.IoUtils;
import hera.util.ValidationUtils;
import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.okhttp.OkHttpChannelBuilder;
import java.io.InputStream;
import javax.net.ssl.SSLSocketFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ApiAudience.Private
@ApiStability.Unstable
/* loaded from: input_file:hera/strategy/TlsChannelStrategy.class */
public class TlsChannelStrategy implements SecurityConfigurationStrategy {
    protected final transient Logger logger = LoggerFactory.getLogger(getClass());
    protected final String serverName;
    protected final BytesValue serverCert;
    protected final BytesValue clientCert;
    protected final BytesValue clientKey;

    public TlsChannelStrategy(String str, InputStream inputStream, InputStream inputStream2, InputStream inputStream3) {
        ValidationUtils.assertNotNull(str, "Server common name must not null");
        ValidationUtils.assertNotNull(inputStream, "Server cert input stream must not null");
        ValidationUtils.assertNotNull(inputStream2, "Client cert input stream must not null");
        ValidationUtils.assertNotNull(inputStream3, "Client key input stream must not null");
        try {
            this.serverName = str;
            this.serverCert = BytesValue.of(IoUtils.from(inputStream));
            this.clientCert = BytesValue.of(IoUtils.from(inputStream2));
            this.clientKey = BytesValue.of(IoUtils.from(inputStream3));
        } catch (Exception e) {
            throw new HerajException(e);
        }
    }

    @Override // hera.strategy.ChannelConfigurationStrategy
    public void configure(ManagedChannelBuilder<?> managedChannelBuilder) {
        this.logger.info("Configure channel with tls (server name: {})", this.serverName);
        try {
            if (managedChannelBuilder instanceof NettyChannelBuilder) {
                ((NettyChannelBuilder) managedChannelBuilder).sslContext(GrpcSslContexts.forClient().trustManager(this.serverCert.getInputStream()).keyManager(this.clientCert.getInputStream(), this.clientKey.getInputStream()).build());
            } else {
                if (!(managedChannelBuilder instanceof OkHttpChannelBuilder)) {
                    throw new HerajException("Unsupported channel builder type " + managedChannelBuilder.getClass());
                }
                ((OkHttpChannelBuilder) managedChannelBuilder).sslSocketFactory((SSLSocketFactory) null);
            }
            managedChannelBuilder.overrideAuthority(this.serverName).useTransportSecurity();
        } catch (HerajException e) {
            throw e;
        } catch (Exception e2) {
            throw new HerajException(e2);
        }
    }

    public String toString() {
        return "TlsChannelStrategy(serverName=" + this.serverName + ")";
    }
}
