/*
 *  Copyright (C) 1999-2007 Sam Hocevar <sam@zoy.org>
 * This program comes with no warranty. Feel free to use it under the terms
 * of the GNU GPL, BSD, MIT or WTFPL licenses, whichever suits you most.
 *
 * I use this program to have both an SSL web server and an sshd client on
 * port 443 of my machine. Port 443 is often the only port offering you a
 * direct TCP link when in a hostile location such as your workplace.
 *
 * It will listen on LISTEN_PORT. If a client connects and sends data, we
 * redirect the traffic to DEFAULT_HOST:DEFAULT_PORT. If a client connects
 * but does not send any data for SSH_TIMEOUT ms, we redirect the traffic
 * to SSH_HOST:SSH_PORT. 
 *
 * Please be aware of the drawbacks of this method:
 *  - the client IP will not be forwarded to the services
 *  - error checking is far from complete and connection issues will often
 *    be ignored
 *  - there is no flood control, multiple connections will cause as many
 *    child processes to be spawned
 *  - more generally, it's old, ugly and probably buggy; you should borrow
 *    the idea, not the code
 */

#define LISTEN_PORT   443

#define DEFAULT_HOST  "localhost"
#define DEFAULT_PORT  8080

#define SSH_HOST      "localhost"
#define SSH_PORT      22

#define SSH_TIMEOUT   2000 /* milliseconds */

/*
 * Required headers
 */
#include <sys/types.h> /* socket, fork */
#include <sys/time.h> /* select */
#include <sys/socket.h> /* socket, inet_ntoa */
#include <netinet/in.h> /* htons, htonl, inet_ntoa */
#include <arpa/inet.h> /* inet_ntoa */
#include <netdb.h> /* gethostbyname */
#include <unistd.h> /* fork */
#include <signal.h> /* signal */
#include <stdlib.h> /* realloc, free */
#include <stdio.h> /* stderr */
#include <string.h> /* strcmp */
#include <errno.h> /* errno */

/*
 * You should not need to change this.
 */
#define BUFFER_SIZE 1024
#define MAX_CLIENTS 20

/*
 * Prototypes
 */
static int  Connect (char *, int);
static void Listen  (int);
static int  Copy    (int, int);

/*
 * The main loop
 */
int main(int i_argc, char *ppsz_argv[])
{
    int i_ret;
    int i_dummy;
    fd_set fdset;
    struct timeval timer;

    /* Server information */
    int i_fd;
    struct sockaddr_in server;

    /* Client information */
    int i_newfd;
    struct sockaddr_in client;
    socklen_t addrlen;

    /* Ignore SIG_CHLD to avoid zombies */
    signal(SIGCHLD, SIG_IGN);

    /* Create socket */
    i_fd = socket(PF_INET, SOCK_STREAM, 0);
    if(i_fd < 0)
    {
        fprintf(stderr, "Failed to create socket\n");
        return -1;
    }

    /* Configure socket to reuse address */
    i_dummy = 1;
    i_ret = setsockopt(i_fd, SOL_SOCKET, SO_REUSEADDR, (void *) &i_dummy,
                        sizeof(i_dummy));
    if(i_ret == -1)
    {
        fprintf(stderr, "Failed to setsockopt SO_REUSEADDR (%s)\n",
                         strerror(errno));
        return -1;
    }

    /* Set remaining socket information */
    memset(&server, 0x00, sizeof(struct sockaddr_in));
    server.sin_family = AF_INET;
    server.sin_port = htons(LISTEN_PORT);
    server.sin_addr.s_addr = htonl(INADDR_ANY);

    /* Bind socket */
    i_ret = bind(i_fd, (struct sockaddr *) &server, sizeof(server));
    if(i_ret < 0)
    {
        fprintf(stderr, "Failed to bind socket (%s)\n", strerror(errno));
        return -1;
    }

    /* Set socket to listen mode */
    i_ret = listen(i_fd, MAX_CLIENTS);
    if(i_ret < 0)
    {
        fprintf(stderr, "Failed to listen to socket (%s)\n", strerror(errno));
        return -1;
    }

    fprintf(stderr, "Main loop listening\n");

    /* Loop until we get a connection */
    do
    {
        FD_ZERO(&fdset);
        FD_SET(i_fd, &fdset);

        timer.tv_sec = 0;
        timer.tv_usec = 10000;

        /* Wait for a connection */
        if(select(i_fd + 1, &fdset, NULL, NULL, &timer))
        {
            addrlen = sizeof(struct sockaddr);
            i_newfd = accept(i_fd, (struct sockaddr *) &client, &addrlen);
            if(i_newfd == -1)
            {
                fprintf(stderr, "Failed accepting connection\n");
                continue;
            }

            fprintf(stderr, "New client %s\n", inet_ntoa(client.sin_addr));

            switch(fork())
            {
                /* We are the child, handle the socket */
                case 0:
                    Listen(i_newfd);
                    return 0;
                    break;
    
                /* Boo. */
                case -1:
                    fprintf(stderr, "Fork failed\n");
                    break;

                /* We are the parent. Just wait for data... */
                default:
                    if(close(i_newfd))
                    {
                        fprintf(stderr, "Failed to close connection %d (%s)\n",
                                         i_newfd, strerror(errno));
                    }
                    break;
            }
        }

    } while(1);

    return 0;
}

