import base64
import decimal
import errno
import itertools
import logging
import multiprocessing
import os
import platform
import random
import re
import shutil
import socket
import stat
import string
import subprocess
import sys
import textwrap
import time
import hashlib
from getpass import getpass
from typing import Union, NoReturn

import netifaces
import paramiko
import pkg_resources
import prettytable
import psutil
from dns import resolver, reversename
from paramiko.sftp_client import SFTPClient
from prettytable import PrettyTable
from scp import SCPClient
from termcolor import colored

from rdaf import CliException

logger = logging.getLogger(__name__)

# An exit code of 5 from sshpass command implies Invalid/incorrect password
_SSHPASS_CMD_INCORRECT_PASSWORD_EXIT_CODE = 5


class cliSSH:
    def __init__(self, **params):

        client = paramiko.SSHClient()
        client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        client.load_system_host_keys()
        ssh_params = {'hostname': params.get('host'),
                      'username': params.get('user'),
                      'timeout': 30}
        if params.get('user', False):
            ssh_params['password'] = params.get('password')
        if params.get('keyfile', False) and params.get('keyfile'):
            ssh_params['key_filename'] = params.get('keyfile')
        if params.get('passphrase', False):
            ssh_params['password'] = params.get('passphrase')
        client.connect(**ssh_params)
        self.session = client

    def executeCommand(self, command):
        stdin, stdout, stderr = self.session.exec_command(command,
                                                          get_pty=True)
        stdin.flush()
        '''
        if stdout.channel.recv_exit_status():
        print(stderr.read())
       '''
        return stdout.channel.recv_exit_status()

    def scp_put_file(self, local_file_path, remote_file_path='.'):
        scp_client = SCPClient(self.session.get_transport())
        try:
            scp_client.put(local_file_path, remote_file_path)
        finally:
            scp_client.close()

    def scp_get_file(self, remote_file_path, local_file_path=''):
        scp_client = SCPClient(self.session.get_transport())
        try:
            scp_client.get(remote_file_path, local_file_path)
        finally:
            scp_client.close()

    def is_remote_file_exists(self, remote_file_path):
        sftp_client = SFTPClient.from_transport(self.session.get_transport())
        try:
            sftp_client.stat(remote_file_path)
            return True
        except IOError:
            return False

    def transfer_file(self, local_file: os.path, remote_file: str, mkdirs=True, sudo=False):
        if mkdirs:
            dir_path = os.path.dirname(remote_file)
            command = 'mkdir -p ' + dir_path
            if sudo:
                command = 'sudo ' + command
            self.command_over_ssh_session(command)
        sftp_client = SFTPClient.from_transport(self.session.get_transport())
        try:
            sftp_client.put(local_file, remote_file)
        finally:
            sftp_client.close()

    def transfer_dir(self, local_dir: os.path, remote_dir: str, sudo=False):
        # create the target remote dir first and then recursively copy
        # the local dir contents into it
        command = 'mkdir -p ' + remote_dir
        if sudo:
            command = 'sudo ' + command
        self.command_over_ssh_session(command)
        sftp_client = SFTPClient.from_transport(self.session.get_transport())
        try:
            for file_name in os.listdir(local_dir):
                f = os.path.join(local_dir, file_name)
                if os.path.isdir(f):
                    self.transfer_dir(f, remote_dir + '/' + file_name + '/')
                else:
                    sftp_client.put(f, remote_dir + '/' + file_name)
        finally:
            sftp_client.close()

    def fetch(self, remote_file: str, local_file: os.path, mkdirs=True, sudo=False):
        if mkdirs:
            dir_path = os.path.dirname(local_file)
            command = 'mkdir -p ' + dir_path
            if sudo:
                command = 'sudo ' + command
            self.command_over_ssh_session(command)
        sftp_client = SFTPClient.from_transport(self.session.get_transport())
        try:
            sftp_client.get(remote_file, local_file)
        finally:
            sftp_client.close()

    def fetch_dir(self, remote_dir: str, local_dir: os.path):
        sftp_client = SFTPClient.from_transport(self.session.get_transport())
        # check if remote dir exists
        try:
            sftp_client.stat(remote_dir)
        except IOError as e:
            if e.errno == errno.ENOENT:
                return
            raise
        # create local dir
        os.makedirs(local_dir, exist_ok=True)

        for file_name in sftp_client.listdir(remote_dir):
            f = remote_dir + '/' + file_name
            try:
                f_stat = sftp_client.stat(f)
            except IOError as e:
                if e.errno == errno.ENOENT:
                    logger.debug('No such (remote) file ' + str(f))
                    continue
                raise
            if stat.S_ISDIR(f_stat.st_mode):
                self.fetch_dir(f + '/', os.path.join(local_dir, file_name))
            else:
                sftp_client.get(f, os.path.join(local_dir, file_name))

    def command_over_ssh_session(self, command):
        stdin, stdout, stderr = self.session.exec_command(command, get_pty=True)
        stdin.close()
        '''
        if stdout.channel.recv_exit_status():
            print(stderr.read())
        '''
        for line in iter(lambda: stdout.readline(2048), ""):
            print(line.strip('\n'))
            sys.stdout.flush()
        return stdout.channel.recv_exit_status()

    def command_over_ssh_session_v2(self, command, environment=None):
        stdin, stdout, stderr = self.session.exec_command(command, get_pty=True,
                                                          environment=environment)
        stdin.close()
        exit_code = stdout.channel.recv_exit_status()
        stdout_content = stdout.read().decode(encoding='UTF-8')
        stderr_content = stderr.read().decode(encoding='UTF-8')
        return exit_code, stdout_content, stderr_content

    def close(self):
        self.session.close()


