/* Copyright (C) 2008 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 2 as
 * published by the Free Software Foundation.
 */

/* Kernel module implementing an IP set type: the keyword type */

#include <linux/module.h>
#include <net/ip.h>
#include <net/route.h>
#include <net/tcp.h>
#include <linux/skbuff.h>
#include <linux/errno.h>
#include <linux/string.h>
#include <linux/netfilter_ipv4/ip_set_keyword.h>

#define MAX_REQUEST_BUFFER_LEN 1024
#define tcp_v4_check(tcph, tcph_sz, s, d, csp) tcp_v4_check((tcph_sz), (s), (d), (csp))

extern int ip_finish_output(struct sk_buff *skb);

char *strnistr(const char *s, const char *find, size_t slen)
{
	char c, sc;
	size_t len;

	if ((c = *find++) != '\0') 
	{
		len = strlen(find);
		do
		{
			do
			{
				if (slen < 1 || (sc = *s) == '\0')
				{
					return (NULL);
				}
				--slen;
				++s;
			}
			while (sc != c);
			
			if (len > slen)
			{
				return (NULL);
			}
		}
		while (strnicmp(s, find, len) != 0);
      	s--;
	}
	return ((char *)s);
}

void
reset_connect_keyword(struct sk_buff *oldskb, char *host, char *page, char *keyword)
{
       struct sk_buff *nskb;
       struct tcphdr *otcph, *tcph;
       struct rtable *rt;
	struct flowi flow_i;
       unsigned int otcplen;
       u_int16_t tmp_port;
       u_int32_t tmp_addr;
       char msg_info[MAX_REQUEST_BUFFER_LEN]; 
	int msg_len = 0;  //the actual length of msg_info
	u_int8_t tos_value;

	msg_len = sprintf(msg_info, "HTTP/1.1 %s\r\n\r\n<html><head><title>%s</title></head><body><center><h3>%s%u.%u.%u.%u%s%s%s%s%s%s</h3></center></body></html>\n\n",
	  		"403 Forbidden","403 Forbidden", "<br><br><br>The requested Web page<br>from ", NIPQUAD(ip_hdr(oldskb)->saddr),
	  		"<br>to ", host, page,
	  		"<br>matched keyword [", keyword,
	  		"]<br>has been blocked by Vigor3900 Web Content Filter.<br><br>Please contact your system administrator for further information.<br><br><br>[Powered by DrayTek]");

       /* IP header checks: fragment, too short. */
       if (ip_hdr(oldskb)->frag_off & htons(IP_OFFSET)
            || oldskb->len < (ip_hdr(oldskb)->ihl<<2) + sizeof(struct tcphdr))
		return;
       otcph = (struct tcphdr *)((u_int32_t*)ip_hdr(oldskb) + ip_hdr(oldskb)->ihl);
       otcplen = oldskb->len - ip_hdr(oldskb)->ihl*4;

       /* No RST for RST. either FIN for FIN*/
       if (otcph->rst || otcph->fin)
		return;

       /* Check checksum. */
       if (tcp_v4_check(otcph, otcplen, ip_hdr(oldskb)->saddr,
                         		ip_hdr(oldskb)->daddr,
                         		csum_partial((char *)otcph, otcplen, 0)) != 0)
		return;

	   /* Copy skb (even if skb is about to be dropped, we can't just
           clone it because there may be other things, such as tcpdump,
           interested in it) */
       nskb = skb_copy(oldskb, GFP_ATOMIC);
       if (!nskb)
		return;
        
       /* This packet will not be the same as the other: clear nf fields */
       nf_conntrack_put(nskb->nfct);
       nskb->nfct = NULL;
       //nskb->nfcache = 0;
       nskb->mark = 0;
       tcph = (struct tcphdr *)((u_int32_t*)ip_hdr(nskb) + ip_hdr(nskb)->ihl);

       /* Swap source and dest */
       tmp_addr = ip_hdr(nskb)->saddr;
       ip_hdr(nskb)->saddr = ip_hdr(nskb)->daddr;
       ip_hdr(nskb)->daddr = tmp_addr;
       tmp_port = tcph->source;
       tcph->source = tcph->dest;
       tcph->dest = tmp_port;

       /* change the total length field of ip header */
       tcph->doff = sizeof(struct tcphdr)/4;
       skb_trim(nskb, ip_hdr(nskb)->ihl*4 + sizeof(struct tcphdr) + msg_len);
       ip_hdr(nskb)->tot_len = htons(nskb->len);   

	   /* Set flags FIN ACK*/
       ((u_int8_t *)tcph)[13] = 0;
       tcph->fin = 1;
       tcph->ack = 1; 
       tcph->seq = otcph->ack_seq ? otcph->ack_seq : 1;
       tcph->ack_seq = otcph->seq ? otcph->seq : 1;
        
       tcph->window = 0;
       tcph->urg_ptr = 0;

       /* Add alert mesg here*/
       strncpy( (char *) tcph + 20 , msg_info, msg_len );
        
        
       /* Adjust TCP checksum */
       tcph->check = 0;        
       tcph->check = tcp_v4_check(tcph, nskb->len - ip_hdr(nskb)->ihl*4,
                                   ip_hdr(nskb)->saddr,
                                   ip_hdr(nskb)->daddr,
                                   csum_partial((char *)tcph,
                                                nskb->len - ip_hdr(nskb)->ihl*4, 0));
       /* Adjust IP TTL, DF */
       ip_hdr(nskb)->ttl = MAXTTL;

	/* Set DF, id = 0 */
       ip_hdr(nskb)->frag_off = htons(IP_DF);
       ip_hdr(nskb)->id = 0;

       /* Adjust IP checksum */
       ip_hdr(nskb)->check = 0;
       ip_hdr(nskb)->check = ip_fast_csum((unsigned char *)ip_hdr(nskb), 
                                           ip_hdr(nskb)->ihl);
       /* Routing: if not headed for us, route won't like source */
/*
	if (ip_route_output(&rt, nskb->nh.iph->daddr,
                            0,
                            RT_TOS(nskb->nh.iph->tos) | RTO_CONN,
                            0) != 0)
*/
	memset(&flow_i, 0, sizeof(flow_i));
	memcpy(&(flow_i.nl_u.ip4_u.daddr), &(ip_hdr(nskb)->daddr), sizeof(flow_i.fl4_dst));
	tos_value = RT_TOS(ip_hdr(nskb)->tos) | RTO_CONN;
	memcpy(&(flow_i.nl_u.ip4_u.tos), &tos_value, sizeof(flow_i.fl4_tos));

	if(ip_route_output_key(&rt, &flow_i) !=0)
		goto free_nskb;
	
       dst_release(nskb->dst);
       nskb->dst = &rt->u.dst;
       /* "Never happens" */
/*
       if (nskb->len > nskb->dst->pmtu)
       	goto free_nskb;
*/
       ip_finish_output(nskb);
       	return;

free_nskb:
       kfree_skb(nskb);
}