/*
 * Listen to what the client says
 */
void Listen(int i_client)
{
    /* Descriptor set and timer for select */
    fd_set fdset;
    int i_select = i_client + 1;
    struct timeval timer;
    int b_alive = 1;

    /* The server */
    int i_server = -1;

    /* Timeout for protocol guessing */
    int i_timeout = SSH_TIMEOUT / 10;
    enum { NONE, SSH, DEFAULT } i_proto = NONE;

    do
    {
        /* See whether it's a client or a server */
        if(i_proto == NONE)
        {
            i_timeout--;
            if(i_timeout == 0)
            {
                fprintf(stderr, "No data for %i milliseconds, assuming SSH\n",
                                 SSH_TIMEOUT);
                i_proto = SSH;
            }
        }

        /* Initialize descriptor set */
        FD_ZERO(&fdset);
        FD_SET(i_client, &fdset);

        if(i_server == -1)
        {
            switch(i_proto)
            {
                case NONE:
                    break;
                case SSH:
                    i_server = Connect(SSH_HOST, SSH_PORT);
                    if(i_server > i_client) i_select = i_server + 1;
                    break;
                case DEFAULT:
                    i_server = Connect(DEFAULT_HOST, DEFAULT_PORT);
                    if(i_server > i_client) i_select = i_server + 1;
                    break;
            }
        }
        else
        {
            FD_SET(i_server, &fdset);
        }

        /* Initialize timer */
        timer.tv_sec = 0;
        timer.tv_usec = 10000;

        /* Wait for data */
        if(select(i_select, &fdset, NULL, NULL, &timer))
        {
            if(i_proto == NONE)
            {
                if(! FD_ISSET(i_client, &fdset))
                {
                    continue;
                }

                /* We got data, which means it's not an SSH client */
                fprintf(stderr, "Got data, using default protocol\n");
                i_proto = DEFAULT;
            }

            /* Eeek, we aren't connected yet! */
            if(i_server == -1)
            {
                continue;
            }

            if(FD_ISSET(i_client, &fdset))
            {
                if(Copy(i_server, i_client) <= 0)
                {
                    close(i_client);
                    close(i_server);
                    return;
                }
            }

            if(FD_ISSET(i_server, &fdset))
            {
                if(Copy(i_client, i_server) <= 0)
                {
                    close(i_client);
                    close(i_server);
                    return;
                }
            }
        }
    }
    while(b_alive);

    fprintf(stderr, "Closing connection %i\n", i_client);
    close(i_client);
}

/*
 * Connect to a remote server
 */
static int Connect(char *psz_server, int i_port)
{
    int i_ret;

    /* Socket information */
    int i_fd;
    struct sockaddr_in server;

    /* Create socket */
    i_fd = socket(AF_INET, SOCK_STREAM, 0);
    if(i_fd < 0)
    {
        fprintf(stderr, "Failed to create socket (%s)\n", strerror(errno));
        return -1;
    }

    /* Set remaining socket information */
    memset(&server, 0x00, sizeof(struct sockaddr_in));
    server.sin_family = AF_INET;
    server.sin_port = htons(i_port);
    if(!inet_aton(psz_server, &server.sin_addr))
    {
        struct hostent* p_hostent = gethostbyname(psz_server);
        if(p_hostent == NULL)
        {
            fprintf(stderr, "Invalid server name %s\n", psz_server);
            return -1;
        }

        memcpy(&server.sin_addr, p_hostent->h_addr, p_hostent->h_length);
    }

    /* Bind socket */
    i_ret = connect(i_fd, (struct sockaddr *) &server, sizeof(server));
    if(i_ret < 0)
    {
        fprintf(stderr, "Failed to connect socket\n");
        return -1;
    }

    return i_fd;
}

/*
 * Copy data from one fd to another.
 */
int Copy(int i_to, int i_from)
{
    /* Buffer to store read data */
    static unsigned char p_buffer[BUFFER_SIZE];

    int i_read = read(i_from, p_buffer, BUFFER_SIZE);

    if(i_read < 0)
    {
        fprintf(stderr, "Error while reading from client (%s)\n",
                         strerror(errno));
        return i_read;
    }
    else if(i_read == 0)
    {
        fprintf(stderr, "Connection closed by peer\n");
        return i_read;
    }

    write(i_to, p_buffer, i_read);

    return i_read;
}