class cliHost(object):
    def __init__(self, host, check_ssh_port_reachable=False):
        self.name = host
        self.ipv4_addr = None
        self.fqdn = None
        self.islocal = False
        self.reachable = False

        if isIPAddress(host):
            self.ipv4_addr = host
            try:
                addr = reversename.from_address(host)
                self.fqdn = str(resolver.query(addr, "PTR")[0]).strip('.')
            except Exception:
                pass
        else:
            self.fqdn = host
            if isHostResolvable(host):
                self.ipv4_addr = getIPAddress(host)

        # local_ips = [ip for ip in socket.gethostbyname_
        # ex(socket.gethostname())[2] if not ip.startswith("127.")][:1]
        reserved_interaces = ['docker0', 'lo']
        local_ips = [netifaces.ifaddresses(iface)[netifaces.AF_INET][0]['addr']
                     for iface in netifaces.interfaces()
                     if
                     iface not in reserved_interaces and netifaces.AF_INET in
                     netifaces.ifaddresses(iface)]

        if self.ipv4_addr:
            if check_ssh_port_reachable:
                self.reachable = isSocketOpen(ip=self.ipv4_addr, port=22)
            if self.ipv4_addr in local_ips:
                self.islocal = True

def cleanup_stdout(output):
    """Given the output of rdac command, remove any prefix/suffix stdout and sanitize it to proper json"""
    m = re.search(r'^\s*[\[\{]', output, re.MULTILINE)
    if m:
        start = m.start()
    else:
        return output
    
    reverse_str = output[::-1]
    m = re.search(r'[\]\}]\s*$', reverse_str, re.MULTILINE)
    if m:
        end = m.start()
    else:
        # Sometimes, the string is [] or {} and the ] or } may not be at the begining of the line.
        if output[start:].strip() == "[]" or output[start:].strip() == "{}":
            return output[start:].strip()
        return output
    return output[start:(len(output) - end + 1)]

