/*
 * Copyright (c) 2003, 2004 Nokia
 * Author: tsavola@movial.fi
 *
 * This program is licensed under GPL (see COPYING for details)
 */

#define _GNU_SOURCE

#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <string.h>
#include <unistd.h>
#include <netdb.h>
#include <errno.h>
#include <termios.h>
#include <signal.h>
#include <getopt.h>
#include <time.h>
#include <pwd.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <sys/stat.h>
#include <netinet/in.h>
#include "client.h"
#include "common.h"
#include "protocol.h"
#include "buffer.h"
#include "config.h"
#include "mount.h"

/* The environment of the process. */
extern char **environ;

/** The name of this program. */
static char *progname;

/** Path of the config file (command line option). */
static char *configpath = NULL;

/** The command line arguments. */
static char *target = NULL;
static char *curdir = "";
static char **args = NULL;

/** The socket fd. */
static int sd = -1;

/** Buffers used to copy data around. */
static buffer_t buf_out;
static buffer_t buf_err;
static char *tmp_buf = NULL;

/** The return code that will eventually be received from the server. */
static int16_t rc = 0;

/** Is the tty in raw mode? */
static bool_t rawmode = FALSE;

/** The original mode of the tty. */
static struct termios oldtio;

/** How much IN DATA has the server requested? */
static size_t inreq = 0;

/** Have we requested some data? */
static bool_t outwait = FALSE;
static bool_t errwait = FALSE;

/** The protocol version of the daemon. */
static int daemon_version = 0;

#ifdef DEBUG
static void debug(const char *msg,
		  ...)
{
	va_list arg;

	fprintf(stderr, "%s (%d) debug: ", progname, getpid());

	va_start(arg, msg);
	vfprintf(stderr, msg, arg);
	va_end(arg);

	fprintf(stderr, "\n");
	fflush(stderr);
}
#else
# define debug(msg, ...)
#endif

/*
 * Prints progname, message and errno description to stderr.
 */
void error(const char *msg,
	   ...)
{
	char *desc = NULL;
	va_list arg;

	if (errno > 0) {
		desc = strerror(errno);
	}

#ifdef DEBUG
	fprintf(stderr, "%s (%d): ", progname, getpid());
#else
	fprintf(stderr, "%s: ", progname);
#endif

	va_start(arg, msg);
	vfprintf(stderr, msg, arg);
	va_end(arg);

	if (desc) {
		fprintf(stderr, " (%s)", desc);
	}

	fprintf(stderr, "\n");
	fflush(stderr);
}

/**
 * Writes (some of) buffer to a file.
 * @param buf the buffer
 * @param fd the target file descriptor
 * @param wait is set to false when the buffer is empty
 * @return 0 on success, -1 on error
 */
static int write_buffer(buffer_t *buf,
			int fd,
			bool_t *wait)
{
	/* buf is never at EOF so fd never becomes -1. */
	if (buf_write_out(buf, &fd) < 0) {
		error(fd == STDOUT_FILENO ?
		      "Can't write buffer to stdout" :
		      "Can't write buffer to stderr");
		return -1;
	}

	if (buf_is_empty(buf)) {
		*wait = FALSE;
	}

	return 0;
}

/**
 * Gets the amount of data we can read from a file (which may be a tty).
 */
static ssize_t get_data_length(int fd,
			       size_t max)
{
	ssize_t len;

	len = max;

	if (isatty(fd)) {
		if (ioctl(fd, FIONREAD, &len) < 0) {
			error("Can't check tty for available data");
			return -1;
		}

		if (len < 0) {
			error("ioctl(tty) gave invalid read length");
			return -1;
		}

		if (len > max) {
			len = max;
		}
	}

	return len;
}

/**
 * Sends an IN DATA packet.
 * @return 0 usually, 1 if stdin reached EOF, -1 on error
 */