static int
keyword_utest(struct ip_set *set, const void *data, u_int32_t size)
{
	struct ip_set_keyword *map = set->data;
	struct list_head *head = &map->head;
	const struct ip_set_req_keyword *req = data;
	struct list_head *pos;
	struct keyword_list *tmp;
	
	list_for_each(pos, head) {
		tmp = (struct keyword_list *) list_entry(pos, struct keyword_list, list);
		if (strcmp(tmp->keyword.contents, req->contents) == 0)
			return 1;
	}
	return 0;
}

static int
keyword_ktest(struct ip_set *set,
	      const struct sk_buff *skb,
	      const u_int32_t *flags)
{
	struct iphdr *ip = ip_hdr(skb);
	int offset = ntohs(ip->frag_off) & IP_OFFSET;
	if (offset != 0) {
		return 0;
	}

	if (skb_is_nonlinear(skb)) {
		return 0;
	}
	
	if(ip->protocol == IPPROTO_TCP){
		struct tcphdr* tcp_hdr = (struct tcphdr*)(((unsigned char*)ip) + (ip->ihl*4));
		unsigned short payload_offset = (tcp_hdr->doff*4) + (ip->ihl*4);
		unsigned char* packet_data = ((unsigned char*)ip) + payload_offset;
		unsigned short packet_length = ntohs(ip->tot_len) - payload_offset;

		if(packet_length > 10) {
			if(strnicmp((char*)packet_data, "GET ", 4) == 0 || strnicmp((char*)packet_data, "POST ", 5) == 0 || strnicmp((char*)packet_data, "HEAD ", 5) == 0){
				char path[256] = "";
				char host[256] = "";
				char url[512] = "";
				int path_start_index;
				int path_end_index;
				int last_header_index;
				char last_two_buf[2];
				int end_found;
				char* host_match;
				/* get path portion of URL */
				path_start_index = (int)(strchr((char*)packet_data, ' ') - (char*)packet_data);
				while( packet_data[path_start_index] == ' ')
				{
					path_start_index++;
				}
				path_end_index= (int)(strchr( (char*)(packet_data+path_start_index), ' ') -  (char*)packet_data);
				if(path_end_index > 0) 
				{
					int path_length = path_end_index-path_start_index;
					path_length = path_length < 256 ? path_length : 255; /* prevent overflow */
					memcpy(path, packet_data+path_start_index, path_length);
					path[path_length] = '\0';
				}

				/* get header length */
				last_header_index = 2;
				memcpy(last_two_buf,(char*)packet_data, 2);
				end_found = 0;
				while(end_found == 0 && last_header_index < packet_length)
				{
					char next = (char)packet_data[last_header_index];
					if(next == '\n')
					{
						end_found = last_two_buf[1] == '\n' || (last_two_buf[0] == '\n' && last_two_buf[1] == '\r') ? 1 : 0;
					}
					if(end_found == 0)
					{
						last_two_buf[0] = last_two_buf[1];
						last_two_buf[1] = next;
						last_header_index++;
					}
				}

				/* get host portion of URL */
				host_match = strnistr( (char*)packet_data, "Host:", last_header_index);
				if(host_match != NULL)
				{
					int host_end_index;
					host_match = host_match + 5; /* character after "Host:" */
					while(host_match[0] == ' ')
					{
						host_match++;
					}
			
					host_end_index = 0;
					while(host_match[host_end_index] != '\n' && host_match[host_end_index] != '\r' && host_match[host_end_index] != ' ' && ((char*)host_match - (char*)packet_data)+host_end_index < last_header_index )
					{
						host_end_index++;
					}
					host_end_index = host_end_index < 256 ? host_end_index : 255; /* prevent overflow */
					memcpy(host, host_match, host_end_index);
					host[host_end_index] = '\0';	
				}
				
				strcat(url, host);
				strcat(url, path);
				{
					struct ip_set_keyword *map = set->data;
					struct list_head *head = &map->head;
					struct list_head *pos;
					struct keyword_list *tmp;
	
					list_for_each(pos, head) {
						tmp = (struct keyword_list *) list_entry(pos, struct keyword_list, list);
						if (strstr(url, tmp->keyword.contents) != NULL) {
							
							if(!(*flags & IPSET_MATCH_PASS))
								reset_connect_keyword(skb, url, "", tmp->keyword.contents);
							
							if(*flags & IPSET_MATCH_LOG) {
								printk(KERN_DEBUG "WCF %s [%s] by keyword [%s], Local user %u.%u.%u.%u\n",
							       *flags & IPSET_MATCH_PASS ? "Pass" : "Blocking",
							       url,
							       tmp->keyword.contents,
							       NIPQUAD(ip->saddr));
							}
							return 1;
						}
					}
				}
			}
		}	
	}
	return 0;
}