def query_yes_no(question, default=None):
    '''
    valid = {"yes": True, "y": True, "ye": True,
             "no": False, "n": False}
    '''
    valid = {"yes": True,
             "no": False, "n": False}

    if default is None:
        prompt = " [yes/no]: "
    elif default == "yes":
        prompt = " [Yes/no]: "
    elif default == "no":
        prompt = " [yes/No]: "
    else:
        raise ValueError("invalid default answer: '%s'" % default)

    while True:
        question = colored(question, attrs=['bold'])
        sys.stdout.write(question + prompt)
        choice = input().lower().strip()
        if default is not None and choice == '':
            return valid[default]
        elif choice in valid:
            return valid[choice]
        else:
            sys.stdout.write("Please respond with 'yes' or 'no|n' "
                             "\n")


def isIPAddress(input):
    return re.match(
        r"^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.)"
        "{3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$",
        input)


def execScript(**params):
    try:
        subprocess.check_call(params.get('command'),
                              stdin=params.get('stdin', None),
                              stdout=params.get('stdout', None),
                              stderr=params.get('stderr', None),
                              shell=params.get('shell', False),
                              cwd=params.get('cwd', "."),
                              env=params.get('env', dict(os.environ)))
    except subprocess.CalledProcessError as e:
        if params.get('noexit', False):
            return 0
        sys.exit("execShell: {}".format(str(e)))
    return 1


def prompt_and_validate(question, default, password=False,
                              help_desc_banner=None, apply_password_validation=True,
                              lower_case_input=True, password_min_length=6):
    if help_desc_banner:
        print(textwrap.dedent(help_desc_banner))
    prompt = question
    if default is not None:
        prompt = prompt + "[" + default + "]: "
    else:
        prompt = prompt + ": "

    default_chosen = False
    while True:
        if password:
            choice = getpass(prompt).strip()
            if choice == '' and default is not None:
                choice = default
                default_chosen = True
        else:
            sys.stdout.write(prompt)
            choice = input().lower().strip() if lower_case_input else input().strip()
        if choice == '' and not password:
            return default if default is not None else ''
        else:
            if not password:
                return choice
            # validate the password
            if apply_password_validation:
                valid, reason = validate_password(choice, password_min_length=password_min_length)
                if not valid:
                    print('Not a valid password, reason - ' + reason)
                    # let it continue in the loop to prompt for the password again
                    continue
            if not default_chosen:
                # valid pass, ask for it to be re-entered only if the user had
                # explicitly entered the value
                second_attempt = getpass('Re-enter ' + prompt).strip()
                if second_attempt != choice:
                    print('Passwords don\'t match. Please retry')
                    continue
                return choice
            else:
                return choice


def printProgress(num_seconds):
    if num_seconds:
        start_time = time.time()
        end_time = 0
        spinner = itertools.cycle(['-', '/', '|', '\\'])
        while end_time - start_time < num_seconds:
            sys.stdout.write(next(spinner))
            sys.stdout.flush()
            sys.stdout.write('\b')
            time.sleep(0.5)
            end_time = time.time()


def mkdir_p(path):
    try:
        os.makedirs(path)
    except OSError as e:
        if e.errno != errno.EEXIST and os.path.isdir(path):
            raise


def getIPAddress(hostname):
    try:
        addr = socket.gethostbyname(hostname)
        if isIPAddress(addr):
            logger.debug("IP address for host: %s is %s" % (hostname, addr))
            return addr
        else:
            return None
    except socket.error:
        return None


def isHostResolvable(hostname):
    myResolver = resolver.Resolver()
    try:
        myAnswers = myResolver.query(hostname)
    except Exception as e:
        logger.debug("Failed to resolve the hostname: " + hostname)
        return False
    for rdata in myAnswers:
        logger.debug("Resolved Attributes: " + str(rdata))
    return True


def isSocketOpen(**params):
    from contextlib import closing
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
        sock.settimeout(params.get('timeout', 5))
        result = sock.connect_ex((params.get('ip'), params.get('port')))
        if result == 0:
            return True
        else:
            return False


def isNotEmpty(s):
    return bool(s and s.strip())


def isInteger(input):
    return re.match("^(\d+)$", input)


