/*****************************************************************************
 *                  TmNS Source Selector - Server Interface                  *
 *                           Reference Application                           *
 *                                                                           *
 *  Provides SSL tunnel for data to/from TSS Server component.  Application  *
 *  listens for incoming connection requests from a TSS Client application   *
 *  and establishes the tunnel.  Tap interface is created on the server      *
 *  platform.                                                                *
 *                                                                           *
 *  This code is not designed to be thread safe.                             *
 *                                                                           *
 *  Date: April 27, 2017                                                     *
 *****************************************************************************/

#include "tss_common.h"
#include <fcntl.h>
#include <linux/if_tun.h>

/*---------------------------------------------------------------------*/
/*--- CONSTANTS                                                     ---*/
/*---------------------------------------------------------------------*/

#define MTU 1600
#define MSG_ID_CALC_OFFSET 14

/*---------------------------------------------------------------------*/
/*--- FORWARD DECLARATIONS AND PROTOTYPES                           ---*/
/*---------------------------------------------------------------------*/

void usage();
void tss_server_init(char *, char *, int, char *, char *, char *, char *, char *);
void send_if_value(char *iface, SSL *s1, int s2, int32_t get_cmd, int32_t set_cmd);


/*---------------------------------------------------------------------*/
/*--- GLOBAL VARIABLES                                              ---*/
/*---------------------------------------------------------------------*/

struct ifreq ifr;
int tap_if_s;
int tcp_s;
SSL *ssl;
int DEBUG;


/*---------------------------------------------------------------------*/
/*--- usage - print out the command line parameters and options     ---*/
/*---------------------------------------------------------------------*/
void usage()
{
	fprintf(stderr, "Usage: tss_server [-C conf | -c certfile -k key [-l port] [-n ifname] [-i ipaddr] [-b bcastaddr] [-m netmask] [-p post_init_script]]\n"
                    "   -b <bcastaddr>   : Broadcast address to associate with the TSS remote interface\n"
                    "   -C <config file> : Configuration file that contains all other initialization parameters\n"
                    "   -c <certfile>    : Certificate file for SSL\n"
                    "   -k <keyfile>     : Private key file for SSL\n"
                    "   -i <IP address>  : IP address to assign to this TSS remote interface (192.168.1.1)\n"
                    "   -l <listen port> : Port to listen on for incoming TCP connection (55000)\n"
                    "   -n <ifname>      : Name to give the TSS remote interface (tap0)\n"
                    "   -m <netmask>     : Netmask of the TSS remote interface\n"
                    "   -p <post init>   : Post initialization script name to be run at the end of initialization\n"
                    "   -d               : Increase debug level (can be used more than once)\n"
                    "   -h               : This help menu\n");
	exit(0);
}


/*****************************************************************************/
/*---------------------------------------------------------------------------*/
/*--- MAIN FUNCTION                                                       ---*/
/*---------------------------------------------------------------------------*/
/*****************************************************************************/