static int send_data(void)
{
	bool_t ok = FALSE;
	ssize_t len;

	len = get_data_length(STDIN_FILENO, BUFFER_SIZE);
	if (len < 0) {
		return -1;
	}
	if (len == 0) {
		return 0;
	}

	len = read(STDIN_FILENO, tmp_buf, len);
	if (len < 0) {
		error("Can't read from stdin");
		return -1;
	}

	if (len == 0) {
		debug("Stdin hit EOF");
		ok = 1;
		inreq = 0;
	}

	if (write_packet(sd, PTYPE_IN_DATA, tmp_buf, len) < 0) {
		error("Can't write IN DATA packet to socket");
		return -1;
	}

	/* At EOF: 0-0=0 */
	inreq -= len;

	return ok;
}

/**
 * Reads bytes from socket to a buffer.
 * @param len the maximum bytes to be read
 * @param buf the target buffer
 * @return 0 on success, -1 on error
 */
static int receive_stream(size_t len,
			  buffer_t *buf)
{
	if (len == 0) {
		error_noerr(buf == &buf_out ?
			    "Received empty OUT DATA packet (EOF)" :
			    "Received empty ERR DATA packet (EOF)");
		return -1;
	}

	if (buf_read_in(buf, sd, len) < 0) {
		error(buf == &buf_out ?
		      "Can't append OUT DATA packet to buffer" :
		      "Can't append ERR DATA packet to buffer");
		return -1;
	}

	return 0;
}

/**
 * Reads message from socket and prints it.
 * @param len the length of the message
 * @param type "error" or "warning"
 * @return 0 on success, -1 on error
 */
static int receive_message(size_t len,
			   ptype_t type)
{
	if (read_buf(sd, tmp_buf, len) < 0) {
		error("Can't read message packet");
		return -1;
	}

	if (len == BUFFER_SIZE) {
		len--;
	}
	tmp_buf[len] = '\0';

#ifdef DEBUG
	if (type == PTYPE_ERROR) {
		fprintf(stderr, "%s (%d) server: %s\n", progname,
			getpid(), tmp_buf);
	} else {
		fprintf(stderr, "%s (%d): %s\n", progname,
			getpid(), tmp_buf);
	}
#else
	if (type == PTYPE_ERROR) {
		fprintf(stderr, "%s server: %s\n", progname, tmp_buf);
	} else {
		fprintf(stderr, "%s: %s\n", progname, tmp_buf);
	}
#endif

	return 0;
}

/**
 * Reads a packet from the socket (sd) and does something about it.
 * The value of an RC packet is stored in the global rc variable.
 * @return 1 on RC, 0 on other valid packet type, -1 on error
 */
static int receive_packet(void)
{
	phead_t head;

	if (read_phead(sd, &head) < 0) {
		error("Can't read packet header from socket");
		return -1;
	}

	switch (head.type) {
	case PTYPE_IN_REQ:
		if (head.size) {
			error_noerr("IN REQ packet size is non-zero");
			return -1;
		}
		inreq = BUFFER_SIZE;
		return 0;

	case PTYPE_OUT_DATA:
		return receive_stream(head.size, &buf_out);

	case PTYPE_ERR_DATA:
		return receive_stream(head.size, &buf_err);

	case PTYPE_RC:
		rc = read_int16(sd);
		if (rc < 0) {
			error("Can't read RC packet from socket");
			return -1;
		}
		return 1;

	case PTYPE_ERROR:
		return receive_message(head.size, PTYPE_ERROR);

	case PTYPE_MESSAGE:
		return receive_message(head.size, PTYPE_MESSAGE);
	}

	error_noerr("Received packet has invalid type: %d", head.type);
	return -1;
}

/**
 * Sends a request for more data.
 * @param ptype PTYPE_OUT_REQ or PTYPE_ERR_REQ
 * @param wait is set to true
 * @return 0 on success, -1 on error
 */
static int send_request(ptype_t ptype,
			bool_t *wait)
{
	if (write_packet(sd, ptype, NULL, 0) < 0) {
		error("Can't write packet to socket");
		return -1;
	}

	*wait = TRUE;

	return 0;
}

/**
 * Read AUTH OK, RC, MESSAGE and ERROR packets from socket.
 * @return 0 for auth ok or -1 on rc/error
 */
