/*****************************************************************************
 *                  TmNS Source Selector - Client Interface                  *
 *                           Reference Application                           *
 *                                                                           *
 *  Provides SSL tunnel for data to/from TSS Client component.  Application  *
 *  connects to a TSS Server and establishes the tunnel.  Tap interface is   *
 *  created on the client platform.                                          *
 *                                                                           *
 *  This code is not designed to be thread safe.                             *
 *                                                                           *
 *  Date: April 27, 2017                                                      *
 *****************************************************************************/


#include "tss_common.h"
#include <fcntl.h>
#include <linux/if_tun.h>
#include <netdb.h>


/*---------------------------------------------------------------------*/
/*--- CONSTANTS                                                     ---*/
/*---------------------------------------------------------------------*/

#define CTRL_PORT 50505
#define MTU 1600

/*---------------------------------------------------------------------*/
/*--- FORWARD DECLARATIONS AND PROTOTYPES                           ---*/
/*---------------------------------------------------------------------*/

void usage();
void connect_tunnel(void);

/*---------------------------------------------------------------------*/
/*--- LOCAL STRUCTS                                                 ---*/
/*---------------------------------------------------------------------*/

typedef struct connection_struct
{
    int conn_port;
    char *conn_ip;
    SSL *ssl;
    SSL_CTX *ctx;
    int tcp_s;
    int tap_s;
    int32_t req;
    struct ifreq ifr;
    char *tss_ifname;
    char *tss_ip;
    char *tss_bcast;
    char *tss_netmask;
} CONNECTION;


/*---------------------------------------------------------------------*/
/*--- GLOBAL VARIABLES                                              ---*/
/*---------------------------------------------------------------------*/

CONNECTION tss_if;
int        tss_if_count = 0;

int        max_s        = 0;  // used for select() call
fd_set     readset;

int        debug        = 0;    // default: No debug printouts (0)


/*---------------------------------------------------------------------*/
/*--- usage - print out the command line parameters and options     ---*/
/*---------------------------------------------------------------------*/

void usage()
{
	fprintf(stderr, "Usage: tss_client -c targetip:port [-I ifname]\n"
                    "-c <IP>:<port>   : connect to <IP>:<port>\n"
                    "-I <ifname>      : choose the name of the virtual interface\n"
                    "-d               : increase debug level (can be used more than once)\n" );
	exit(0);
}



/*---------------------------------------------------------------------*/
/*---------------------------------------------------------------------*/
/*--- Main Function.                                                ---*/
/*---------------------------------------------------------------------*/
/*---------------------------------------------------------------------*/

