import time
import socket
import struct
import json

try:
    from urllib.parse import urlparse
    from urllib.parse import parse_qs
except ImportError:
    from urlparse import urlparse

from . import pac_server
from . import global_var as g
from .socket_wrap import SocketWrap
import utils
from .smart_route import handle_ip_proxy, handle_domain_proxy, netloc_to_host_port
from xlog import getLogger
xlog = getLogger("smart_router")

SO_ORIGINAL_DST = 80


class ProxyServer():
    handle_num = 0

    def __init__(self, sock, client, args):
        self.conn = sock
        self.rfile = self.conn.makefile("rb", 0)
        self.wfile = self.conn.makefile("wb", 0)
        self.client_address = client

        self.read_buffer = b""
        self.buffer_start = 0
        self.support_redirect = True

    def try_redirect(self):
        if not self.support_redirect:
            return False

        try:
            dst = self.conn.getsockopt(socket.SOL_IP, SO_ORIGINAL_DST, 16)
        except:
            self.support_redirect = False
            return False

        try:
            dst_port, srv_ip = struct.unpack("!2xH4s8x", dst)
            ip_str = socket.inet_ntoa(srv_ip)

            if dst_port == g.config.proxy_port and utils.to_bytes(ip_str) in g.local_ips:
                return False

            xlog.debug("Redirect to:%s:%d from:%s", ip_str, dst_port, self.client_address)
            handle_ip_proxy(self.conn, ip_str, dst_port, self.client_address)
        except Exception as e:
            xlog.exception("redirect except:%r", e)

        return True

    def handle(self):
        self.__class__.handle_num += 1

        if not self.try_redirect():
            self.handle_request()

    def handle_request(self):
        try:
            socks_version = self.conn.recv(1, socket.MSG_PEEK)
            if not socks_version:
                return

            if socks_version == b"\x04":
                self.socks4_handler()
            elif socks_version == b"\x05":
                self.socks5_handler()
            elif socks_version == b"C":
                self.https_handler()
            elif socks_version in [b"G", b"P", b"D", b"O", b"H", b"T"]:
                self.http_handler()
            else:
                xlog.warn("socks version:%s[%s] not supported", socks_version, utils.str2hex(socks_version))
                return

        except socket.error as e:
            xlog.warn('socks handler read error:%r', e)
            self.conn.close()
        except Exception as e:
            xlog.exception("any err:%r", e)
            self.conn.close()

    def read_null_end_line(self):
        sock = self.conn
        sock.setblocking(0)
        try:
            while True:
                n1 = self.read_buffer.find(b"\x00", self.buffer_start)
                if n1 > -1:
                    line = self.read_buffer[self.buffer_start:n1]
                    self.buffer_start = n1 + 1
                    return line

                try:
                    data = sock.recv(8192)
                except socket.error as e:
                    # logging.exception("e:%r", e)
                    if e.errno in [2, 11, 10035]:
                        time.sleep(0.01)
                        continue
                    else:
                        raise e

                self.read_buffer += data
        finally:
            sock.setblocking(1)

    def read_crlf_line(self):
        sock = self.conn
        sock.setblocking(0)
        try:
            while True:
                n1 = self.read_buffer.find(b"\r\n", self.buffer_start)
                if n1 > -1:
                    line = self.read_buffer[self.buffer_start:n1]
                    self.buffer_start = n1 + 2
                    return line

                try:
                    data = sock.recv(8192)
                except socket.error as e:
                    # logging.exception("e:%r", e)
                    if e.errno in [2, 11, 10035]:
                        time.sleep(0.01)
                        continue
                    else:
                        raise e

                self.read_buffer += data
        finally:
            sock.setblocking(1)

    def read_headers(self):
        sock = self.conn
        sock.setblocking(0)
        try:
            while True:
                if self.read_buffer[self.buffer_start:] == b"\r\n":
                    self.buffer_start += 2
                    return ""

                n1 = self.read_buffer.find(b"\r\n\r\n", self.buffer_start)
                if n1 > -1:
                    block = self.read_buffer[self.buffer_start:n1]
                    self.buffer_start = n1 + 4
                    return block

                try:
                    data = sock.recv(8192)
                except socket.error as e:
                    # logging.exception("e:%r", e)
                    if e.errno in [2, 11, 10035]:
                        time.sleep(0.01)
                        continue
                    else:
                        raise e

                self.read_buffer += data
        finally:
            sock.setblocking(1)

    def read_bytes(self, size):
        sock = self.conn
        sock.setblocking(1)
        try:
            while True:
                left = len(self.read_buffer) - self.buffer_start
                if left >= size:
                    break

                need = size - left

                try:
                    data = sock.recv(need)
                except socket.error as e:
                    # logging.exception("e:%r", e)
                    if e.errno in [2, 11, 10035]:
                        time.sleep(0.01)
                        continue
                    else:
                        raise e

                if len(data):
                    self.read_buffer += data
                else:
                    raise socket.error("recv fail")
        finally:
            sock.setblocking(1)

        data = self.read_buffer[self.buffer_start:self.buffer_start + size]
        self.buffer_start += size
        return data

    def socks4_handler(self):
        # Socks4 or Socks4a
        sock = self.conn
        socks_version = ord(self.read_bytes(1))
        cmd = ord(self.read_bytes(1))
        if cmd != 1:
            xlog.warn("Socks4 cmd:%d not supported", cmd)
            return

        data = self.read_bytes(6)
        port = struct.unpack(">H", data[0:2])[0]
        addr_pack = data[2:6]
        if addr_pack[0:3] == b'\x00\x00\x00' and addr_pack[3:4] != b'\x00':
            domain_mode = True
        else:
            ip = socket.inet_ntoa(addr_pack)
            domain_mode = False

        user_id = self.read_null_end_line()
        if len(user_id):
            xlog.debug("Socks4 user_id:%s", user_id)

        if domain_mode:
            addr = self.read_null_end_line()
        else:
            addr = ip

        reply = b"\x00\x5a" + addr_pack + struct.pack(">H", port)
        sock.send(reply)

        # xlog.debug("Socks4:%r to %s:%d", self.client_address, addr, port)
        if domain_mode:
            handle_domain_proxy(sock, addr, port, self.client_address)
        else:
            handle_ip_proxy(sock, addr, port, self.client_address)

    def handle_udp_associate(self, sock, addr, port, addrtype_pack, addr_pack):
        udp_relay_port = g.dns_srv.udp_relay_port
        xlog.debug("socks5 from:%r udp associate to %s:%d use udp_relay_port:%d", self.client_address, addr, port, udp_relay_port)
        reply = b"\x05\x00\x00" + addrtype_pack + addr_pack + struct.pack(">H", udp_relay_port)
        sock.send(reply)

        self.rfile.read(1)
        xlog.debug("socks5 from:%r udp associate to %s:%d closed", self.client_address, addr, port)

    def socks5_handler(self):
        sock = self.conn
        socks_version = ord(self.read_bytes(1))
        auth_mode_num = ord(self.read_bytes(1))
        data = self.read_bytes(auth_mode_num)

        sock.send(b"\x05\x00")  # socks version 5, no auth needed.
        try:
            data = self.read_bytes(4)
        except Exception as e:
            xlog.debug("socks5 auth num:%d, list:%s", auth_mode_num, utils.str2hex(data))
            xlog.warn("socks5 protocol error:%r", e)
            return

        socks_version = ord(data[0:1])
        if socks_version != 5:
            xlog.warn("request version:%d error", socks_version)
            return

        command = ord(data[1:2])
        addrtype_pack = data[3:4]
        addrtype = ord(addrtype_pack)
        if addrtype == 1:  # IPv4
            addr_pack = self.read_bytes(4)
            addr = socket.inet_ntoa(addr_pack)
        elif addrtype == 3:  # Domain name
            domain_len_pack = self.read_bytes(1)[0:1]
            domain_len = ord(domain_len_pack)
            domain = self.read_bytes(domain_len)
            addr_pack = domain_len_pack + domain
            addr = domain
        elif addrtype == 4:  # IPv6
            addr_pack = self.read_bytes(16)
            addr = socket.inet_ntop(socket.AF_INET6, addr_pack)
        else:
            xlog.warn("request address type unknown:%d", addrtype)
            sock.send(b"\x05\x07\x00\x01")  # Command not supported
            return
        port = struct.unpack('>H', self.rfile.read(2))[0]

        if command == 3:  # 3. UDP associate
            return self.handle_udp_associate(sock, addr, port, addrtype_pack, addr_pack)

        if command != 1:  # 1. Tcp connect
            xlog.warn("request not supported command mode:%d", command)
            sock.send(b"\x05\x07\x00\x01")  # Command not supported
            return

        # xlog.debug("socks5 %r connect to %s:%d", self.client_address, addr, port)
        reply = b"\x05\x00\x00" + addrtype_pack + addr_pack + struct.pack(">H", port)
        sock.send(reply)

        if addrtype in [1, 4]:
            handle_ip_proxy(sock, addr, port, self.client_address)
        else:
            handle_domain_proxy(sock, addr, port, self.client_address)

    def https_handler(self):
        line = self.read_crlf_line()
        line = line
        words = line.split()
        if len(words) == 3:
            command, path, version = words
        elif len(words) == 2:
            command, path = words
            version = b"HTTP/1.1"
        else:
            xlog.warn("https req line fail:%s", line)
            return

        if command != b"CONNECT":
            xlog.warn("https req line fail:%s", line)
            return

        host, _, port = path.rpartition(b':')
        port = int(port)

        header_block = self.read_headers()
        sock = self.conn

        # xlog.debug("https %r connect to %s:%d", self.client_address, host, port)
        sock.send(b'HTTP/1.1 200 OK\r\n\r\n')

        handle_domain_proxy(sock, host, port, self.client_address)

    def http_handler(self):
        req_data = self.conn.recv(65537, socket.MSG_PEEK)
        rp = req_data.split(b"\r\n")
        req_line = rp[0]

        words = req_line.split()
        if len(words) == 3:
            method, url, http_version = words
        elif len(words) == 2:
            method, url = words
            http_version = b"HTTP/1.1"
        else:
            xlog.warn("http req line fail:%s", req_line)
            return

        if url.lower().startswith(b"http://"):
            o = urlparse(url)
            host, port = netloc_to_host_port(o.netloc)

            url_prex_len = url[7:].find(b"/")
            if url_prex_len >= 0:
                url_prex_len += 7
                path = url[url_prex_len:]
            else:
                url_prex_len = len(url)
                path = b"/"
        else:
            # not proxy request
            parsed_url = urlparse(utils.to_str(url))
            kv = parse_qs(parsed_url.query)
            if parsed_url.path == "/dns-query":
                return self.DoH_handler(kv)
            else:
                xlog.debug("PAC %s %s from:%s", method, url, self.client_address)
                handler = pac_server.PacHandler(self.conn, self.client_address, None, xlog)
                return handler.handle()

        sock = SocketWrap(self.conn, self.client_address[0], self.client_address[1])
        sock.replace_pattern = [url[:url_prex_len], b""]

        xlog.debug("http %r connect to %s:%d %s %s", self.client_address, host, port, method, path)
        handle_domain_proxy(sock, host, port, self.client_address)

    def DoH_handler(self, kv):
        handler = pac_server.PacHandler(self.conn, self.client_address, None, xlog)
        name = kv.get("name", [None])[0]
        if not name:
            xlog.warn("DoH request no name")
            return handler.send_response(content=b'{"error":"no name"}', status=400)

        dns_type = kv.get("type", ["1"])[0]
        if dns_type.isnumeric():
            dns_type = int(dns_type)

        ips = utils.to_str(g.dns_query.query(name, dns_type))
        info = {
            "Status": 0,
            "Answer": [
            ]
        }
        for ip in ips:
            if dns_type == 1 and not utils.check_ip_valid4(ip):
                continue
            if dns_type == 16 and not utils.check_ip_valid6(ip):
                continue

            info["Answer"].append({
                "name": name,
                "type": dns_type,
                "data": ip
            })
        res = json.dumps(info)
        headers = {
            "Content-Type": "application/dns-json",
            "Access-Control-Allow-Origin": "*"
        }
        return handler.send_response(content=res, headers=headers, status=200)