static int get_auth_reply(void)
{
	phead_t head;

	while (1) {
		if (read_phead(sd, &head) < 0) {
			error("Can't read packet header from socket");
			return -1;
		}

		switch (head.type) {
		case PTYPE_AUTH_OK:
			/* There's no data involved. */
			return 0;

		case PTYPE_RC:
			if (read_int16(sd) < 0) {
				error("Can't read RC packet from socket");
				return -1;
			}
			/* We don't really care about the rc value here. */
			return -1;

		case PTYPE_ERROR:
			receive_message(head.size, PTYPE_ERROR);
			continue;

		case PTYPE_MESSAGE:
			receive_message(head.size, PTYPE_MESSAGE);
			continue;
		}

		error_noerr("Received packet has invalid type");
		return -1;
	}
}

/**
 * Reads RC, MESSAGE and ERROR packets from socket.
 * @return 0 if the global rc variable was set or -1 on error
 */
static int get_rc(void)
{
	phead_t head;

	while (1) {
		if (read_phead(sd, &head) < 0) {
			error("Can't read packet header from socket");
			return -1;
		}

		switch (head.type) {
		case PTYPE_RC:
			rc = read_int16(sd);
			if (rc < 0) {
				error("Can't read RC packet from socket");
				return -1;
			}
			return 0;

		case PTYPE_ERROR:
			receive_message(head.size, PTYPE_ERROR);
			continue;

		case PTYPE_MESSAGE:
			receive_message(head.size, PTYPE_MESSAGE);
			continue;
		}

		error_noerr("Received packet has invalid type");
		return -1;
	}
}

/**
 * Manages stdin/stdout/stderr and waits for the return code.
 * @return 0 on success, -1 on error
 */
static int manage(bool_t is_tty)
{
	fd_set readfds, writefds;
	bool_t inopen = TRUE;
	int ok;

	while (1) {
		FD_ZERO(&readfds);
		FD_ZERO(&writefds);

		FD_SET(sd, &readfds);

		if (inopen && inreq) {
			FD_SET(STDIN_FILENO, &readfds);
		}

		if (!buf_is_empty(&buf_out) || !outwait) {
			FD_SET(STDOUT_FILENO, &writefds);
		}

		if (!is_tty &&
		    (!buf_is_empty(&buf_err) || !errwait)) {
			FD_SET(STDERR_FILENO, &writefds);
		}

		if (select(sd + 1, &readfds, &writefds, NULL, NULL) <= 0) {
			error("Can't select");
			return -1;
		}

		/* Read packet from socket. */
		if (FD_ISSET(sd, &readfds)) {
			ok = receive_packet();
			if (ok < 0) {
				return -1;
			}
			if (ok > 0) {
				break;
			}
		}

		/* Send data from stdin if requested. */
		if (inopen && inreq && FD_ISSET(STDIN_FILENO, &readfds)) {
			ok = send_data();
			if (ok < 0) {
				return -1;
			}
			if (ok > 0) {
				inopen = FALSE;
			}
		}

		/* Flush buf_out to stdout or request more data. */
		if (FD_ISSET(STDOUT_FILENO, &writefds)) {
			if (buf_is_empty(&buf_out)) {
				ok = outwait || send_request(PTYPE_OUT_REQ,
							     &outwait) >= 0;
			} else {
				ok = write_buffer(&buf_out, STDOUT_FILENO,
						  &outwait) >= 0;
			}

			if (!ok) {
				return -1;
			}
		}

		/* Flush buf_err to stderr or request more data. */
		if (FD_ISSET(STDERR_FILENO, &writefds)) {
			if (buf_is_empty(&buf_err)) {
				ok = errwait || send_request(PTYPE_ERR_REQ,
							     &errwait) >= 0;
			} else {
				ok = write_buffer(&buf_err, STDERR_FILENO,
						  &errwait) >= 0;
			}

			if (!ok) {
				return -1;
			}
		}
	}

	return 0;
}

/**
 * Sets stdin to raw mode.
 */