def checkForIPAndPort(endpoint):
    endpoint = endpoint.split(':')
    # Check for IP and port in the format <host>:<port>
    if len(endpoint) != 2:
        print("\nPlease enter the end point as <host>:<port>\n")
        return False
    # Check for non empty values.
    if not isNotEmpty(endpoint[0]) or not isNotEmpty(endpoint[1]):
        print("\nHost or Port is empty. Please enter the end point as <host>:<port>\n")
        return False
    # Check for DNS resolvability of Host.
    myip = getIPAddress(endpoint[0])
    if myip is None:
        print("\nHostname '" + endpoint[
            0] + "' cannot be resolved. Please check your DNS settings and re-enter.\n")
        return False
    # Check for valid port.
    if not isInteger(endpoint[1]):
        print("\nPort is not valid. Please enter the end point as <host>:<port>\n")
        return False
    if int(endpoint[1]) < 0 or int(endpoint[1]) > 65535:
        print("\nPort is not valid. Please enter the end point as <host>:<port>\n")
        return False

    return True


def printPretty(printList, sortby=None, reversesort=None, fields=[]):
    # pop the header out.
    tblHeader = printList.pop(0)
    prettyTbl = PrettyTable(tblHeader)
    prettyTbl.sortby = sortby
    if reversesort:
        prettyTbl.reversesort = reversesort
    prettyTbl.align = "l"  # Left align
    prettyTbl.padding_width = 2
    #  prettyTbl.sortby = tblHeader[0]
    for i in printList:
        prettyTbl.add_row(i)
    print(prettyTbl.get_string(fields=fields))
    return 0


def systemMemory():
    memory = psutil.virtual_memory()
    mem_total_kb = memory.total / 1024
    mem_available_kb = memory.available / 1024
    return mem_total_kb, mem_available_kb


def printSystemInformation():
    printList = []
    printList.append(["System", "Details"])
    printList.append(['OS', platform.system()])
    printList.append(['Kernel', platform.release()])
    distribution = platform.linux_distribution()
    printList.append(['Distribution', "Name:    " + distribution[0]])
    printList.append(['', "Version: " + distribution[1]])
    printList.append(['', "Type:    " + distribution[2]])
    total, available = systemMemory()
    printList.append(['Memory Total', str(total) + " KB"])
    printList.append(['Memory Available', str(available) + " KB"])
    printList.append(['CPU Cores', str(multiprocessing.cpu_count())])
    print("\nSystem Details: ")
    printPretty(printList)


def printError(*args, **kwargs):
    nargs = list(args)
    nargs.insert(0, '\033[91m')
    nargs.append('\033[0m')
    print(*tuple(nargs),
          file=sys.stderr,
          **kwargs)


def validate_password(password, password_min_length=6):
    banned_passwords = ['Nl71dm31@12']
    if password is None:
        return False, 'Password cannot be empty'
    if password in banned_passwords:
        return False, 'Password belongs to banned set of passwords'
    if len(password.strip()) == 0:
        return False, 'Password cannot be empty'
    if password[0] in '"!@#$%^&*()-+?_=,<>/"':
        return False, 'Password cannot start with a special character'
    if password_min_length > len(password):
        return False, 'Password length must be at least ' + str(password_min_length) + ' characters'
    return True, 'Valid password'


def str_base64_encode(s, encoding='utf-8'):
    if s is None:
        return None
    b64_encoded_bytes = base64.b64encode(bytes(s, encoding))
    return str(b64_encoded_bytes, encoding=encoding)


def str_base64_decode(s, encoding='utf-8'):
    if s is None:
        return None
    b64_decoded_bytes = base64.b64decode(bytes(s, encoding))
    return str(b64_decoded_bytes, encoding=encoding)


def bytes_base64_encode(b, encoding='utf-8'):
    if b is None:
        return None
    b64_encoded_binary_string = base64.b64encode(b)
    return str(b64_encoded_binary_string, encoding=encoding)


def bytes_base64_decode(b, encoding='utf-8'):
    if b is None:
        return None
    b64_decoded_binary_string = base64.b64encode(b)
    return str(b64_decoded_binary_string, encoding=encoding)


