/* -*- Mode: C; c-basic-offset: 8; indent-tabs-mode: t -*- */
/**
  @file ic-compat-preload.c

  osso-ic-oss Internet Connectivity library
  Copyright (C) 2005 Nokia Corporation

  This library is free software; you can redistribute it and/or modify it
  under the terms of the GNU Lesser General Public License as published by
  the Free Software Foundation; either version 2.1 of the License, or (at
  your option) any later version.

  This library is distributed in the hope that it will be useful, but
  WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser
  General Public License for more details.

  You should have received a copy of the GNU Lesser General Public License
  along with this library; if not, write to the Free Software Foundation,
  Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*/

#define _GNU_SOURCE
#include <errno.h>
#include <stdio.h>
#include <dlfcn.h>
#include <resolv.h>
#include <unistd.h>
#include <signal.h>
#include <string.h>
#include <stdlib.h>
#include <pthread.h>

#include <netdb.h>
#include <sys/types.h>
#include <sys/socket.h>

#include <dbus/dbus.h>

#include <osso-ic.h>
#include <osso-ic-dbus.h>

#define DBUS_SYSTEM_BUS_DEFAULT_ADDRESS \
	"unix:path=/var/run/dbus/system_bus_socket"

static char *iap_name = NULL;

static const char * const domain_names[] = {
	"PF_UNSPEC", "PF_UNIX", "PF_INET", "PF_AX25", "PF_IPX",
	"PF_APPLETALK", "PF_NETROM", "PF_BRIDGE", "PF_ATMPVC",
	"PF_X25", "PF_INET6", "PF_ROSE", "PF_DECnet", "PF_NETBEUI",
	"PF_SECURITY", "PF_KEY", "PF_NETLINK", "PF_PACKET", "PF_ASH",
	"PF_ECONET", "PF_ATMSVC", "PF_SNA", "PF_IRDA", "PF_PPPOX",
	"PF_WANPIPE", "PF_BLUETOOTH"
};

static const char * const type_names[] = {
	"0", "SOCK_STREAM", "SOCK_DGRAM", "SOCK_RAW", "SOCK_RDM",
	"SOCK_SEQ_PACKET", "6", "7", "8", "9", "SOCK_PACKET"
};

#define tablesize(table) (sizeof(table)/sizeof(table[0]))

#define SAFE_LOOKUP(i,table) \
	(((i) < 0 || (i) >= tablesize(table)) ? "?" : (table[i]))

static int using_icapi;
static int (*old_socket)(int, int, int) = NULL;
static int (*old_close)(int) = NULL;
static void (*old_res_nclose)(res_state) = NULL;
static fd_set open_sockets;
static pthread_mutex_t mutex = PTHREAD_RECURSIVE_MUTEX_INITIALIZER_NP;

static void on_rt_signal(int signum, siginfo_t *info, void *ctx);

static void icpreld_init(void)
{
	struct sigaction act;

	using_icapi = dlsym(RTLD_NEXT, "osso_iap_cb") != NULL;
	old_socket = dlsym(RTLD_NEXT, "socket");
	old_close = dlsym(RTLD_NEXT, "close");
	old_res_nclose = dlsym(RTLD_NEXT, "__res_nclose");
	FD_ZERO(&open_sockets);

	act.sa_sigaction = &on_rt_signal;
	act.sa_flags = SA_SIGINFO;
	sigemptyset(&act.sa_mask);
	sigaction(SIGRTMAX-1, &act, NULL);

#ifdef DEBUG
	fprintf(stderr, "[%d] icpreld_init(): %d\n",
		getpid(), using_icapi);
#endif
}

static void on_iap_lost(void)
{
	if (iap_name)
		dbus_free(iap_name);
	iap_name = NULL;
}

static void on_rt_signal(int signum, siginfo_t *info, void *ctx)
{
	int i;

	if (info->si_code != SI_QUEUE || 
	    info->si_value.sival_int != 0xC0DE) {
#ifdef DEBUG
		fprintf(stderr, "[%d] Spurious signal (0x%x)",
			getpid(), info->si_value.sival_int);
#endif
		return;
	}

#ifdef DEBUG
	fprintf(stderr, "[%d] IAP fscked: killing active sockets.\n",
		getpid());
#endif

	for (i = 0; i < FD_SETSIZE; i++) {
		if (FD_ISSET(i, &open_sockets))
			shutdown(i, SHUT_RDWR);
	}

	FD_ZERO(&open_sockets);
	on_iap_lost();
}

static DBusConnection *get_connection()
{
	static DBusConnection *connection = NULL;

	if (connection == NULL) {
		connection = dbus_connection_open(
			DBUS_SYSTEM_BUS_DEFAULT_ADDRESS,
			NULL);

		if (!connection)
			return NULL;

		if (!dbus_bus_register(connection, NULL)) {
			dbus_connection_disconnect(connection);
			dbus_connection_unref(connection);
			connection = NULL;
			return NULL;
		}
	}

	return connection;
}

static DBusMessage *send_message(DBusMessage *msg,
				 int timeout)
{
	DBusConnection *conn;
	DBusMessage *reply;

	conn = get_connection();
	reply = dbus_connection_send_with_reply_and_block(conn,
							  msg,
							  timeout,
							  NULL);
	return reply;
}

static void send_message_no_reply(DBusMessage *msg)
{
	DBusConnection *conn;

	conn = get_connection();
	dbus_message_set_no_reply(msg, TRUE);
	dbus_connection_send(conn, msg, NULL);
}