static int set_raw_mode(void)
{
	struct termios tio;

	if (tcgetattr(STDIN_FILENO, &tio) < 0) {
		error("Can't get termios");
		return -1;
	}
	oldtio = tio;

	/* OpenSSH does something like this and it seems to work. */
	tio.c_iflag |= IGNPAR;
	tio.c_iflag &= ~(ISTRIP | INLCR | IGNCR | ICRNL | IXON | IXANY);
	tio.c_lflag &= ~(IXOFF | ISIG | ICANON | ECHO | ECHOE | ECHOK);
	tio.c_lflag &= ~(ECHONL | IEXTEN | OPOST);
	tio.c_cc[VMIN] = 1;
	tio.c_cc[VTIME] = 0;

	debug("Entering raw mode");

	if (tcsetattr(STDIN_FILENO, TCSADRAIN, &tio) < 0) {
		error("Can't change termios");
		return -1;
	}

	rawmode = TRUE;
	return 0;
}

/**
 * Sets stdin in its original mode.
 */
static void set_old_mode(void)
{
	if (!rawmode) {
		debug("Not in raw mode");
		return;
	}

	debug("Leaving raw mode");

	if (tcsetattr(STDIN_FILENO, TCSADRAIN, &oldtio) < 0) {
		error("Can't restore original termios");
	}
}

/** Calls exit() on error. */
static mount_info_t **parse_mounts(const char **opts,
				   bool_t umount)
{
	size_t len, i;
	mount_info_t **mounts;

	len = calc_vec_len((void **) opts);

	mounts = calloc(len + 1, sizeof (mount_info_t *));
	if (!mounts) {
		error_noerr(oom);
		exit(1);
	}

	for (i = 0; i < len; ++i) {
		mount_info_t *mi;

		mi = mntinfo_parse(opts[i]);
		if (!mi) {
			exit(1);
		}

		debug("type=%d device=%s point=%s opts=%s",
		      mi->type, mi->device, mi->point, mi->opts);

		if (!umount && MTYPE_NFS == mi->type) {
			if (mntinfo_stat_device(mi) < 0) {
				exit(1);
			}

			debug("       device_dev=%lld", mi->device_dev);
		}

		mounts[i] = mi;
	}

	return mounts;
}

static void usage(void)
{
	fprintf(stderr, "Usage: %s <target>"
				 " [-c|--config <path>]"
				 " [-d|--directory <dir>]"
				 " [<command>]"
				 " [<args>]\n"
			"       %s <target>"
				 " [-c|--config <path>]"
				 " --umount-all\n"
			"       %s -v|--version\n"
			"       %s -h|--help\n",
		progname, progname, progname, progname);
}

static bool_t should_skip_parameter(char *arg,
				    struct option *longopts)
{
	bool_t is_long;
	struct option *opt;

	is_long = strlen(arg) > 2;

	/* Check if arg also has its value. */
	if (is_long && strchr(arg, '=') != NULL) {
		return FALSE;
	}

	for (opt = longopts; opt->name; opt++) {
		if (is_long) {
			if (strcmp(&arg[2], opt->name) != 0) {
				continue;
			}
		} else {
			if (arg[1] != opt->val) {
				continue;
			}
		}

		return opt->has_arg == required_argument;
	}

	return FALSE;
}

static char **modify_args(int old_argc,
			  char **old_argv,
			  int *new_argc,
			  struct option *longopts)
{
	char **new_argv;
	int i = 0;

	new_argv = calloc(old_argc + 2, sizeof (char *));
	if (!new_argv) {
		error(oom);
		exit(1);
	}

	/* sbrsh's progname */
	new_argv[i++] = *old_argv++;

	/* target */
	if (old_argc > 1) {
		new_argv[i++] = *old_argv++;

		/* sbrsh's options and -- */
		while (i < old_argc) {
			char *arg = *old_argv++;

			if (arg[0] != '-') {
				new_argv[i++] = "--";
				new_argv[i++] = arg;
				break;
			}

			new_argv[i++] = arg;

			if (should_skip_parameter(arg, longopts) &&
			    i < old_argc) {
				new_argv[i++] = *old_argv++;
			}
		}

		/* command and its arguments */
		while (*old_argv) {
			new_argv[i++] = *old_argv++;
		}
	}

	*new_argc = i;
	return new_argv;
}

/**
 * Fills in the options from argv. Prints usage and exits on error.
 * @return true if we should --umount-all and not run a command
 */
