#    Openvpn Applet
#    Copyright (C) 2008 - 2010  Mikko Vartiainen <mvartiainen@gmail.com>
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.import socket

import socket
import time
import debug

log = debug.log

class Connection():
    def __init__(self,socket):
        self.socket_name = socket
        self.socket = None

    def __socket_connect__(self):
        if self.socket == None:
            self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
            try:
                self.socket.connect(self.socket_name)
                self.socket.settimeout(10)
            except socket.error, (value,message):
                log("socket.connect() failed:",value, message)
                self.socket = None
                return False
            return True
        return True

    def __socket_disconnect__(self):
        if self.socket != None:
            self.socket.close()
            self.socket = None

    def __admin_command__(self, command, get_response = True, disconnect = False, wait_for_ack = True):
        data = ''
        if self.socket == None: 
            for i in range(1,5):
                if self.__socket_connect__() == False:
                    time.sleep(0.2)
                else:
                    break

        try:
            bytes_sent = self.socket.send(command)
            #log("Command: "+command,bytes_sent)
        except socket.timeout, (message):
            log("socket.timeout", message)
            self.__socket_disconnect__()
            return ['ERROR:  socket timeout', '']
        except socket.error, (value,message):
            self.__socket_disconnect__()
            return ['ERROR: '+message, '']
        except AttributeError:
            self.__socket_disconnect__()
            return ['ERROR: '+"socket.connect() failed", '']

        if (get_response):
            try:
                data = self.__admin_recv__(wait_for_ack)
                log("Response:", data[-3:])
            except socket.timeout, (message):
                log("socket.timeout", message)
                self.__socket_disconnect__()
                return ['ERROR: socket timeout', '']

        if disconnect:
            self.__socket_disconnect__()
        return data

    def __admin_recv__(self, wait_for_ack=False):
        if self.socket == None: return

        data = ''
        while True:
            chunk = self.socket.recv(4096)
            data = data+chunk
            if wait_for_ack:
                if data.find("SUCCESS:") != -1: break
                if data.find("ERROR:") != -1: break
                if data.find("END") != -1: break
                if data.find(">FATAL") != -1: break
            else:
                lines = data.split("\r\n")
                if lines[-1] == '': break
        lines = data.split("\r\n")
        return lines

    def __parse_data__(self, data, type="normal"):
        if type == "normal":
            ret = ("ERROR", "unknown error")
            for line in data:
                if line[:7] == "SUCCESS":
                    if line[9:13] == "pid=":
                        return ("SUCCESS", int(line[13:]))
                    else:
                        ret = ("SUCCESS", line[8:])
                if line[:5] == "ERROR": ret = ("ERROR", line[6:])
                if line[:6] == ">FATAL": ret = ("ERROR", line[7:])
            return ret

        elif type == "password":
            for line in data:
                if line[:9] == ">PASSWORD":
                    if len(line.split()) > 3:
                        (a, auth_name,auth_name2, auth_type) = line.split()
                        auth_name = auth_name+" "+auth_name2
                    else:
                        (a, auth_name, auth_type) = line.split()
                    auth_name = auth_name[1:-1]
                    return (auth_name, auth_type)
            return ("ERROR", "")

    def disconnect(self):
        data = self.__admin_command__("signal SIGTERM\n")
        self.__socket_disconnect__()
        return self.__parse_data__(data)

    def connect(self):
        data = self.__admin_command__("hold off\n")
        data = self.__admin_command__("hold release\n")
        self.__socket_disconnect__()
        return self.__parse_data__(data)

    def restart(self):
        data = self.__admin_command__("signal SIGUSR1\n")
        return self.__parse_data__(data)

    def password(self,auth_name,password):
        password = password.replace('"','\\"')
        data = self.__admin_command__("password \""+auth_name+"\" \""+password+"\"\n")
        self.__socket_disconnect__()
        return self.__parse_data__(data)

    def username(self,auth_name,username):
        username = username.replace('"','\\"')
        data = self.__admin_command__("username \""+auth_name+"\" \""+username+"\"\n")
        self.__socket_disconnect__()
        return self.__parse_data__(data)

    def poll_password_request(self):
        data = self.__admin_command__("test 1\n",wait_for_ack=False)
        return self.__parse_data__(data, type="password")
        

    def get_log(self):
        data = "\n".join(self.__admin_command__("log all\n")[1:])
        while True:
            mindex = data.find("MANAGEMENT")
            if mindex == -1: break
            end = data.find("\n",mindex)
            if end == -1: end = len(data)
            start = data.rfind("\n",0,mindex)
            if start >= mindex or start == -1: start = 0
            data = data[:start+1]+data[end+1:]
    
        return data[:-4]

    def get_statistics(self):
        self.__socket_disconnect__()
        data = self.__admin_command__("status\n")
        self.__socket_disconnect__()
        return "\n".join(data[1:-2])

    def get_state(self):
        self.__socket_disconnect__()
        data = self.__admin_command__("state all\n")
        self.__socket_disconnect__()
        return "\n".join(data[1:-2])

    def get_pid(self):
        data = self.__admin_command__("pid\n")
        return self.__parse_data__(data)