def print_tabular(column_headers, rows, col_max_width=None,
                  show_row_separator=False, add_row_spacing=False,
                  target_file=None):
    pretty_table = PrettyTable(column_headers)
    pretty_table.align = 'l'
    if show_row_separator:
        pretty_table.hrules = prettytable.ALL
    if col_max_width is not None:
        pretty_table.max_width = col_max_width
    else:
        # temporarily disable this, till we find a way to display important
        # column values in a single line, instead of wrapping the value to next line

        # compute the max possible column width of each column
        num_columns = len(column_headers)
        # the "borders" of each column (the "|" character)
        total_col_separators = num_columns + 1
        # the space character (padding) between the column's border and the
        # start and end of column/cell text (each column will have one space
        # at the beginning and one at the end as the padding character)
        total_col_padding_chars = num_columns * 2
        # total number of columns on the terminal
        try:
            # don't rely on shutil.get_terminal_size() since that returns unusable
            # values when the output is piped.
            # more details at https://groups.google.com/d/topic/comp.lang.python/v9VszdDzpdE
            # Use os.get_terminal_size instead and then default to a fixed value if
            # it raises an error
            terminal_size_cols = os.get_terminal_size()[0]
        except OSError:
            terminal_size_cols = 104
        # total number of columns that are usable for column text (excluding the
        # column borders and the padding)
        usable_terminal_cols = terminal_size_cols - total_col_separators - total_col_padding_chars
        # how much text can each column hold
        per_col_max_width = int(
            decimal.Decimal(usable_terminal_cols / num_columns).to_integral_exact(
                rounding=decimal.ROUND_CEILING))
        pretty_table.max_width = per_col_max_width
        logger.debug('Using column max width ' + str(per_col_max_width) + ' for '
                     + str(num_columns) + ' columns on a terminal with ' + str(terminal_size_cols)
                     + ' columns')
    for row in rows:
        pretty_table.add_row(row)
        if add_row_spacing:
            # add an empty row to add spacing between rows
            empty_row = []
            for c in row:
                empty_row.append('')
            pretty_table.add_row(empty_row)
    print(pretty_table.get_string(), file=target_file)


def create_row_for_tabular_display(num_columns):
    if num_columns <= 0:
        return None
    # create a list with num_columns as the number of entries, each
    # initialized to ''
    return [''] * num_columns


def cli_err_exit(message) -> NoReturn:
    # we just raise the CliException so that the CLI can then
    # decide how to deal with it
    raise CliException(message)

def gen_password_with_uuid(uuid, num_chars=10):
    if num_chars < 1:
        return None
    hash_object = hashlib.sha256(uuid.encode())
    hash_hex = hash_object.hexdigest()
    chars = string.ascii_letters + string.digits
    password = ''.join(chars[int(hash_hex[i], 16) % len(chars)] for i in range(num_chars))
    return password 

def gen_password(num_chars=10, allowed_special_chars=[]):
    if num_chars < 1:
        return None
    chars = string.ascii_letters + string.digits
    if allowed_special_chars is not None:
        for c in allowed_special_chars:
            chars += c
    # shuffle the characters first
    char_list = list(chars)
    random.shuffle(char_list)
    chars = ''.join(char_list)

    generated = ''
    for _ in range(num_chars):
        generated += random.choice(chars)
    return generated


def prompt_host_name(help_desc_banner, prompt_string, default_value, multiple_hosts=False) \
        -> Union[str, list]:
    if help_desc_banner:
        print(textwrap.dedent(help_desc_banner))
    while True:
        host_name = prompt_and_validate(prompt_string if prompt_string is not None else
                                        'Fully qualified domain name for host',
                                        default_value if default_value is not None else '')
        if host_name == '':
            continue

        if multiple_hosts:
            host_names = delimited_to_list(host_name)
            for host in host_names:
                validate_host_name(host)
            return host_names
        else:
            validate_host_name(host_name)
            return host_name