static int read_args(int orig_argc,
		     char **orig_argv)
{
	char **argv;
	int argc;
	bool_t umount = FALSE;

	struct option longopts[] = {
		{ "help",       no_argument,       0,       'h'   },
		{ "version",    no_argument,       0,       'v'   },
		{ "config",     required_argument, 0,       'c'   },
		{ "directory",  required_argument, 0,       'd'   },
		{ "umount-all", no_argument,       &umount, TRUE  },
		{ 0 }
	};

	argv = modify_args(orig_argc, orig_argv, &argc, longopts);

	while (1) {
		int c;

		c = getopt_long(argc, argv, "hvc:d:", longopts, NULL);
		if (c < 0) {
			break;
		}

		switch (c) {
		case 'c':
			configpath = optarg;
			break;

		case 'd':
			curdir = optarg;
			break;

		case 0:
			/* --umount-all */
			break;

		case 'v':
			fprintf(stderr,
				"Scratchbox Remote Shell client " VERSION "\n"
				"Protocol version %d\n"
				"Compiled at %s %s\n",
				PROTOCOL_VERSION, __DATE__, __TIME__);
			exit(0);

		case 'h':
			usage();
			exit(0);

		default:
			usage();
			exit(1);
		}
	}

	if (optind >= argc) {
		usage();
		exit(1);
	}

	target = argv[optind];
	args = &argv[optind + 1];

	return umount;
}

/**
 * Fills in the port and password from $HOME "/" CONFIG_NAME.
 * User defaults to $USER and port to DEFAULT_PORT.
 */
static void read_user_config(config_t *cfg)
{
	char *home, *path = NULL;

	if (!configpath) {
		home = getenv("HOME");
		path = malloc(strlen(home) + strlen("/" CONFIG_NAME) + 1);
		if (!path) {
			error(oom);
			exit(1);
		}
		strcpy(path, home);
		strcat(path, "/" CONFIG_NAME);

		configpath = path;
	}

	if (!read_config(configpath, target, cfg)) {
		exit(1);
	}

	if (path) {
		free(path);
		configpath = NULL;
	}

	if (!cfg->user) {
		cfg->user = getenv("USER");

		if (!cfg->user) {
			struct passwd *pass;

			pass = getpwuid(getuid());
			if (!pass) {
				error("Can't get user info");
				exit(1);
			}

			cfg->user = strdup(pass->pw_name);
		}
	}

	if (!cfg->port) {
		cfg->port = DEFAULT_PORT;
	}
}

/**
 * Called when the process exits.
 */
static void cleanup(void)
{
	/* See that stdout and stderr buffers are flushed. */

	if (!buf_is_empty(&buf_out)) {
		set_nonblocking(STDOUT_FILENO, FALSE);
		write_buffer(&buf_out, STDOUT_FILENO, &outwait);
	}

	if (!buf_is_empty(&buf_err)) {
		set_nonblocking(STDERR_FILENO, FALSE);
		write_buffer(&buf_err, STDERR_FILENO, &errwait);
	}

	/* Terminal goes mental if we forget this. */
	set_old_mode();

	debug("sbrsh exiting");
}

/**
 * Called when signaled.
 */
static void sig_exit(int sig)
{
	exit(1);
}

/*
 * Reads options, resolves and connects to host, talks with it, returns rc.
 */