int main(int argc, char **argv)
{
    fd_set  readset;
    int     max_s;
    TSS_MSG msg;
    int     nread;
    int16_t plength;
    
    int   c;
    char *tap_bcast_addr    = "192.168.1.255";
    char *config_file       = " ";
    char *cert_file         = " ";
    char *key_file	        = " ";
    char *tap_ip_addr       = "192.168.1.1";
    int   listen_port       = 55000;
    char *tap_if_name       = "tap0";
    char *tap_netmask       = "255.255.255.0";
    char *post_init_script  = " ";
    
    DEBUG                   = 0;

    
    /* Parse command line arguments */
    while ((c = getopt(argc, argv, "b:C:c:k:i:l:n:m:p:dh")) != -1)
    {
        switch (c)
        {
            case 'b':   tap_bcast_addr   = optarg;       break;
            case 'C':   config_file      = optarg;       break;
            case 'c':   cert_file        = optarg;       break;
            case 'k':   key_file         = optarg;	     break;
            case 'i':   tap_ip_addr      = optarg;       break;
            case 'l':   listen_port      = atoi(optarg); break;
            case 'n':   tap_if_name      = optarg;       break;
            case 'm':   tap_netmask      = optarg;       break;
            case 'p':   post_init_script = optarg;       break;
            case 'd':   DEBUG++;                         break;
            case 'h':   usage();                         break;
            default:    usage();                         break;
        }
    }
    
    /* Check for config file.  If config file exists, then parse it and assign values accordingly */
    if (0 != strncmp(" ", config_file, 2))
    {
        // Config file exists.
        
        char carray[256];
        FILE *p_fd = fopen(config_file, "r");
        if (NULL==p_fd)
        {
            fprintf(stderr, "Failed to open filename: %s\n", config_file);
        }
        else
        {
            const char delims[] = "=,: \n";
            char *name;
            char *value;
            
            while (NULL != fgets(carray, 256, p_fd))
            {
                name = (char*)strtok(carray, delims);
                if (NULL == name) break;
                value = (char*)strtok(NULL, delims);
                
                if (0 == strncmp(name, "CERT_FILE", 9))
                {
                    cert_file = malloc(255);
                    strncpy(cert_file, value, 255);
                }
                else if (0 == strncmp(name, "KEY_FILE", 8))
                {
                    key_file = malloc(255);
                    strncpy(key_file, value, 255);
                }
                else if (0 == strncmp(name, "POST_INIT_SCRIPT", 16))
                {
                    post_init_script = malloc(255);
                    strncpy(post_init_script, value, 255);
                }
                else if (0 == strncmp(name, "LISTEN_PORT", 11))
                {
                    listen_port = atoi(value);
                }
                else if (0 == strncmp(name, "TAP_IFNAME", 10))
                {
                    tap_if_name = malloc(255);
                    strncpy(tap_if_name, value, 255);
                }
                else if (0 == strncmp(name, "TAP_IPADDR", 10))
                {
                    tap_ip_addr = malloc(255);
                    strncpy(tap_ip_addr, value, 255);
                }
                else if (0 == strncmp(name, "TAP_NETMASK", 11))
                {
                    tap_netmask = malloc(255);
                    strncpy(tap_netmask, value, 255);
                }
                else if (0 == strncmp(name, "TAP_BCAST", 9))
                {
                    tap_bcast_addr = malloc(255);
                    strncpy(tap_bcast_addr, value, 255);
                }
                else if (0 == strncmp(name, "DEBUG_LEVEL", 11))
                {
                    DEBUG = atoi(value);
                }
            }
            fclose(p_fd);
        }
    }
    
    if (strncmp(cert_file, " ", 1) == 0) { ERROR("Must provide a certificate file.\n"); }
    if (strncmp(key_file, " ", 1) == 0) { ERROR("Must provide a key file.\n"); }
    
    /* Initialize server: open listening socket and wait for connection */
    tss_server_init(cert_file, key_file, listen_port, tap_if_name, tap_ip_addr, tap_bcast_addr, tap_netmask, post_init_script);
    
    /* get maximum socket for select() call below */
    max_s = tap_if_s;
    if (max_s < tcp_s) max_s = tcp_s;
    
    /********************************************************************************/
    /* Main Algorithm Loop - do forever                                             */
    /*   1) Add TCP sockets to the readset                                          */
    /*   2) Select() on readset until data is available to read                     */
    /*   3) For each socket ready for reading:                                      */
    /*      3a) Read data from socket                                               */
    /*      3b) Write data to destination socket / or process command               */
    /********************************************************************************/
    
    do
    {
        /* 1 */
        FD_ZERO(&readset);
        FD_SET(tap_if_s, &readset);
        FD_SET(tcp_s, &readset);
        
        /* 2 */
        if (select(max_s + 1, &readset, NULL, NULL, NULL) < 0)
        {
            if (errno == EINTR) continue;
            
            PERROR("select()");
        }
        
        
        /* 3 - Receive frame from TSS Master via TCP.  Write to tap interface (i.e., packet being received by tap interface) */
        if (FD_ISSET(tcp_s, &readset))
        {
            
            nread = cread(ssl, tcp_s, (unsigned char*)&msg.len, sizeof(msg.len));
            if (nread < 0) PERROR("bad read");
            if (nread == 0) PERROR("TCP connection is likely broken");
            if (nread < sizeof(msg.len)) bio_read(ssl, tcp_s, ((unsigned char *)&(msg.len))+nread, sizeof(msg.len)-nread);
            
            plength = ntohs(msg.len);
            if (plength == 0) continue;
            if (plength > MTU) { plength = MTU; fprintf(stderr,"WARNING: truncated packet\n");}
            
            if (bio_read(ssl, tcp_s, (unsigned char *)&msg.crc, plength) < 0) PERROR("read crc+frame");
            if (write(tap_if_s, msg.buffer, plength - sizeof(msg.crc)) < 0) PERROR("write to tap");
        }
        
        
        /* 3 - Receive frame from tap interface.  Send to TSS Master via TCP (i.e., packet being sent from tap interface) */
        if (FD_ISSET(tap_if_s, &readset))
        {
            nread = read(tap_if_s, msg.buffer, sizeof(msg.buffer));
            if (errno == EINTR) continue;
            if (nread < 0) PERROR_CONTINUE("read");
            
            msg.len = htons(nread + sizeof(msg.crc));
            msg.crc = htonl(get_crc32(msg.buffer+MSG_ID_CALC_OFFSET, nread-MSG_ID_CALC_OFFSET));
            if (cwrite(ssl, tcp_s, (void *)&msg, ntohs(msg.len) + sizeof(msg.len)) < 0) PERROR("cwrite msg");
        }
        
    } while (1);
    
    exit(0); /* never reached */
}