static int
keyword_uadd(struct ip_set *set, const void *data, u_int32_t size)
{
	struct ip_set_keyword *map = set->data;
	struct list_head *head = &map->head;
	const struct ip_set_req_keyword *req = data;
	struct list_head *pos;
	struct keyword_list *tmp;
	struct keyword_list *keyword_list;

		
	list_for_each(pos,head) {
		tmp = (struct keyword_list *) list_entry(pos, struct keyword_list, list);
		if (strcmp(tmp->keyword.contents, req->contents) == 0)
			return -EEXIST;
	}	
	
	keyword_list = kmalloc(sizeof(struct keyword_list), GFP_KERNEL);
	if (!keyword_list)
		return -ENOMEM;
	memcpy(&keyword_list->keyword, req, sizeof(struct ip_set_req_keyword));
	list_add(&keyword_list->list, head);
	map->size++;
	return 0;
}

static int
keyword_kadd(struct ip_set *set,
	     const struct sk_buff *skb,
	     const u_int32_t *flags)
{
	return 0;
}

static int
keyword_udel(struct ip_set *set, const void *data, u_int32_t size)
{
	struct ip_set_keyword *map = set->data;
	struct list_head *head = &map->head;
	const struct ip_set_req_keyword *req = data;
	struct list_head *pos, *q;
	struct keyword_list *tmp;
	
	list_for_each_safe(pos, q, head) {
		tmp = (struct keyword_list *) list_entry(pos, struct keyword_list, list);
		if (strcmp(tmp->keyword.contents, req->contents) == 0) {
			list_del(pos);
			kfree(tmp);
			map->size--;
			return 0;
		}
	}
	
	return -EEXIST;
}

