// tunnel.c
//
// tcp connect, chat, and pass through - console based utility
//
// Copyright (c) 2001 Andrew McGill and Leading Edge Business Solutions
// You may use and distribute this code only under the terms of the 
// GNU General Public License (GPL).

#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>
#include <sys/types.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <string.h>

int clientsocket;

// Try to avoid those annoying TCP_WAIT states
void cli_cleanup(void)
{
	if (clientsocket) close(clientsocket);
}

// Open a TCP connection to the server
// Returns 0 on all-OK
int cli_serverconnect(char *servername, int port)
{
	struct protoent *pe;
	struct sockaddr_in addr_in;
	struct sockaddr_un addr_un;
	struct hostent *host;
	int rc;
	int bytes;
	char *error;
	static int init;

	if (!init) { atexit(cli_cleanup); init=1; }
	error = NULL;
	if (port) {
		// TCP connect
		memset(&addr_in,0,sizeof(addr_in));
		pe = getprotobyname("tcp");
		if (!pe) { error="getprotobyname"; goto oops; }
		clientsocket = socket(AF_INET,SOCK_STREAM,pe->p_proto);
		if (clientsocket==-1) { error="socket"; goto oops; }
		host = gethostbyname(servername);
		if (!host) { error="gethostbyname"; goto oops; }
		addr_in.sin_port = htons(port);
		addr_in.sin_family = AF_INET;
		addr_in.sin_addr = *( (struct in_addr *)(host->h_addr_list[0]));
		rc = connect(clientsocket, (struct sockaddr *)&addr_in, sizeof(addr_in));
	}
	else {
		// Unix socket connect
		memset(&addr_un,0,sizeof(addr_un));
		clientsocket = socket(PF_UNIX,SOCK_STREAM,0);
		if (clientsocket==-1) { error="socket"; goto oops; }
		addr_un.sun_family = AF_UNIX;
		strncpy( addr_un.sun_path, servername, sizeof(addr_un.sun_path) );
		rc = connect(clientsocket, (struct sockaddr *)&addr_un, sizeof(addr_un));
	}

	// We should really try from **h_addr_list
	if (rc) { error="connect"; goto oops; }
	bytes = 1;
	if (port)
		setsockopt(clientsocket,pe->p_proto, TCP_NODELAY, &bytes, sizeof(bytes));
	return 0;

oops:
	perror(error);
	if (clientsocket) close(clientsocket);
	clientsocket = 0;
	return 1;
}

void sendstring(char *string)
{
	int bytes;
#ifdef LOCALECHO
	char *p;
#endif

	if (*string==0) return;
	bytes = send(clientsocket, string, strlen(string), MSG_NOSIGNAL);
	if (bytes<=0) {
		perror("send");
		exit(1);
	}
#ifdef LOCALECHO
	for (p=string; *p; p++) {
		if (*p!='\r') fprintf(stderr,"%c",*p);
		else fprintf(stderr,"\n");
	}
#endif
}

// Wait for a character
void waitchar(char waitfor)
{
	int bytes;
	char received;

	do {
		// bytes = recv(clientsocket, buffer, 1, MSG_NOSIGNAL);
		bytes = read(clientsocket, &received, 1);
		if (bytes<=0) {
			perror("read");
			exit(1);
		}
		fprintf(stderr,"%c",received);
	} while (received!=waitfor);
}

// Wait for something that resembles the string presented
void waitchars(char *chars)
{
	char *p;
	for (p=chars; *p; p++) {
		waitchar(*p);
	}
}

void passthru(void)
{
           struct timeval tv;
           int retval;
	   fd_set rfds, wfds;
	   char buffer[64000];
	   int bytes;
 
           /* Watch stdin (fd 0) to see when it has input. */
           FD_ZERO(&rfds);
           FD_SET(0, &rfds);
           FD_SET(clientsocket, &rfds);
           FD_ZERO(&wfds);
           FD_SET(1, &wfds);
           FD_SET(clientsocket, &wfds);
           /* Wait up to five minutes */
           tv.tv_sec = 300;
           tv.tv_usec = 0;
 
           /* Don't rely on the value of tv now! */                             
           retval = select(5, &rfds, &wfds, NULL, &tv);
	   if (retval) {
		   bzero(buffer,sizeof(buffer));
		   if (FD_ISSET(0,&rfds) && FD_ISSET(clientsocket,&wfds)) {
			   bytes = read(0,buffer,sizeof(buffer)-1);
			   if (bytes>0) write (clientsocket,buffer,bytes);
		   }
		   bzero(buffer,sizeof(buffer));
		   if (FD_ISSET(clientsocket,&rfds) && FD_ISSET(1,&wfds)) {
			   bytes = read(clientsocket,buffer,sizeof(buffer)-1);
			   if (bytes>0) write (1,buffer,bytes);
		   }
	   }
}

// Do some basic escapes
void cescape(char *line)
{
	char *p, *q;

	q=p=line;
	for (p=line; *p; p++) {
		if (*p == '\n') {
			// ignore it
		}
		else if (*p=='\\' && *(p+1)) {
			p++;
			switch (*p) {
				case 'n': *q++ = '\n'; break;
				case 't': *q++ = '\t'; break;
				case 'r': *q++ = '\r'; break;
				case '\\': *q++ = '\\'; break;
				default: *q++=*p;
			}
		}
		else {
			*q++ = *p;
		}
	}
	*q=0;
}

//
int main(int argc, char *argv[])
{
	char line[352];
	FILE *cmdfile;

	if (argc!=4) {
		fprintf(stderr,"Usage: %s host port commandfile\n"
			"The file consists of lines alternating between what to\n"
			"expect, and what to send.\n", argv[0]);
		exit(1);
	}
	cmdfile = fopen(argv[3],"r");
	if (!cmdfile) {
		perror(argv[3]);
		exit(1);
	}
	// clientsocket = cli_serverconnect(argv[1], atoi(argv[2]));
	cli_serverconnect(argv[1],atoi(argv[2]));
	while (1) {
		// Read a line to expect
		if (fgets(line,sizeof(line),cmdfile)==NULL) break;
		cescape(line);
		waitchars(line);
		// Read a line to expect
		if (fgets(line,sizeof(line),cmdfile)==NULL) break;
		cescape(line);
		sendstring(line);
	}
	fprintf(stderr,"Passthrough:");
	while (1)
		passthru();
	return 0;
}


