diff --git a/components/openssl/include/internal/ssl_types.h b/components/openssl/include/internal/ssl_types.h index 5aaee94176..15295243f6 100644 --- a/components/openssl/include/internal/ssl_types.h +++ b/components/openssl/include/internal/ssl_types.h @@ -81,6 +81,9 @@ typedef struct x509_method_st X509_METHOD; struct pkey_method_st; typedef struct pkey_method_st PKEY_METHOD; +struct ssl_alpn_st; +typedef struct ssl_alpn_st SSL_ALPN; + struct stack_st { char **data; @@ -144,6 +147,16 @@ struct X509_VERIFY_PARAM_st { }; +typedef enum { ALPN_INIT, ALPN_ENABLE, ALPN_DISABLE, ALPN_ERROR } ALPN_STATUS; +struct ssl_alpn_st { + ALPN_STATUS alpn_status; + /* This is dynamically allocated */ + char *alpn_string; + /* This only points to the members in the string */ +#define ALPN_LIST_MAX 10 + const char *alpn_list[ALPN_LIST_MAX]; +}; + struct ssl_ctx_st { int version; @@ -152,9 +165,7 @@ struct ssl_ctx_st unsigned long options; - #if 0 - struct alpn_protocols alpn_protocol; - #endif + SSL_ALPN ssl_alpn; const SSL_METHOD *method; @@ -277,6 +288,7 @@ struct pkey_method_st { int (*pkey_load)(EVP_PKEY *pkey, const unsigned char *buf, int len); }; + typedef int (*next_proto_cb)(SSL *ssl, unsigned char **out, unsigned char *outlen, const unsigned char *in, unsigned int inlen, void *arg); diff --git a/components/openssl/library/ssl_lib.c b/components/openssl/library/ssl_lib.c index 8b539826dc..a89eab9739 100644 --- a/components/openssl/library/ssl_lib.c +++ b/components/openssl/library/ssl_lib.c @@ -224,6 +224,10 @@ void SSL_CTX_free(SSL_CTX* ctx) X509_free(ctx->client_CA); + if (ctx->ssl_alpn.alpn_string) { + ssl_mem_free((void *)ctx->ssl_alpn.alpn_string); + } + ssl_mem_free(ctx); } @@ -1554,3 +1558,38 @@ void SSL_set_verify(SSL *ssl, int mode, int (*verify_callback)(int, X509_STORE_C ssl->verify_mode = mode; ssl->verify_callback = verify_callback; } + +/** + * @brief set the ALPN protocols in the preferred order. SSL APIs require the + * protocols in a format. mbedtls doesn't need + * that though. We sanitize that here itself. So convert from: + * "\x02h2\x06spdy/1" to { {"h2"}, {"spdy/1}, {NULL}} + */ +int SSL_CTX_set_alpn_protos(SSL_CTX *ctx, const unsigned char *protos, unsigned protos_len) +{ + ctx->ssl_alpn.alpn_string = ssl_mem_zalloc(protos_len + 1); + if (! ctx->ssl_alpn.alpn_string) { + return 1; + } + ctx->ssl_alpn.alpn_status = ALPN_ENABLE; + memcpy(ctx->ssl_alpn.alpn_string, protos, protos_len); + + char *ptr = ctx->ssl_alpn.alpn_string; + int i; + /* Only running to 1 less than the actual size */ + for (i = 0; i < ALPN_LIST_MAX - 1; i++) { + char len = *ptr; + *ptr = '\0'; // Overwrite the length to act as previous element's string terminator + ptr++; + protos_len--; + ctx->ssl_alpn.alpn_list[i] = ptr; + ptr += len; + protos_len -= len; + if (! protos_len) { + i++; + break; + } + } + ctx->ssl_alpn.alpn_list[i] = NULL; + return 0; +} diff --git a/components/openssl/platform/ssl_pm.c b/components/openssl/platform/ssl_pm.c index 54319d2550..3d8849e3a9 100755 --- a/components/openssl/platform/ssl_pm.c +++ b/components/openssl/platform/ssl_pm.c @@ -153,6 +153,9 @@ int ssl_pm_new(SSL *ssl) mbedtls_ssl_conf_min_version(&ssl_pm->conf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_0); } + if (ssl->ctx->ssl_alpn.alpn_status == ALPN_ENABLE) { + mbedtls_ssl_conf_alpn_protocols( &ssl_pm->conf, ssl->ctx->ssl_alpn.alpn_list ); + } mbedtls_ssl_conf_rng(&ssl_pm->conf, mbedtls_ctr_drbg_random, &ssl_pm->ctr_drbg); #ifdef CONFIG_OPENSSL_LOWLEVEL_DEBUG