int main(int argc, char *argv[])
{
    int32_t len;
    
    int c;
    char *p;
    TSS_MSG msg;

    struct sockaddr_in ctrl_sin;
    int ctrl_s;    

    int nread;
    int16_t plength;
    
    // Default interface values
    tss_if.tss_ifname = "tap0";
    tss_if.tss_ip = "192.168.1.2";
    tss_if.tss_bcast = "192.168.1.255";
    tss_if.tss_netmask = "255.255.255.0";
 
    /* Parse command line arguments */
    while ((c = getopt(argc, argv, "c:I:dh")) != -1)
    {
        switch (c) 
        {
            case 'c':
                        p = memchr(optarg,':',16);
                        if (!p) ERROR("invalid argument : [%s]\n",optarg);
                        *p = 0;
                        tss_if.conn_ip    = optarg;
                        tss_if.conn_port  = atoi(p+1);
                        tss_if_count++;
                        break;
            case 'I':   tss_if.tss_ifname = optarg;       break;
            case 'd':   debug++;                          break;
            case 'h':   usage();                          break;
            default:    usage();                          break;
		}
    }
    
    if (tss_if_count != 1) ERROR("Must have one connection.  See help menu (-h).\n");

    connect_tunnel();  // Attempt to connect to TSS Server
    
    /* Open socket for listening for commands from the LM. */
    {
        if ((ctrl_s = socket(PF_INET, SOCK_DGRAM, 0)) < 0) PERROR("Cannot open control socket!");
        ctrl_sin.sin_family = AF_INET;
        ctrl_sin.sin_port = htons(CTRL_PORT);
        ctrl_sin.sin_addr.s_addr = htonl(INADDR_ANY);
        if (bind(ctrl_s, (struct sockaddr *)&ctrl_sin, sizeof(ctrl_sin)) < 0) PERROR("Cannot bind control socket!");
    }

    
    while (1) 
    {
        /***********************************************
        * Create FD_SET for select  -- must include    *
        *   all data and control connections to remote *
        *   interfaces.  Must also include stdin.      *                        
        ***********************************************/
        FD_ZERO(&readset);
        FD_SET(ctrl_s,&readset);
        max_s = ctrl_s;
        FD_SET(fileno(stdin), &readset);
        if (fileno(stdin) > max_s)    max_s = fileno(stdin);
        FD_SET(tss_if.tap_s,&readset);
        if (tss_if.tap_s > max_s)     max_s = tss_if.tap_s;
        FD_SET(tss_if.tcp_s,&readset);
        if (tss_if.tcp_s > max_s)     max_s = tss_if.tcp_s;

				
        /***********************************************
        * Select to see which socket need service      *
        ***********************************************/
        if (select(max_s + 1, &readset, NULL, NULL, NULL) < 0)
        {
            if (errno == EINTR) continue;
            PERROR("select()");
        }
        					
        /******************************************************
        * Receive data from a remote IF and write to tap      *
        ******************************************************/

        if (FD_ISSET(tss_if.tcp_s, &readset))
        {
            if (debug==5) write(1,"v", 1);
            len = cread(tss_if.ssl, tss_if.tcp_s, (unsigned char *)&msg.len, sizeof(msg.len));
            if(len<0) PERROR("bad read");
            if(len==0) continue;
                
            if(len<sizeof(msg.len)) bio_read(tss_if.ssl, tss_if.tcp_s, ((unsigned char *)&msg.len)+len, sizeof(msg.len)-len);
            plength = ntohs(msg.len);
				
            if (plength > MTU) { plength = MTU; fprintf(stderr, "WARNING: truncated packet\n"); }
            if (bio_read(tss_if.ssl, tss_if.tcp_s, (unsigned char *)&msg.crc, sizeof(msg.crc)) < 0) PERROR("read crc data");
            if (bio_read(tss_if.ssl, tss_if.tcp_s, msg.buffer, plength - sizeof(msg.crc)) < 0) PERROR("read packet data");
                
            if (write(tss_if.tap_s, msg.buffer, plength - sizeof(msg.crc)) < 0) PERROR_CONTINUE("write to tap");
        }

     
        /**********************************
        * Send data to the TSS server     *
        **********************************/
        if (FD_ISSET(tss_if.tap_s, &readset))
        {
            if (debug==5) write(1,"^", 1);
            nread = read(tss_if.tap_s, msg.buffer, sizeof(msg.buffer));
            if (nread < 0) PERROR_CONTINUE("read");

            msg.len = htons(nread + sizeof(msg.crc));
            msg.crc = htonl(get_crc32(msg.buffer, nread));
            if (cwrite(tss_if.ssl, tss_if.tcp_s, (void *)&msg, ntohs(msg.len) + sizeof(msg.len)) < 0) PERROR("cwrite msg");
        }
    } 

    return 0;  // never reached
}



