/* DHCP Relay for 'DHCPv4 Configuration of IPSec Tunnel Mode' support 
 * Copyright (C) 2002 Mario Strasser <mast@gmx.net>, 
 *                    Zuercher Hochschule Winterthur,
 *                    Netbeat AG 
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation; either version 2 of the License, or (at your
 * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * for more details.
 *
 * RCSID $Id: dhcprelay.c,v 1.2 2002/08/22 09:41:30 sri Exp $
 */

#include <sys/socket.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/time.h>
#include <sys/ioctl.h>
#include <signal.h>
#include <netdb.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <net/if.h>
#include <unistd.h>
#include <string.h>
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <fcntl.h>
#include <time.h>
#include "dhcp.h"
#include "config.h"

/* constants */
#define SELECT_TIMEOUT  5           /* select timeout in sec. */
#define MAX_LIFETIME    2*60        /* lifetime of an xid entry in sec. */

/* debug macro */
#ifdef DEBUG
#define DBG(msg, args...) { char _ts[20] = {0}; time_t _t; struct tm _tm; \
                            time(&_t); localtime_r(&_t, &_tm);            \
                            strftime(_ts, 19, "%b %d %H:%M:%S", &_tm);    \
                            printf("%s: DEBUG: " __FILE__ ": ", _ts);     \
                            printf(msg, ## args); printf("\n"); }
#else
#define DBG(msg, args...) { }
#endif

/* log macro */
#ifndef DEBUG
#define LOG(msg, args...) { char _ts[20] = {0}; time_t _t; struct tm _tm; \
                            time(&_t); localtime_r(&_t, &_tm);            \
                            strftime(_ts, 19, "%b %d %H:%M:%S", &_tm);    \
                            printf("%s: " __FILE__ ": ", _ts);            \
                            printf(msg, ## args); printf("\n"); }
#else
#define LOG(msg, args...) { }
#endif

/* if defined the giaddr filed will be set to the device IP address,
   otherwise the GWs public IP address will be used. */
#define ADD_DEV_IP 1

/* stop flag, quit if set to '1' */
int stopflag = 0;

/* addresses and sockets */
char *server;                       /* DHCP server */
char *server_dev;                   /* device to the DHCP server */
char **client_dev = NULL;           /* devices to the DHCP clients */ 
struct sockaddr_in server_address;  /* server's address */
struct sockaddr_in client_address;  /* client's address */
struct sockaddr_in gw_address;      /* gateway's address */
struct in_addr gw_ip;               /* gateway's ip address */
int server_socket = 0;              /* server's socket */
int *client_socket = NULL;          /* clients' sockets */
int client_number;                  /* number of clients */

/* list which helds transaction ids */
struct xid_item {
    u_int32_t xid;
    struct sockaddr_in ip;
    int client;
    time_t timestamp;
    struct xid_item *next;
} xid_list = {0, {0}, 0, 0, NULL};

/**
 * signal_handler - handles signals ;-)
 * sig - sent signal
 */
void signal_handler(int sig)
{
    if (sig == SIGTERM) stopflag = 1;
    if (sig == SIGQUIT) stopflag = 1;
    if (sig == SIGINT) stopflag = 1;
    DBG("got signal %d", sig);
}

/**
 * install_signal_handler - installes the signal handlers
 * returns 0 on success, -1 otherwise
 */
int install_signal_handler()
{
    struct sigaction action;

    /* set SIGTERM handler */
    action.sa_handler = signal_handler;
    sigemptyset(&action.sa_mask);
    action.sa_flags = 0;
    if (sigaction(SIGTERM, &action, NULL) == -1)
        return -1;
    /* set SIGQUIT handler */
    action.sa_handler = signal_handler;
    sigemptyset(&action.sa_mask);
    action.sa_flags = 0;
    if (sigaction(SIGQUIT, &action, NULL) == -1)
        return -1;
    /* set SIGINT handler */
    action.sa_handler = signal_handler;
    sigemptyset(&action.sa_mask);
    action.sa_flags = 0;
    if (sigaction(SIGINT, &action, NULL) == -1)
        return -1;
    return 0;
}

/**
 * update_xid_list - removes aged entries 
 */
void update_xid_list()
{
    struct xid_item *item = xid_list.next;
    struct xid_item *last = &xid_list;
    time_t current_time = time(NULL);

    while (item != NULL) {
        if ((current_time-item->timestamp) > MAX_LIFETIME) {
            last->next = item->next;
            free(item);
            item = last->next;
        } else {
            last = item;
            item = item->next;
        }
    }
}

/**
 * get_dev_ip - gets the ip address of a device
 * name - name of the device
 * addr - address of the devce
 * returns 0 on success, -1 otherwise
 */
int get_dev_ip(char *name, struct in_addr *addr)
{
    int s;
    struct ifreq if_data;

    /* clear address */
    memset(addr, 0, sizeof(struct in_addr));
    /* get ip address of the device */
    s = socket (PF_INET, SOCK_DGRAM, 0);
    if (s == -1) {
        perror("socket() failed");
        return -1;
    }
    strcpy(if_data.ifr_name, name);
    if (ioctl(s, SIOCGIFADDR, &if_data) == -1) {
        perror("Can't get ip addres");
        return -1;
    }
    close(s);
    /* copy address and return */
    *addr = ((struct sockaddr_in*)&if_data.ifr_addr)->sin_addr;
    return 0;
}

/**
 * get_client_devices - parses the devices list
 * dev_list - comma separated list of devices 
 * returns 0 on success, -1 otherwise
 */
int get_client_devices(char *dev_list)
{
    char *s, *list;
    int i;
    
    /* copy list */
    list = strdup(dev_list);
    if (list == NULL) return -1;
    
    /* get number of items */
    for (s=dev_list, client_number=1; *s; s++) if (*s == ',') client_number++;

    /* alloc memory */
    client_dev = calloc(client_number, sizeof(*client_dev));
    if (client_dev == NULL) {
        free(list);
        return -1;
    }

    /* parse list */
    s = strtok(list, ",");
    i = 0;
    while (s != NULL) {
        client_dev[i++] = strdup(s);      
        s = strtok(NULL, ",");
    }
    
    /* free copy and exit */
    free(list);   
    return 0;
}

/**
 * init_sockets - creates all needed sockets
 * max_socket - greates sockt id
 * returns 0 on success, -1 otherwise
 */
int init_sockets(int *max_socket)
{
    int i, j, res, server_port, client_port;
    struct servent *service;
    struct hostent *host;

    /* allocate memory */
    client_socket = calloc(client_number, sizeof(*client_socket));
    if (client_socket == NULL) return -1;

    /* get port numbers */
    service = getservbyname("bootps", "udp");
    if (service == NULL) return -1;
    server_port = service->s_port;
    service = getservbyname("bootpc", "udp");
    if (service == NULL) return -1;
    client_port = service->s_port;

    /* set server's address */
    memset(&server_address, 0, sizeof(server_address));
    server_address.sin_family = AF_INET;
    server_address.sin_port = server_port;
    if (!inet_aton(server, &server_address.sin_addr)) {
        host = gethostbyname(server);
        if (host == NULL) return -1;
        memcpy(&server_address.sin_addr, host->h_addr_list[0],
            sizeof(server_address.sin_addr));
    }
    /* set client's address */
    memset(&client_address, 0, sizeof(client_address));
    client_address.sin_family = AF_INET;
    client_address.sin_port = client_port;
    client_address.sin_addr.s_addr = htonl(INADDR_ANY);
    /* set gateway's address */
    if (get_dev_ip(server_dev, &gw_ip) == -1) return -1;
    memset(&gw_address, 0, sizeof(gw_address));
    gw_address.sin_family = AF_INET;
    /* if the server is on the same host use port dhcpc (68) */
    if (!strcmp(server_dev, "lo")) gw_address.sin_port = client_port;
    else gw_address.sin_port = server_port;
    gw_address.sin_addr.s_addr = htonl(INADDR_ANY);

    DBG("DHCP server port:\t%d", ntohs(server_port));
    DBG("DHCP client port:\t%d", ntohs(client_port));
    DBG("Server:\t\t%s:%d", inet_ntoa(server_address.sin_addr),
        ntohs(server_address.sin_port));
    DBG("Client:\t\t%s:%d", inet_ntoa(client_address.sin_addr),
        ntohs(client_address.sin_port));

    /* create socket to the dhcp server */
    server_socket = socket(PF_INET, SOCK_DGRAM, 0);
    if (server_socket == -1) return -1;
    /* enable broadcast if needed */
    if (server_address.sin_addr.s_addr == htonl(INADDR_BROADCAST)) {
        i = 1;
        res = setsockopt(server_socket, SOL_SOCKET, SO_BROADCAST, 
            &i, sizeof(i));
        if (res == -1) return -1;
        DBG("Broadcast enabled!");        
    }
    /* bind socket to device */
    if (strlen(server_dev) < 4) {
        char tmp[4] = {0};
        memcpy(tmp, server_dev, strlen(server_dev));
        res = setsockopt(server_socket, SOL_SOCKET, SO_BINDTODEVICE, tmp, 4);
    } else {
        res = setsockopt(server_socket, SOL_SOCKET, SO_BINDTODEVICE,
            server_dev, strlen(server_dev) + 1); 
    }
    if (res == -1) return -1;
    /* bind socket */
    res = bind(server_socket, (struct sockaddr *)&gw_address, 
        sizeof(gw_address));
    if (res == -1) return -1;
    DBG("Gateway:\t\t%s:%d", inet_ntoa(gw_ip),
        ntohs(gw_address.sin_port));

    /* create sockets to the dhcp clients */
    *max_socket = 0;
    /* always listen on port bootps (67) */
    gw_address.sin_port = server_port;
    for (i=0; i < client_number; i++) {
        /* create socket */
        client_socket[i] = socket(PF_INET, SOCK_DGRAM, 0);
        if (client_socket[i] == -1) return -1;
        /* enable broadcast */
        j = 1;
        res = setsockopt(client_socket[i], SOL_SOCKET, SO_BROADCAST,
            &j, sizeof(j));
        if (res == -1) return -1;
        /* bind socket to device */
        if (strlen(client_dev[i]) < 4) {
            char tmp[4] = {0};
            memcpy(tmp, client_dev[i], strlen(client_dev[i]));
            res = setsockopt(client_socket[i], SOL_SOCKET, 
                SO_BINDTODEVICE, tmp, 4);
            if (res == -1) return -1;
            res = setsockopt(client_socket[i], SOL_SOCKET,
                SO_REUSEADDR, tmp, 4);
            if (res == -1) return -1;
        } else {
            res = setsockopt(client_socket[i], SOL_SOCKET, SO_BINDTODEVICE,
                client_dev[i], strlen(client_dev[i]) + 1);
            if (res == -1) return -1;
            res = setsockopt(client_socket[i], SOL_SOCKET, SO_REUSEADDR,
                client_dev[i], strlen(client_dev[i]) + 1);
            if (res == -1) return -1;
        }
        /* bind socket */
        res = bind(client_socket[i], (struct sockaddr *)&gw_address, 
            sizeof(gw_address));
        if (res == -1) return -1; 
        /* save biggest socket */
        if (client_socket[i] > *max_socket) *max_socket = client_socket[i];
    }

    return 0;    
}

/**
 * add_relay_agent_option() - adds the relay agent information option
 * p - dhcp packet
 * client - number of the client
 */
void add_relay_agent_option(struct dhcp_packet *p, int client)
{
    u_char *op;
    int oplen = 4 + strlen(client_dev[client]);

    /* create relay agent information option */
    op = malloc(oplen);
    if (op == NULL) {
        DBG("malloc() failed: %s", strerror(errno));
        LOG("malloc() failed: %s", strerror(errno));
        return;
    }
    op[0] = DHO_DHCP_AGENT_OPTIONS;
    op[1] = strlen(client_dev[client]) + 2;
    /* add agent circuit id sub-option */
    op[2] = 0x01;
    op[3] = strlen(client_dev[client]);
    memcpy(&op[4], client_dev[client], strlen(client_dev[client]));
    /* TODO: add agent remote id sub-option with the clients ipsec id */
    /* add option to the packet */
    if (add_dhcp_option(p, op, oplen)) {
        DBG("add_dhcp_option() failed");
        LOG("add_dhcp_option() failed");
    }
}

/**
 * pass_on() - forwards dhcp packets from client to server
 * p - packet to send
 * client - number of the client
 */
void pass_on(struct dhcp_packet *p, int client)
{
    int res;
    struct xid_item *item;
    int packet_len;

    DBG("pass_on() started");

    /* check packet_type */
    if (check_dhcp_packet(p, DHCPDISCOVER) == 0) {
        DBG("got a DHCPDISCOVER");
    } else if (check_dhcp_packet(p, DHCPREQUEST) == 0) {
        DBG("got a DHCPREQUEST");
    } else if (check_dhcp_packet(p, DHCPDECLINE) == 0) {
        DBG("got a DHCPDECLINE");
    } else if (check_dhcp_packet(p, DHCPRELEASE) == 0) {
        DBG("got a DHCPRELEASE");
    } else if (check_dhcp_packet(p, DHCPINFORM) == 0) {
        DBG("got a DHCPINFORM");
    } else {
        DBG("got a invalid dhcp packet");
        return;
    }

    /* create new xid entry */
    item = malloc(sizeof(struct xid_item));
    if (item == NULL) {
        DBG("malloc() failed: %s", strerror(errno));
        LOG("malloc() failed: %s", strerror(errno));
        return; 
    }
    /* add xid entry */
    item->ip = client_address;
    item->xid = p->xid;
    item->client = client;
    item->timestamp = time(NULL);
    item->next = xid_list.next;
    xid_list.next = item;

    /* add gateway address */
#ifdef ADD_DEV_IP
    if (get_dev_ip(client_dev[client], &p->giaddr) < 0) p->giaddr = gw_ip;
#else
    p->giaddr = gw_ip;
#endif

    /* add relay agent information option if it dosen't already exist */
    if (get_dhcp_option(p, DHO_DHCP_AGENT_OPTIONS) == NULL)
        add_relay_agent_option(p, client);
#ifdef DEBUG
    show_dhcp_packet(p, 3);
#endif

    /* get packet length */
    packet_len = get_dhcp_packet_len(p);
    /* forward request to LAN (server) */
    res = sendto(server_socket, p, packet_len, 0,
        (struct sockaddr*)&server_address, sizeof(server_address));
    if (res != packet_len) {
        DBG("sendto() failed or was not complete: %s", strerror(errno));
        LOG("sendto() failed or was not complete: %s", strerror(errno));
        return;
    }

    DBG("pass_on() stopped");
}

/**
 * pass_back() - forwards dhcp packets from server to client
 * p - packet to send
 */
void pass_back(struct dhcp_packet *p)
{
    int res, found_xid = 0;
    struct xid_item *item;
    struct xid_item *last;
    int packet_len;

    DBG("pass_back() started");
    
    /* check xid */
    item = xid_list.next;
    while (item != NULL) {
        if (item->xid == p->xid) {
            found_xid = 1;
            break;
        }
        item = item->next;
    }
    if (!found_xid) {
        DBG("got a invalid dhcp packet (xid)");
        return;
    }
    
    /* check packet type */
    if (check_dhcp_packet(p, DHCPOFFER) == 0) {
        DBG("got a DHCPOFFER");
    } else if (check_dhcp_packet(p, DHCPACK) == 0) {
        DBG("got a DHCPACK");
    } else if (check_dhcp_packet(p, DHCPNAK) == 0) {
        DBG("got a DHCPNAK");
    } else {
        DBG("got a invalid dhcp packet (type)");
        return;
    }
#ifdef DEBUG
    show_dhcp_packet(p, 3);
#endif

    /* get packet length */
    packet_len = get_dhcp_packet_len(p);

    /* restore client's address */
    client_address = item->ip;
    if (client_address.sin_addr.s_addr == htonl(INADDR_ANY))
        client_address.sin_addr.s_addr = htonl(INADDR_BROADCAST);
    /* forward request to client */
    res = sendto(client_socket[item->client], p, packet_len, 0,
        (struct sockaddr*)&client_address, sizeof(client_address));
    if (res != packet_len) {
        DBG("sendto() failed or was not complete: %s", strerror(errno));
        LOG("sendto() failed or was not complete: %s", strerror(errno));
        return;
    }

    /* remove xid entry */
    item = xid_list.next;
    last = &xid_list;
    while (item != NULL) {
        if (item->xid == p->xid) {
            last->next = item->next;
            free(item);
            item = last->next;
        } else {
            last = item;
            item = item->next;
        }
    }

    DBG("pass_back() stopped");
}

/**
 * release_all() - closes all open sockets and release all data 
 */
void release_all()
{
    int i;

    if (server_socket) close(server_socket);
    for (i=0; i < client_number; i++) {
        if (client_socket[i]) close(client_socket[i]);
    }
    if (client_dev != NULL) {
        for (i=0; i < client_number; i++) {
            if (client_dev[i]) free(client_dev[i]);
        }
        free(client_dev);
    }
}

int main(int argc, char **argv)
{
    int i, res;
    fd_set rfds;
    struct timeval tv;
    struct dhcp_packet dhcp_p;
    socklen_t addr_len;
    int max_socket;

    /* read arguments */
    if (argc == 3) {
        if (get_client_devices(argv[1]) == -1) {
            perror("get_client_devices() failed");
            exit(-1);
        }  
        server_dev = argv[2];
        server = "255.255.255.255"; 
    } else if (argc == 4) {
        if (get_client_devices(argv[1]) == -1) {
            perror("get_client_devices() failed");
            exit(-1);
        }
        server_dev = argv[2];
        server = argv[3];
    } else {
        printf("usage: %s <device list> <device to the dhcp server> " 
            "[dhcp server]\n", argv[0]);
        exit(0); 
    }
    for (i=0; i < client_number; i++) {
        DBG("%s started - forwarding from device %s to %s", argv[0],
            client_dev[i], server);
        LOG("%s started - forwarding from device %s to %s", argv[0],
            client_dev[i], server);
    }

    /* instal signal handler */
    if (install_signal_handler() == -1) {
        perror("sigaction() failed");
        exit(-1);
    }

    /* init sockets */
    if (init_sockets(&max_socket) == -1) {
        perror("init_sockets() failed");
        release_all();
        exit(-1);
    }

    /* main loop */
    while (!stopflag) {
        /* wait for incomming packets */
        FD_ZERO(&rfds);
        FD_SET(server_socket, &rfds);
        for (i=0; i < client_number; i++) FD_SET(client_socket[i], &rfds);
        tv.tv_sec = SELECT_TIMEOUT;
        tv.tv_usec = 0;
        res = select(max_socket + 1, &rfds, NULL, NULL, &tv);
        if (res == -1) {
            DBG("select() failed: %s", strerror(errno));
            LOG("select() failed: %s", strerror(errno));
        } else {
            /* got something from server */
            addr_len = sizeof(client_address);
            if (FD_ISSET(server_socket, &rfds)) {
                memset(&dhcp_p, 0, sizeof(dhcp_p));
                res = recvfrom(server_socket, &dhcp_p, sizeof(dhcp_p), 0,
                    (struct sockaddr*)&client_address, &addr_len);
                if (res == -1) {
                    DBG("read() failed: %s", strerror(errno));
                    LOG("read() failed: %s", strerror(errno));
                } else {
                    pass_back(&dhcp_p); 
                }
            }
            /* got something from a client */
            addr_len = sizeof(client_address);
            for (i=0; i < client_number; i++) {
                if (FD_ISSET(client_socket[i], &rfds)) {
                    memset(&dhcp_p, 0, sizeof(dhcp_p));
                    res = recvfrom(client_socket[i], &dhcp_p, sizeof(dhcp_p), 0, 
                        (struct sockaddr*)&client_address, &addr_len);
                    if (res == -1) {
                        DBG("read() failed: %s", strerror(errno));
                        LOG("read() failed: %s", strerror(errno));
                    } else {
                        pass_on(&dhcp_p, i);
                    }
                }
            }
        }
        /* update xid list */
        update_xid_list();
    }

    /* close sockets and release data */
    release_all();    
    
    DBG("%s stopped", argv[0]);  
    return 0;
}