static int connect_iap_blocking(const char *iap)
{
	DBusMessage *msg;
	DBusMessage *reply;

	msg = dbus_message_new_method_call(
		ICD_DBUS_SERVICE,
		ICD_DBUS_PATH,
		ICD_DBUS_INTERFACE,
		ICD_CONNECT_REQ);
	if (msg == NULL)
		return -1;

	if (!dbus_message_append_args(msg,
				      DBUS_TYPE_STRING, iap,
				      DBUS_TYPE_UINT32, 0x0,
				      DBUS_TYPE_UINT32, getpid(),
				      DBUS_TYPE_INVALID)) {
		dbus_message_unref(msg);
		return -1;
	}
	reply = send_message(msg, 3*60*1000);
	dbus_message_unref(msg);

	if (reply == NULL)
		return -1;

	if (!dbus_message_get_args(reply, NULL,
				   DBUS_TYPE_STRING, &iap_name,
				   DBUS_TYPE_INVALID)) {
		dbus_message_unref(reply);
		return -1;
	}
	dbus_message_unref(reply);

	return 0;
}

static int disconnect_iap_async(const char *iap)
{
	DBusMessage *msg;

	msg = dbus_message_new_method_call(
		ICD_DBUS_SERVICE,
		ICD_DBUS_PATH,
		ICD_DBUS_INTERFACE,
		ICD_DISCONNECT_REQ);
	if (msg == NULL)
		return -1;

	if (!dbus_message_append_args(msg,
				      DBUS_TYPE_STRING, iap,
				      DBUS_TYPE_INVALID)) {
		dbus_message_unref(msg);
		return -1;
	}
	send_message_no_reply(msg);
	dbus_message_unref(msg);

	return 0;
}

/* Overridden version of socket().
 * When creating PF_INET or PF_INET6 socket make sure that we have
 * an active IAP for the process. */
int socket(int domain, int type, int protocol)
{
	int rc;

	if (old_socket == NULL)
		icpreld_init();
        
	rc = old_socket(domain, type, protocol);
	if (using_icapi)
		return rc;

	if (rc < 0)
		return rc;

	if (domain == PF_INET || domain == PF_INET6) {
		pthread_mutex_lock(&mutex);

#ifdef DEBUG
		fprintf(stderr, "[%d] socket(%s, %s, %d)=%d\n",
		       getpid(), SAFE_LOOKUP(domain, domain_names),
		       SAFE_LOOKUP(type, type_names),
		       protocol, rc);
#endif

		/* Track this socket */
		FD_SET(rc, &open_sockets);

		if (iap_name == NULL) {
			/* First socked created -> connect */
			char *wanted_iap = getenv("IAP_NAME");
			if (wanted_iap == NULL)
				wanted_iap = OSSO_IAP_ANY;

#ifdef DEBUG
			fprintf(stderr, "[%d] connecting to IAP [%s]... ",
				getpid(), wanted_iap);
			fflush(stderr);
#endif

			connect_iap_blocking(wanted_iap);

			if (iap_name == NULL) {
				old_close(rc);

				FD_CLR(rc, &open_sockets);
				errno = EHOSTUNREACH;
				rc = -1;

#ifdef DEBUG
				fprintf(stderr, "FAILED.\n");
#endif
			} else {
#ifdef DEBUG
				fprintf(stderr, "OK, got [%s].\n", iap_name);
#endif
			}
		}
		pthread_mutex_unlock(&mutex);
	}

	return rc;
}

static int close_without_disconnect(int fd)
{
#ifdef DEBUG
	fprintf(stderr, "[%d] close_without_disconnect(%d)\n", getpid(), fd);
#endif

	if (FD_ISSET(fd, &open_sockets)) {
		FD_CLR(fd, &open_sockets);
		return 1;
	}

	return 0;
}

static void disconnect_if_no_sockets(void)
{
	int i, n = 0;

	for (i = 0; i < FD_SETSIZE; i++)
		n += FD_ISSET(i, &open_sockets);

	if (n == 0 && iap_name != NULL) {
#ifdef DEBUG
		/* Last socket closed and API not used -> disconnect */
		fprintf(stderr, "[%d] disconnecting IAP... ", getpid());
		fflush(stderr);
#endif

		disconnect_iap_async(iap_name);
		on_iap_lost();

#ifdef DEBUG
		fprintf(stderr, "done.\n");
#endif
	}
}

/* Overridden version of close().
 * Update socket book keeping and close the IAP if this was there are
 * no more open sockets. */
int close(int fd)
{
	int rc;

	if (old_close == NULL)
		icpreld_init();

	rc = old_close(fd);
	if (rc < 0 || using_icapi)
		return rc;

	if (close_without_disconnect(fd))
		disconnect_if_no_sockets();

	return rc;
}

/* Overridden version of __res_nclose().
 * Needed here because of b0rked glibc/resolv implementation which does
 * not allow hooking of close(). Also after name query we don't release
 * the IAP even if there are no more sockets since the application most
 * likely will next try to connect to the just queried host.
 * WARNING: this is an ugly HACK and will get busted if/when the glibc
 * resolv implementation changes. */
void __res_nclose(res_state statp)
{
	int ns;

#ifdef DEBUG
	fprintf(stderr, "[%d] res_nclose(%p)\n", getpid(), statp);
#endif

	if (statp->_vcsock >= 0)
		close_without_disconnect(statp->_vcsock);

	for (ns = 0; ns < MAXNS; ns++)
		if (statp->_u._ext.nsaddrs[ns] &&
		    statp->_u._ext.nssocks[ns] != -1)
			close_without_disconnect(statp->_u._ext.nssocks[ns]);

	old_res_nclose(statp);
}