/*----------------------------------------------------------------------------------*/
/*--- connect_tunnel - Connect to a new TSS Remote System                        ---*/
/*----------------------------------------------------------------------------------*/
void connect_tunnel(void)
{
    int opt;
    struct sockaddr_in tss_server_sin;
    struct sockaddr_in local_sin;
    socklen_t local_sinlen;
    struct ifreq ifr;
    struct hostent *host;
    
    {
        /******************************************************
        * Create virtual socket for the TSS IF                *
        ******************************************************/
        SSL_library_init();
        if ((tss_if.tap_s = open("/dev/net/tun",O_RDWR)) < 0) PERROR("open tunnel");
        memset(&ifr, 0, sizeof(ifr));
        ifr.ifr_flags = IFF_TAP | IFF_NO_PI;
        strncpy(ifr.ifr_name, tss_if.tss_ifname, IFNAMSIZ);
        if (ioctl(tss_if.tap_s, TUNSETIFF, (void *)&ifr) < 0) PERROR("ioctl TUNSETIFF");
       
        if (debug>=1) fprintf(stderr, "Allocated interface is [%s]\n", tss_if.tss_ifname);
   
        /*****************************************
        * Connect to the remote interface        *
        *****************************************/
    
        /* Socket for data routing to and from this if */
        /* Socket for controlling this if              */
        tss_if.tcp_s = socket(PF_INET, SOCK_STREAM, 0);
    
        tss_server_sin.sin_family = AF_INET;
        tss_server_sin.sin_port = htons(tss_if.conn_port);
        host = gethostbyname(tss_if.conn_ip);
        if (!host) ERROR("can't resolve [%s]\n",tss_if.conn_ip);
        tss_server_sin.sin_addr = *(struct in_addr *)host->h_addr;
        if (debug>=1) fprintf(stderr, "Connecting to %s:%i...\n", inet_ntoa(tss_server_sin.sin_addr), ntohs(tss_server_sin.sin_port));
        if (connect(tss_if.tcp_s, (struct sockaddr *)&tss_server_sin, sizeof(tss_server_sin)) < 0) PERROR("connect");
        
        opt=1; setsockopt(tss_if.tcp_s, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt));
        opt=1; setsockopt(tss_if.tcp_s, IPPROTO_TCP, TCP_QUICKACK, &opt, sizeof(opt));

        tss_if.ctx = InitCTX();
        tss_if.ssl = SSL_new(tss_if.ctx);               /* create new SSL connection state */
        SSL_set_fd(tss_if.ssl, tss_if.tcp_s);           /* attach the socket descriptor */

        if ( SSL_connect(tss_if.ssl) < 0 )   { ERR_print_errors_fp(stderr); abort(); }

        local_sinlen = sizeof(local_sin);
        getsockname(tss_if.tcp_s, (struct sockaddr *)&local_sin, &local_sinlen);
        if (debug>=1) fprintf(stderr, "(TSS Client) %s:%i \t<----->\t",inet_ntoa(local_sin.sin_addr), ntohs(local_sin.sin_port));
        if (debug>=1) fprintf(stderr, "%s:%i (TSS Server)\n",inet_ntoa(tss_server_sin.sin_addr), ntohs(tss_server_sin.sin_port));
  
        //set MAC
        memset(&ifr, 0, sizeof(ifr));
        strncpy(ifr.ifr_name, tss_if.tss_ifname, IFNAMSIZ);
    
        ifr.ifr_hwaddr.sa_family = AF_UNIX;
        ifr.ifr_hwaddr.sa_data[0] = 0x2a;
        ifr.ifr_hwaddr.sa_data[1] = 0x5f;
        ifr.ifr_hwaddr.sa_data[2] = 0x54;
        ifr.ifr_hwaddr.sa_data[3] = 0x53;
        ifr.ifr_hwaddr.sa_data[4] = 0x53;
        ifr.ifr_hwaddr.sa_data[5] = (0x5f);
        if (debug>=3) fprintf(stderr, "MAC Addr: %02x:%02x:%02x:%02x:%02x:%02x\n",
                              (unsigned char)ifr.ifr_hwaddr.sa_data[0],
                              (unsigned char)ifr.ifr_hwaddr.sa_data[1],
                              (unsigned char)ifr.ifr_hwaddr.sa_data[2],
                              (unsigned char)ifr.ifr_hwaddr.sa_data[3],
                              (unsigned char)ifr.ifr_hwaddr.sa_data[4],
                              (unsigned char)ifr.ifr_hwaddr.sa_data[5]);
        if (ioctl(tss_if.tcp_s, SIOCSIFHWADDR, (void*)&ifr) < 0) PERROR_CONTINUE("ioctl set: MAC Addr");
    
        //set MTU
        memset(&ifr, 0, sizeof(ifr));
        strncpy(ifr.ifr_name, tss_if.tss_ifname, IFNAMSIZ);
        ifr.ifr_mtu = 1500;
        if (debug>=3) fprintf(stderr, "MTU: %d\n", ifr.ifr_mtu);
        if (ioctl(tss_if.tcp_s, SIOCSIFMTU, (void*)&ifr) < 0) PERROR_CONTINUE("ioctl set: MTU");
    
        //set IP Address
        memset(&ifr, 0, sizeof(ifr));
        strncpy(ifr.ifr_name, tss_if.tss_ifname, IFNAMSIZ);
        ifr.ifr_addr.sa_family = AF_INET;
        struct sockaddr_in address;
        memset(&address, 0, sizeof (struct sockaddr_in));
        address.sin_family=AF_INET;
        address.sin_port=0;
        address.sin_addr.s_addr=inet_addr(tss_if.tss_ip);
        memcpy(&ifr.ifr_addr, &address, sizeof(struct sockaddr_in));
    
        if (debug>=3) fprintf(stderr, "IP Addr: %s\n", inet_ntoa(((struct sockaddr_in*)&ifr.ifr_addr)->sin_addr));
        if (ioctl(tss_if.tcp_s, SIOCSIFADDR, (void*)&ifr) < 0) PERROR_CONTINUE("ioctl set: IP Addr");
    
        //set IP Destination Address
        if (debug>=3) fprintf(stderr, "Dst IP Addr: %s\n", inet_ntoa(((struct sockaddr_in*)&ifr.ifr_dstaddr)->sin_addr));
        if (ioctl(tss_if.tcp_s, SIOCSIFDSTADDR, (void*)&ifr) < 0) PERROR_CONTINUE("ioctl set: Dst IP Addr");
    
        //set Broadcast Address
        memset(&ifr, 0, sizeof(ifr));
        strncpy(ifr.ifr_name, tss_if.tss_ifname, IFNAMSIZ);
        memset(&address, 0, sizeof (struct sockaddr_in));
        address.sin_family=AF_INET;
        address.sin_port=0;
        address.sin_addr.s_addr=inet_addr(tss_if.tss_bcast);
        memcpy(&ifr.ifr_addr, &address, sizeof(struct sockaddr_in));
        if (debug>=3) fprintf(stderr, "Broadcast Addr: %s\n", inet_ntoa(((struct sockaddr_in*)&ifr.ifr_broadaddr)->sin_addr));
        if (ioctl(tss_if.tcp_s, SIOCSIFBRDADDR, (void*)&ifr) < 0) PERROR_CONTINUE("ioctl set: Bcast Addr");
    
        //set Netmask
        memset(&ifr, 0, sizeof(ifr));
        strncpy(ifr.ifr_name, tss_if.tss_ifname, IFNAMSIZ);
        memset(&address, 0, sizeof (struct sockaddr_in));
        address.sin_family=AF_INET;
        address.sin_port=0;
        address.sin_addr.s_addr=inet_addr(tss_if.tss_netmask);
        memcpy(&ifr.ifr_addr, &address, sizeof(struct sockaddr_in));
        if (debug>=3) fprintf(stderr, "Netmask: %s\n", inet_ntoa(((struct sockaddr_in*)&ifr.ifr_netmask)->sin_addr));
        if (ioctl(tss_if.tcp_s, SIOCSIFNETMASK, (void*)&ifr) < 0) PERROR_CONTINUE("ioctl set: Netmask");
    
        // Bring up the new TSS interface
        memset(&ifr, 0, sizeof(ifr));
        strncpy(ifr.ifr_name, tss_if.tss_ifname, IFNAMSIZ);
        if (ioctl(tss_if.tcp_s, SIOCGIFFLAGS, (void*)&ifr) < 0) PERROR_CONTINUE("ioctl get: SIOCGIFFLAGS");
        ifr.ifr_flags |= IFF_UP | IFF_RUNNING;
        if (debug>=3) fprintf(stderr, "Bring interface %s up...\n", tss_if.tss_ifname);
        if (ioctl(tss_if.tcp_s, SIOCSIFFLAGS, (void *)&ifr) < 0) PERROR("ioctl set: SIOCSIFFLAGS");

            
        /**********************************************************
        *  Collect detailed setup/hardware info from remote IFs   *
        **********************************************************/
        {
            int j;
            for (j=0; j<6; j++)     /* we expect to receive 6 messages from each remote if */
            {
                bio_read(tss_if.ssl, tss_if.tcp_s, (unsigned char *)&(tss_if.req), sizeof(tss_if.req));
                bio_read(tss_if.ssl, tss_if.tcp_s, (unsigned char *)&(tss_if.ifr), 32);
            }
        }
        
        if (debug>=3) fprintf(stderr, "%s initialized.\n", tss_if.tss_ifname);
        if (debug>=2) fprintf(stderr, "Interface '%s' at %s successfully created.  Now up and running.\n", tss_if.tss_ifname, tss_if.tss_ip);
        if (debug>=1) fprintf(stderr, "\n");
    }
}