/*---------------------------------------------------------------------*/
/*--- tss_server_init - Initialize the TSS Remote Server            ---*/
/*---------------------------------------------------------------------*/
void tss_server_init(char *cert_file,
                     char *key_file,
                     int listen_port,
                     char *tap_if_name,
                     char *tap_ip_addr,
                     char *tap_bcast_addr,
                     char *tap_netmask,
                     char *post_init_script)
{
    struct sockaddr_in local;
    struct sockaddr_in remote;
    socklen_t locallen;
    socklen_t remotelen;
    int listen_s;
    SSL_CTX *ctx;
    int optval = 1;
    

    SSL_library_init();
    ctx = InitServerCTX();
    LoadCertificates(ctx, cert_file, key_file);	/* load certs */
    
    /****************************************************************/
    /* Setup listen port and wait for a connection for the TCP side */
    /*    Reuse so in case restarting quicker than timeouts         */
    /*    Accept when connected                                     */
    /****************************************************************/
    if ((listen_s = socket(AF_INET, SOCK_STREAM, 0)) < 0) PERROR("socket()");
    if (setsockopt(listen_s, SOL_SOCKET, SO_REUSEADDR, (char *)&optval, sizeof(optval)) < 0) PERROR("setsockopt()");
    
    memset(&local, 0, sizeof(local));
    locallen = sizeof(local);
    local.sin_family = AF_INET;
    local.sin_addr.s_addr = htonl(INADDR_ANY);
    local.sin_port = htons((short)listen_port);
    if (bind(listen_s, (struct sockaddr*) &local, locallen) < 0) PERROR("bind()");
    if (listen(listen_s, 5) < 0) PERROR("listen()");
    
    memset(&remote, 0, sizeof(remote));
    remotelen = sizeof(remote);
    if ((tcp_s = accept(listen_s, (struct sockaddr*)&remote, &remotelen)) < 0) PERROR("accept()");
    
    ssl = SSL_new(ctx);
    SSL_set_fd(ssl, tcp_s);
    
    if ( SSL_accept(ssl) < 0 ) { ERR_print_errors_fp(stderr); exit(__LINE__); }
    ShowCerts(ssl);   /* get any certificates */
    
    
    /********************************************************************************/
    /* Create and configure tap interface to be used                                */
    /*   1) Create tap interface                                                    */
    /*   2) Set tap interface properties                                            */
    /*   3) Bring up new tap interface                                              */
    /*   4) Optimize TCP connections                                                */
    /*   5) Send tap interface properties to master                                 */
    /*   6) Run the post-initialization script, if provided in the argument list    */
    /********************************************************************************/
    
    /* 1 */
    if ((tap_if_s = open("/dev/net/tun",O_RDWR)) < 0) PERROR("open tunnel");
    memset(&ifr, 0, sizeof(ifr));
    strncpy(ifr.ifr_name, tap_if_name, IFNAMSIZ);
    ifr.ifr_flags = IFF_TAP | IFF_NO_PI;
    if (ioctl(tap_if_s, TUNSETIFF, (void *)&ifr) < 0) PERROR("ioctl TUNSETIFF");
    
    /* 2a */    // Set MTU
    memset(&ifr, 0, sizeof(ifr));
    strncpy(ifr.ifr_name, tap_if_name, IFNAMSIZ);
    ifr.ifr_mtu = 1500;
    if (DEBUG>=3) fprintf(stderr, "MTU: %d\n", ifr.ifr_mtu);
    if (ioctl(tcp_s, SIOCSIFMTU, (void*)&ifr) < 0) PERROR_CONTINUE("ioctl set: MTU");
    
    /* 2b */    // Set IP Address
    memset(&ifr, 0, sizeof(ifr));
    strncpy(ifr.ifr_name, tap_if_name, IFNAMSIZ);
    ifr.ifr_addr.sa_family = AF_INET;
    struct sockaddr_in address;
    memset(&address, 0, sizeof (struct sockaddr_in));
    address.sin_family=AF_INET;
    address.sin_port=0;
    address.sin_addr.s_addr=inet_addr(tap_ip_addr);
    memcpy(&ifr.ifr_addr, &address, sizeof(struct sockaddr_in));
    if (DEBUG>=3) fprintf(stderr, "IP Addr: %s\n", inet_ntoa(((struct sockaddr_in*)&ifr.ifr_addr)->sin_addr));
    if (ioctl(tcp_s, SIOCSIFADDR, (void*)&ifr) < 0) PERROR_CONTINUE("ioctl set: IP Addr");
    
    /* 2c */    // Set IP Destination Address
    if (DEBUG>=3) fprintf(stderr, "Dst IP Addr: %s\n", inet_ntoa(((struct sockaddr_in*)&ifr.ifr_dstaddr)->sin_addr));
    if (ioctl(tcp_s, SIOCSIFDSTADDR, (void*)&ifr) < 0) PERROR_CONTINUE("ioctl set: Dst IP Addr");
    
    /* 2d */    // Set Broadcast Address
    memset(&ifr, 0, sizeof(ifr));
    strncpy(ifr.ifr_name, tap_if_name, IFNAMSIZ);
    memset(&address, 0, sizeof (struct sockaddr_in));
    address.sin_family=AF_INET;
    address.sin_port=0;
    address.sin_addr.s_addr=inet_addr(tap_bcast_addr);
    memcpy(&ifr.ifr_addr, &address, sizeof(struct sockaddr_in));
    if (DEBUG>=3) fprintf(stderr, "Broadcast Addr: %s\n", inet_ntoa(((struct sockaddr_in*)&ifr.ifr_broadaddr)->sin_addr));
    if (ioctl(tcp_s, SIOCSIFBRDADDR, (void*)&ifr) < 0) PERROR_CONTINUE("ioctl set: Bcast Addr");
    
    /* 2e */    // Set Netmask
    memset(&ifr, 0, sizeof(ifr));
    strncpy(ifr.ifr_name, tap_if_name, IFNAMSIZ);
    memset(&address, 0, sizeof (struct sockaddr_in));
    address.sin_family=AF_INET;
    address.sin_port=0;
    address.sin_addr.s_addr=inet_addr(tap_netmask);
    memcpy(&ifr.ifr_addr, &address, sizeof(struct sockaddr_in));
    if (DEBUG>=3) fprintf(stderr, "Netmask: %s\n", inet_ntoa(((struct sockaddr_in*)&ifr.ifr_netmask)->sin_addr));
    if (ioctl(tcp_s, SIOCSIFNETMASK, (void*)&ifr) < 0) PERROR_CONTINUE("ioctl set: Netmask");
    
    /* 3 */
    memset(&ifr, 0, sizeof(ifr));
    strncpy(ifr.ifr_name, tap_if_name, IFNAMSIZ);
    if (ioctl(tcp_s, SIOCGIFFLAGS, (void*)&ifr) < 0) PERROR_CONTINUE("ioctl get: SIOCGIFFLAGS");
    ifr.ifr_flags |= IFF_UP | IFF_RUNNING;
    if (DEBUG>=3) fprintf(stderr, "Bring interface %s up...\n", tap_if_name);
    if (ioctl(tcp_s, SIOCSIFFLAGS, (void *)&ifr) < 0) PERROR("ioctl set: SIOCSIFFLAGS");
    
    if (DEBUG>=3) fprintf(stderr, "Interface '%s' at %s successfully created.  Now up and running.\n", tap_if_name, tap_ip_addr);
    
    /* 4 */
    {
        int opt =1;
        opt = setsockopt(tcp_s, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt));
        opt = setsockopt(tcp_s, IPPROTO_TCP, TCP_QUICKACK, &opt, sizeof(opt));
    }
    
    /* 5 */
    send_if_value(tap_if_name, ssl, tcp_s, SIOCGIFHWADDR,  SIOCSIFHWADDR);
    send_if_value(tap_if_name, ssl, tcp_s, SIOCGIFMTU,     SIOCSIFMTU);
    send_if_value(tap_if_name, ssl, tcp_s, SIOCGIFADDR,    SIOCSIFADDR);
    send_if_value(tap_if_name, ssl, tcp_s, SIOCGIFDSTADDR, SIOCSIFDSTADDR);
    send_if_value(tap_if_name, ssl, tcp_s, SIOCGIFBRDADDR, SIOCSIFBRDADDR);
    send_if_value(tap_if_name, ssl, tcp_s, SIOCGIFNETMASK, SIOCSIFNETMASK);
    
    /* 6 */
    {
        char cmd[256];
        memset(cmd, 0, 256);
        
        if (0 != strncmp(" ", post_init_script, 2))
        {
            sprintf(cmd,"%s", post_init_script);
            system(cmd);
        }
    }
}


/*---------------------------------------------------------------------*/
/*--- send_if_value - send interface info to Virtual Router Master  ---*/
/*---------------------------------------------------------------------*/
void send_if_value(char *iface, SSL *s1, int s2, int32_t get_cmd, int32_t set_cmd)
{
    int32_t set_cmd2;
    
    memset(&ifr, 0, sizeof(ifr));
    strcpy(ifr.ifr_name, iface);
    if (ioctl(s2, get_cmd, &ifr) < 0) PERROR("ioctl in send_if_value");
    
    set_cmd2 = htonl(set_cmd);
    cwrite(s1, s2, (void*)&set_cmd2, sizeof(set_cmd2));
    if (set_cmd == SIOCSIFHWADDR) ifr.ifr_addr.sa_family = htons(AF_UNIX);
    else if (set_cmd == SIOCSIFMTU) ifr.ifr_mtu = htonl(ifr.ifr_mtu);
    else ifr.ifr_addr.sa_family = htons(AF_INET);
    cwrite(s1, s2, (void *)&ifr, 32);
}