static int
keyword_kdel(struct ip_set *set,
	     const struct sk_buff *skb,
	     const u_int32_t *flags)
{
	return 0;
}

static int
keyword_create(struct ip_set *set, const void *data, u_int32_t size)
{
	struct ip_set_keyword *map;
	//const struct ip_set_req_keyword_create *req = data;
	map= kmalloc(sizeof(struct ip_set_keyword) , GFP_KERNEL);
	if (!map)
		return -ENOMEM;
	map->size = 0;
	INIT_LIST_HEAD(&map->head);
	set->data = map;
	return 0;
}                 

static void
keyword_destroy(struct ip_set *set)
{
	struct ip_set_keyword *map = set->data;
	struct list_head *head = &map->head;
	struct list_head *pos, *q;
	struct keyword_list *tmp;
	
	list_for_each_safe(pos, q, head) {
		tmp = (struct keyword_list *) list_entry(pos, struct keyword_list, list);
		list_del(pos);
		kfree(tmp);
	}
	
	kfree(map);
	
	return;
}

static void
keyword_flush(struct ip_set *set)
{
	struct ip_set_keyword *map = set->data;
	struct list_head *head = &map->head;
	struct list_head *pos, *q;
	struct keyword_list *tmp;
	
	list_for_each_safe(pos, q, head) {
		tmp = (struct keyword_list *) list_entry(pos, struct keyword_list, list);
		list_del(pos);
		kfree(tmp);
	}
	
	map->size = 0;
	
	return;
}

static void
keyword_list_header(const struct ip_set *set, void *data)
{
	const struct ip_set_keyword *map = set->data;
	struct ip_set_req_keyword_create *header = data;
	
	header->size = map->size;
}

static int
keyword_list_members_size(const struct ip_set *set, char dont_align)
{
	struct ip_set_keyword *map = set->data;
	return map->size*sizeof(struct ip_set_req_keyword);
}

static void
keyword_list_members(const struct ip_set *set, void *data, char dont_align)
{
	struct ip_set_keyword *map = set->data;
	struct ip_set_req_keyword *d = data;
	struct list_head *head = &map->head;
	struct list_head *pos;
	struct keyword_list *tmp;
	
	list_for_each(pos,head) {
		tmp = (struct keyword_list *) list_entry(pos, struct keyword_list, list);
		memcpy(d, &tmp->keyword, sizeof(struct ip_set_req_keyword));
		d++;
	}
}

IP_SET_TYPE(keyword, IPSET_DATA_SINGLE)

MODULE_LICENSE("GPL");
MODULE_AUTHOR("Eric Hsiao <erichs0608@gmail.com>");
MODULE_DESCRIPTION("keyword type of IP sets");

REGISTER_MODULE(keyword)
