doubango/tinyNET/src/tls/tnet_tls.c
c732d49e
 #if HAVE_CRT
 #define _CRTDBG_MAP_ALLOC 
 #include <stdlib.h> 
 #include <crtdbg.h>
 #endif //HAVE_CRT
 /*
 * Copyright (C) 2010-2012 Mamadou Diop.
 * Copyright (C) 2013 Doubango Telecom <http://www.doubango.org>
 *	
 * This file is part of Open Source Doubango Framework.
 *
 * DOUBANGO is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *	
 * DOUBANGO is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *	
 * You should have received a copy of the GNU General Public License
 * along with DOUBANGO.
 *
 */
 
 /**@file tnet_tls.c
  * @brief TLS utilitity functions, based on openssl.
  */
 #include "tnet_tls.h"
 #include "tnet_utils.h"
 
 #include "tsk_object.h"
 #include "tsk_string.h"
 #include "tsk_memory.h"
 #include "tsk_debug.h"
 #include "tsk_safeobj.h"
 
 #define TNET_TLS_TIMEOUT		2000
 #define TNET_TLS_RETRY_COUNT	10
 
 typedef struct tnet_tls_socket_s
 {
 	TSK_DECLARE_OBJECT;
 	
 	tnet_fd_t fd; /* not owner: do not try to close */
 
 #if HAVE_OPENSSL
 	SSL *ssl;
 #endif
 
 	TSK_DECLARE_SAFEOBJ;
 }
 tnet_tls_socket_t;
 
 tsk_bool_t tnet_tls_is_supported()
 {
 #if HAVE_OPENSSL
 	return tsk_true;
 #else
 	return tsk_false;
 #endif
 }
 
 tnet_tls_socket_handle_t* tnet_tls_socket_create(tnet_fd_t fd, struct ssl_ctx_st* ssl_ctx)
 {
 #if !HAVE_OPENSSL
 	TSK_DEBUG_ERROR("OpenSSL not enabled");
 	return tsk_null;
 #else
 	tnet_tls_socket_t* socket;
 	if(fd <= 0 || !ssl_ctx){
 		TSK_DEBUG_ERROR("Invalid parameter");
 		return tsk_null;
 	}
 	if((socket = tsk_object_new(tnet_tls_socket_def_t))){
 		socket->fd = fd;
 		if(!(socket->ssl = SSL_new(ssl_ctx))){
 			TSK_DEBUG_ERROR("SSL_new(CTX) failed [%s]", ERR_error_string(ERR_get_error(), tsk_null));
 			TSK_OBJECT_SAFE_FREE(socket);
 			return tsk_null;
 		}
 		if(SSL_set_fd(socket->ssl, socket->fd) != 1){
 			TSK_DEBUG_ERROR("SSL_set_fd(%d) failed [%s]", socket->fd, ERR_error_string(ERR_get_error(), tsk_null));
 			TSK_OBJECT_SAFE_FREE(socket);
 			return tsk_null;
 		}
 	}
 	return socket;
 #endif
 }
 
 int tnet_tls_socket_connect(tnet_tls_socket_handle_t* self)
 {
 #if !HAVE_OPENSSL
 	TSK_DEBUG_ERROR("You MUST enable OpenSSL");
 	return -200;
 #else
 	int ret;
 	tnet_tls_socket_t* socket = self;
 
 	if(!self){
 		TSK_DEBUG_ERROR("Invalid parameter");
 		return -1;
 	}
 
 	if((ret = SSL_connect(socket->ssl)) != 1){
 		ret = SSL_get_error(socket->ssl, ret);
 		if(ret == SSL_ERROR_WANT_WRITE || ret == SSL_ERROR_WANT_READ || ret == SSL_ERROR_SYSCALL){
 			ret = 0; /* up to the caller to check that the socket is writable and valid */
 		}
 		else{
 			TSK_DEBUG_ERROR("SSL_connect failed [%d, %s]", ret, ERR_error_string(ERR_get_error(), tsk_null));
 		}
 	}
 	else{
 		ret = 0;
 	}
 	
 	return ret;
 #endif
 }
 
 int tnet_tls_socket_accept(tnet_tls_socket_handle_t* self)
 {
 #if !HAVE_OPENSSL
 	TSK_DEBUG_ERROR("You MUST enable OpenSSL");
 	return -200;
 #else
 	int ret = -1;
 	tnet_tls_socket_t* socket = self;
 
 	if(!self){
 		TSK_DEBUG_ERROR("Invalid parameter");
 		return -1;
 	}
 	
 	if((ret = SSL_accept(socket->ssl)) != 1){
 		ret = SSL_get_error(socket->ssl, ret);
 		if(ret == SSL_ERROR_WANT_READ){
 			int retval;
 			fd_set rfds;
 			while (1)
 			{
 				FD_ZERO(&rfds);
 				FD_SET(socket->fd, &rfds);
 				retval = select(socket->fd + 1, &rfds, NULL, NULL, NULL);
 				if (retval == -1){
 					TNET_PRINT_LAST_ERROR("select() failed");
 				}
 				else if (retval)
 				{
 					if (FD_ISSET(socket->fd, &rfds)){
 						ret = SSL_accept(socket->ssl);
 						ret = SSL_get_error(socket->ssl, ret);
 						if (ret == SSL_ERROR_WANT_READ){
 							continue;
 						}
 						else{
 							if(ret == SSL_ERROR_NONE){
 								return 0;
 							}
 							break;
 						}
 					}
 				}
 				else
 				{
 					break;
 				}
 			}
 		}
 		TSK_DEBUG_ERROR("SSL_accept() failed with error code [%d, %s]", ret, ERR_error_string(ERR_get_error(), tsk_null));
 		return -3;
 	}
 
 	return 0;
 #endif
 }
 
 int tnet_tls_socket_write(tnet_tls_socket_handle_t* self, const void* data, tsk_size_t size)
 {
 #if !HAVE_OPENSSL
 	TSK_DEBUG_ERROR("You MUST enable OpenSSL");
 	return -200;
 #else
 	int ret = -1;
 	tnet_tls_socket_t* socket = self;
 	tsk_bool_t try_again = tsk_true, want_read, want_write;
 	
 	if(!self){
 		TSK_DEBUG_ERROR("Invalid parameter");
 		return -1;
 	}
 
 	/* Write */
 	tsk_safeobj_lock(socket);
 	while(((ret = SSL_write(socket->ssl, data, (int)size)) <= 0) && try_again){
 		ret = SSL_get_error(socket->ssl, ret);
 		want_read = (ret == SSL_ERROR_WANT_READ);
 		want_write = (ret == SSL_ERROR_WANT_WRITE);
 
 		if(want_write || want_read){
 			if(!(ret = tnet_sockfd_waitUntil(socket->fd, TNET_TLS_TIMEOUT, want_write))){
 				continue;
 			}
 		}
 		else{
 			TSK_DEBUG_ERROR("SSL_write failed [%d, %s]", ret, ERR_error_string(ERR_get_error(), tsk_null));
 			ret = -3;
 			try_again = tsk_false;
 		}
 	}
 	tsk_safeobj_unlock(socket);
 	
 	ret = (ret > 0) ? 0 : -3;
 	return ret;
 #endif
 }
 
 int tnet_tls_socket_recv(tnet_tls_socket_handle_t* self, void** data, tsk_size_t *size, tsk_bool_t *isEncrypted)
 {
 #if !HAVE_OPENSSL
 	TSK_DEBUG_ERROR("You MUST enable OpenSSL");
 	return -200;
 #else
 	int ret = -1;
 	tsk_size_t read = 0;
 	tsk_size_t to_read = *size;
 	int rcount = TNET_TLS_RETRY_COUNT;
 	tnet_tls_socket_t* socket = self;
 
 	if(!self){
 		TSK_DEBUG_ERROR("Invalid parameter");
 		return -1;
 	}
 	
 	tsk_safeobj_lock(socket);
 
 	*isEncrypted = SSL_is_init_finished(socket->ssl) ? tsk_false : tsk_true;
 
 	/* SSL handshake has completed? */
 	if(*isEncrypted){
 		char* buffer[1024];
 		if((ret = SSL_read(socket->ssl, buffer, sizeof(buffer))) <= 0){
 			ret = SSL_get_error(socket->ssl, ret);
 			if(ret == SSL_ERROR_WANT_WRITE || ret == SSL_ERROR_WANT_READ){
 				ret = 0;
 			}
 			else{
 				TSK_DEBUG_ERROR("SSL_read failed [%d, %s]", ret, ERR_error_string(ERR_get_error(), tsk_null));
 			}
 			*size = 0;
 		}
 		else{
 			*size = ret;
 			ret = 0;
 		}
 		
 		goto bail;
 	}
 
 	/* Read Application data */
 ssl_read:	
 	if(rcount && ((ret = SSL_read(socket->ssl, (((uint8_t*)*data)+read), (int)to_read)) <= 0)){
 		ret = SSL_get_error(socket->ssl, ret);
 		if(ret == SSL_ERROR_WANT_WRITE || ret == SSL_ERROR_WANT_READ){
 			if(!(ret = tnet_sockfd_waitUntil(socket->fd, TNET_TLS_TIMEOUT, (ret == SSL_ERROR_WANT_WRITE)))){
 				rcount--;
 				goto ssl_read;
 			}
 		}
 		else if(SSL_ERROR_ZERO_RETURN){ /* connection closed: do nothing, the transport layer will be alerted. */
 			*size = 0;
 			ret = 0;
 			TSK_DEBUG_INFO("TLS connection closed.");
 		}
 		else{
 			TSK_DEBUG_ERROR("SSL_read failed [%d, %s]", ret, ERR_error_string(ERR_get_error(), tsk_null));
 		}
 	}
 	else if(ret >=0){
 		read += (tsk_size_t)ret;
 
 		if((ret = SSL_pending(socket->ssl)) > 0){
 			void *ptr;
 			to_read = ret;
 
 			if((ptr = tsk_realloc(*data, (read + to_read)))){
 				*data = ptr;
 				goto ssl_read;
 			}
 		}
 	}
 
 bail:
 	tsk_safeobj_unlock(socket);
 
 	if(read){
 		*size = read;
 		return 0;
 	}
 	else{
 		return ret;
 	}
 #endif
 }
 
 
 
 
 //=================================================================================================
 //	TLS socket object definition
 //
 static tsk_object_t* tnet_tls_socket_ctor(tsk_object_t * self, va_list * app)
 {
 	tnet_tls_socket_t *socket = self;
 	if(socket){
 		tsk_safeobj_init(socket);
 	}
 	return self;
 }
 
 static tsk_object_t* tnet_tls_socket_dtor(tsk_object_t * self)
 { 
 	tnet_tls_socket_t *socket = self;
 	if(socket){
 #if HAVE_OPENSSL
 		if(socket->ssl){
 			SSL_shutdown(socket->ssl);
 			SSL_free(socket->ssl);
 		}
 #endif
 		tsk_safeobj_deinit(socket);
 	}
 	return self;
 }
 
 static const tsk_object_def_t tnet_tls_socket_def_s = 
 {
 	sizeof(tnet_tls_socket_t),
 	tnet_tls_socket_ctor, 
 	tnet_tls_socket_dtor,
 	tsk_null, 
 };
 const tsk_object_def_t *tnet_tls_socket_def_t = &tnet_tls_socket_def_s;