def validate_host_name(host_name, check_ssh_port_reachable=False):
    resolved_host = cliHost(host_name)
    if not resolved_host.ipv4_addr:
        cli_err_exit('Cannot resolve a IPv4 address for ' + host_name)
        return
    if check_ssh_port_reachable and not resolved_host.reachable:
        logger.warning(host_name + ' is not reachable on port 22')
        return


def delimited_to_list(val, delimiter=','):
    r"""
    :type   val: str
    :type   delimiter: str
    """
    if val is None:
        return None
    parts = val.split(sep=delimiter)

    def _is_not_empty(s):
        return len(s) != 0

    return list(filter(_is_not_empty, map(str.strip, parts)))


def user_group_ids():
    """
    Returns a tuple of (userid, groupid) representing the user id and group id
    of the "current" user. Returns a tuple of (None, None) if the user id or group id
    cannot be determined for the "current" user for whatever reason
    """
    if os.name != 'posix':
        return None, None
    try:
        grp_id = str(os.getgid())
        user_id = str(os.getuid())
        return user_id, grp_id
    except KeyError:
        return None, None


def create_ssh_key():
    ssh_dir = os.path.expanduser("~/.ssh")
    id_rsa_file = os.path.join(ssh_dir, 'id_rsa')
    id_rsa_pub_file = os.path.join(ssh_dir, 'id_rsa.pub')
    if os.path.exists(id_rsa_file):
        logger.info('Deleting already existing file ' + id_rsa_file)
        os.remove(id_rsa_file)
    if os.path.exists(id_rsa_pub_file):
        logger.info('Deleting already existing file ' + id_rsa_pub_file)
        os.remove(id_rsa_pub_file)
    logger.info('Creating SSH keys')
    os.makedirs(ssh_dir, exist_ok=True)
    cmd = 'ssh-keygen -N "" -b 2048 -t rsa -f ' + id_rsa_file
    logger.debug("SSH key gen command: " + cmd)
    subprocess.check_output([cmd], shell=True, text=True, stderr=subprocess.STDOUT)
    return id_rsa_file, id_rsa_pub_file


def ssh_add_as_known_host(hosts: iter):
    for host in hosts:
        logger.debug('Adding ' + host + ' ssh known_hosts file')
        cmd = 'ssh-keyscan -H ' + host + ' >> ~/.ssh/known_hosts'
        subprocess.check_output([cmd], shell=True, text=True, stderr=subprocess.STDOUT)


def copy_ssh_key(target_host: str, ssh_user: str, ssh_pass: str, public_key_path: os.path = None):
    if public_key_path is None:
        ssh_dir = os.path.expanduser("~/.ssh")
        public_key_path = os.path.join(ssh_dir, 'id_rsa.pub')
    # SSHPASS env var is used by sshpass command to pass the SSH password
    # for the ssh-copy-id command
    env_vars = dict(os.environ)
    env_vars['SSHPASS'] = ssh_pass
    cmd = 'sshpass -e ssh-copy-id -i ' + public_key_path + ' ' + ssh_user + '@' + target_host
    logger.info('Copying ssh key ' + public_key_path + ' for user '
                + ssh_user + ' to host ' + target_host)
    try:
        subprocess.check_output([cmd], shell=True, text=True, env=env_vars,
                                stderr=subprocess.STDOUT)
    except subprocess.CalledProcessError as cpe:
        if cpe.returncode == _SSHPASS_CMD_INCORRECT_PASSWORD_EXIT_CODE:
            cli_err_exit('Incorrect SSH password for user ' + ssh_user + ' on host ' + target_host)
        raise


def center_text_on_terminal(text: str) -> str or None:
    if text is None:
        return None
    # split the string on (any) newlines
    lines = text.split('\n')
    terminal_size_cols = shutil.get_terminal_size()[0]
    first_line = True
    centered_text = ''
    for line in lines:
        if not first_line:
            centered_text += '\n'
        centered_text += line.center(terminal_size_cols)
        first_line = False
    return centered_text


def get_cli_version() -> str:
    return pkg_resources.get_distribution('rdafcli').version