int main(int argc,
	 char **argv)
{
	config_t cfg;
	struct sigaction act;
	struct winsize ws;
	struct sockaddr_in addr;
	uint32_t ipaddr;
	int16_t umount, is_tty;
	mount_info_t **mounts;

	progname = get_progname(argv[0]);
	debug("sbrsh version " VERSION);

	init_config(&cfg);
	umount = read_args(argc, argv);
	read_user_config(&cfg);

	/* Buffers. */

	tmp_buf = malloc(BUFFER_SIZE);
	if (!tmp_buf) {
		error(oom);
		return 1;
	}

	if (!buf_init(&buf_out) ||
	    !buf_init(&buf_err)) {
		error(oom);
		return 1;
	}

	/* We really want to exit nicely if we touch the tty mode later on. */

	atexit(cleanup);

	act.sa_handler = sig_exit;
	sigemptyset(&act.sa_mask);
	act.sa_flags = SA_ONESHOT;

	sigaction(SIGINT, &act, NULL);
	sigaction(SIGHUP, &act, NULL);
	sigaction(SIGTERM, &act, NULL);

	/* Parse and do stuff with mount entries. */
	mounts = parse_mounts((const char **) cfg.opts, umount ? TRUE : FALSE);

	/* Should terminal emulation be supported? */
	is_tty = isatty(STDIN_FILENO) &&
		 isatty(STDOUT_FILENO) &&
		 isatty(STDERR_FILENO);
	debug(is_tty ? "TTY mode" : "Not TTY mode");

	if (!is_tty || ioctl(STDIN_FILENO, TIOCGWINSZ, &ws) < 0) {
		memset(&ws, 0, sizeof (ws));
	}

	/* Lookup host. */
	ipaddr = resolve(cfg.host);
	if (!ipaddr) {
		return 1;
	}

	/* Create socket. */
	sd = socket(PF_INET, SOCK_STREAM, 0);
	if (sd < 0) {
		error("Can't create socket");
		return 1;
	}

	/* Connect to server. */
	addr.sin_family = AF_INET;
	addr.sin_addr.s_addr = htonl(ipaddr);
	addr.sin_port = htons(cfg.port);

	if (!addr.sin_addr.s_addr) {
		return 1;
	}

	if (connect(sd, (struct sockaddr *) &addr, sizeof (addr)) < 0) {
		error("Can't connect");
		return 1;
	}

	/* Send protocol version. */
	if (send_version(sd) < 0) {
		error("Can't write protocol version packet to socket");
		return 1;
	}

	/* Wait for daemon's protocol version. */
	daemon_version = get_version(sd);
	if (daemon_version < 0) {
		error("Can't read protocol version packet from socket");
		return 1;
	}

	if (daemon_version < 3) {
		error_noerr("Server protocol version %d is too old",
			    daemon_version);
		return 1;
	}
	if (daemon_version != PROTOCOL_VERSION) {
		error_noerr("Server uses %s protocol version %d",
			daemon_version < PROTOCOL_VERSION ? "older" : "newer",
			daemon_version);
	}

	/* Temporary backward-compatibility hack */
	if (daemon_version == 3) {
		char **ptr, *str;

		for (ptr = environ; *ptr; ptr++) {
			if (strncmp(*ptr, "SBOX_ENV_", 9) == 0) {
				str = malloc(strlen(*ptr) + 2);
				if (!str) {
					error(oom);
					return 1;
				}

				strcpy(str, "SBRSH");
				strcat(str, *ptr + 4);

				error_noerr("Converting \"%s\" to \"%s\"",
					    *ptr, str);

				*ptr = str;
			}
		}
	}

	/* Send authentication info. */
	if (write_auth(sd, cfg.user, cfg.pwd) < 0) {
		error("Can't write authentication info to socket");
		return 1;
	}

	/* Wait for a reply... */
	if (get_auth_reply() < 0) {
		return 1;
	}

	debug("Authentication ok");

	/* Send command info. */
	if (write_cmd(sd, umount, target, (const mount_info_t **) mounts,
		      (const char **) args, curdir, (const char **) environ,
		      is_tty, &ws) < 0) {
		error("Can't write command info to socket");
		return 1;
	}

	/* We don't need this anymore. */
	free_config(&cfg);

	if (umount) {
		/* Just get the rc and possibly some error messages. */

		if (get_rc() < 0) {
			return 1;
		}
	} else {
		/* Handle the command execution. */

		if (is_tty && set_raw_mode() < 0) {
			return 1;
		}

		/* Non-blocking I/O. */

		if (set_nonblocking(STDOUT_FILENO, TRUE) < 0) {
			error("Can't make stdout non-blocking");
			return 1;
		}

		if (is_tty && set_nonblocking(STDERR_FILENO, TRUE) < 0) {
			error("Can't make stderr non-blocking");
			return 1;
		}

		/* Manage in/out/err and wait for the rc. */

		if (manage(is_tty) < 0) {
			return 1;
		}
	}

	if (rc == INTERNAL_ERROR_CODE) {
		debug("Internal server error");
		return 127;
	} else {
		debug("Return code: %d", rc);
		return rc;
	}
}
