import argparse
import configparser
import datetime
import json
import logging
import os
import shutil
import socket
import stat
import string
import subprocess
import tempfile
import time
from typing import Callable, Any, List, Tuple

import yaml
from docker.models.containers import ContainerCollection

from rdaf import get_templates_dir_root
import rdaf.component
import rdaf.util.requestsutil
from rdaf import rdafutils
from rdaf.component import Component, InfraCategoryOrder, execute_command, run_command, create_file, do_potential_scp, do_potential_scp_fetch
from rdaf.component import _host_dir_storer, _comma_delimited_to_list, find_all_files, \
    _list_to_comma_delimited, _host_dir_loader, _apply_data_dir_defaults, \
    run_potential_ssh_command, run_command_exitcode, copy_content_to_root_owned_file, remove_dir_contents, \
check_potential_remote_file_exists
from rdaf.rdafutils import str_base64_encode, str_base64_decode, cli_err_exit
import termcolor

logger = logging.getLogger(__name__)
COMPONENT_NAME = 'mariadb'


class MariaDB(rdaf.component.Component):
    _option_data_dir = 'datadir'
    _option_user = 'user'
    _option_password = 'password'
    _option_host = 'host'
    _option_master_id = 'master_id'

    _default_data_dirs = ['/var/mysql']
    _backup_stream_file_name = 'mariadb-backup-mbstream'
    _backup_stream_gzip_file_name = _backup_stream_file_name + '.gz'
    _backup_bin_log_tar_file_name = 'mariadb-bin-log.tar.gz'

    def __init__(self):
        super().__init__(COMPONENT_NAME, 'mariadb', 'infra', InfraCategoryOrder.MARIADB.value)

    def _get_config_loader(self, config_name: str) -> Callable[[str], Any]:
        if config_name == self._option_host:
            return _comma_delimited_to_list
        if config_name == self._option_data_dir:
            # convert the comma separated value to list of tuples
            return _host_dir_loader
        return None

    def _get_config_storer(self, config_name: str) -> Callable[[Any], str]:
        if config_name == self._option_host:
            return _list_to_comma_delimited
        if config_name == self._option_data_dir:
            # convert the list of tuple to a comma separate value
            return _host_dir_storer
        return None

    def _init_default_configs(self):
        default_configs = dict()
        default_configs[self._option_data_dir] = None
        default_configs[self._option_user] = None
        default_configs[self._option_password] = None
        default_configs[self._option_host] = None
        default_configs[self._option_master_id] = 0
        return default_configs

    def get_hosts(self) -> list:
        return self.configs[self._option_host]

    def _get_host_data_dirs(self) -> List[Tuple[str, List[str]]]:
        return self.configs[self._option_data_dir]

    def _get_host_data_dir(self, hostname: str):
        for host, dir in self.configs[self._option_data_dir]:
            if host == hostname:
                return dir[0]

    def get_ports(self) -> tuple:
        hosts = self.get_hosts()
        ports = ['3306', '4444', '4567', '4568'] if len(hosts) > 1 else ['3306']
        return hosts, ports

    def get_user(self) -> str:
        return str_base64_decode(self.configs[self._option_user])

    def get_password(self) -> str:
        return str_base64_decode(self.configs[self._option_password])

    def get_escaped_password(self) -> str:
        return str_base64_decode(self.configs[self._option_password]).replace("$", "\\$").replace("&", "\\&")
       
    def healthcheck(self, component_name, host, cmd_args: argparse.Namespace, config_parser: configparser.ConfigParser):
        command = 'SHOW DATABASES'
        command = 'mysql -u {} -p{} -h {} -e "{}"'.format(self.get_user(), self.get_escaped_password(),
                                                          host, command)
        try:
            exit_code, stdout, stderr = run_command_exitcode(command, socket.gethostname(), config_parser)
            if exit_code != 0:
                if stderr:
                    raise Exception(stderr)
                raise Exception("Unable to connect to mariadb")
        except Exception as e:
            return [component_name, "Service Status", termcolor.colored("Failed", color='red'), str(e)]

        return [component_name, "Service Status", "OK", "N/A"]

    def gather_minimal_setup_inputs(self, cmd_args, config_parser):
        mariadb_configs = self._init_default_configs()
        mariadb_configs[self._option_host] = [self.get_default_host()]
        mariadb_configs[self._option_data_dir] = [(self.get_default_host(), self._default_data_dirs)]
        mariadb_configs[self._option_user] = str_base64_encode('rdafadmin')
        mariadb_configs[self._option_password] = str_base64_encode(rdafutils.gen_password(num_chars=8))
        self._mark_configured(mariadb_configs, config_parser)

    def gather_setup_inputs(self, cmd_args, config_parser):
        mariadb_configs = self._init_default_configs()
        default_host_name = rdaf.component.Component.get_default_host()
        no_prompt_err_msg = 'No MariaDB server host specified. Use --mariadb-server-host ' \
                            'to specify a MariaDB server host'
        mariadb_host_desc = 'What is the "host/path-on-host" on which you want the ' \
                            'MariaDB server to be provisioned?'
        host_dirs = Component._parse_or_prompt_host_dirs(
            cmd_args.mariadb_server_host,
            default_host_name,
            no_prompt_err_msg, mariadb_host_desc,
            'MariaDB server host/path',
            cmd_args.no_prompt)
        mariadb_configs[self._option_data_dir] = _apply_data_dir_defaults(
            host_dirs, self._default_data_dirs)
        # make sure each host has only 1 data dir
        for host, data_dirs in mariadb_configs[self._option_data_dir]:
            if len(data_dirs) > 1:
                rdafutils.cli_err_exit('MariaDB allows only one data dir per host, '
                                       + host + ' is configured with more than 1: '
                                       + str(data_dirs))
        mariadb_hosts = []
        for host, data_dirs in host_dirs:
            mariadb_hosts.append(host)
        mariadb_configs[self._option_host] = mariadb_hosts
        mariadb_user_desc = 'What is the user name you want to give for MariaDB admin user that ' \
                            'will be created and used by the RDAF platform?'
        user_no_prompt_err_msg = 'No MariaDB user specified. Use --mariadb-user to specify one'
        user = rdaf.component.Component._parse_or_prompt_value(cmd_args.mariadb_user,
                                                               'rdafadmin',
                                                               user_no_prompt_err_msg,
                                                               mariadb_user_desc,
                                                               'MariaDB user',
                                                               cmd_args.no_prompt)
        mariadb_configs[self._option_user] = str_base64_encode(user)

        pass_desc = 'What is the password you want to use for the newly created' \
                    ' MariaDB root user?'
        pass_no_prompt_err_msg = 'No MariaDB password specified. Use --mariadb-password' \
                                 ' to specify one'
        default_autogen_password = rdafutils.gen_password(num_chars=8)
        passwd = rdaf.component.Component._parse_or_prompt_value(cmd_args.mariadb_password,
                                                                 default_autogen_password,
                                                                 pass_no_prompt_err_msg,
                                                                 pass_desc, 'MariaDB password',
                                                                 cmd_args.no_prompt,
                                                                 password=True)
        mariadb_configs[self._option_password] = str_base64_encode(passwd)
        self._mark_configured(mariadb_configs, config_parser)

    def do_setup(self, cmd_args, config_parser):
        for host, data_dirs in self.configs[self._option_data_dir]:
            data_dir = data_dirs[0]
            command = 'sudo mkdir -p ' + data_dir + ' && sudo chmod -R 777 ' + data_dir
            run_potential_ssh_command(host, command, config_parser)

        # log files
        mysql_log = os.path.join(self.get_logs_dir(), 'mariadb.log')
        mysql_slow_log = os.path.join(self.get_logs_dir(), 'mariadb-slow.log')
        cnf_filename = 'mariadb_custom_replication.cnf' \
            if ((hasattr(cmd_args, 'primary') and cmd_args.primary) or
                (hasattr(cmd_args, 'secondary') and cmd_args.secondary)) else 'mariadb_custom.cnf'
        config_template = os.path.join(get_templates_dir_root(), cnf_filename)
        with open(config_template, 'r') as f:
            template_content = f.read()
        replacement_values = dict()
        if hasattr(cmd_args, 'primary') and cmd_args.primary:
            replacement_values['SERVER_ID'] = 1
            replacement_values['REPORT_HOST'] = 'master1'
            replacement_values['GTID_DOMAIN_ID'] = 10
            # saving the server
            self.configs[self._option_master_id] = 1
            self._mark_configured(self.configs, config_parser)
            self.write_configs(config_parser)
        else:
            primary_config_file = hasattr(cmd_args, 'primary_config') and cmd_args.primary_config
            if primary_config_file:
                primary_config = configparser.ConfigParser(allow_no_value=True)
                primary_config.read(primary_config_file)
                master_id = primary_config.get('mariadb', 'master_id', fallback=1)
                if int(master_id) == 1:
                    replacement_values['SERVER_ID'] = 2
                    replacement_values['REPORT_HOST'] = 'master2'
                    replacement_values['GTID_DOMAIN_ID'] = 20
                    self.configs[self._option_master_id] = 2
                else:
                    replacement_values['SERVER_ID'] = 1
                    replacement_values['REPORT_HOST'] = 'master1'
                    replacement_values['GTID_DOMAIN_ID'] = 10
                    self.configs[self._option_master_id] = 1
                self._mark_configured(self.configs, config_parser)
                self.write_configs(config_parser)
        substituted_content = string.Template(template_content).substitute(replacement_values)
        conf_dest = os.path.join(self.get_conf_dir(), 'my_custom.cnf')
        for host in self.get_hosts():
            command = 'mkdir -p ' + self.get_logs_dir() + ' && touch ' \
                      + mysql_log + ' ' + mysql_slow_log + \
                      ' && sudo chmod -R 777 ' + self.get_logs_dir()
            run_potential_ssh_command(host, command, config_parser)

            # custom conf
            command = 'mkdir -p ' + self.get_conf_dir()
            run_potential_ssh_command(host, command, config_parser)
            create_file(host, substituted_content.encode(encoding='UTF-8'), conf_dest)

    def get_k8s_component_name(self):
        return 'rda-mariadb'

    def get_k8s_chart_name(self):
        return 'rda_mariadb'

    def do_k8s_setup(self, cmd_args, config_parser):
        replacements = self._get_docker_repo()
        replacements['REPLICAS'] = len(self.get_hosts())
        replacements['MARIADB_USER'] = str_base64_decode(self.configs[self._option_user])
        replacements['MARIADB_PASSWORD'] = str_base64_decode(self.configs[self._option_password])
        replacements['MARIADB_BACKUP_PASSWORD'] = str_base64_decode(self.configs[self._option_password])

        template_dir = 'k8s-local'
        if self.get_deployment_type(config_parser) == "aws":
            template_dir = 'k8s-aws'
        else:
            data_dir_permissions = 'sudo chown -R 1001:1001 /var/mysql'
            for host in self.get_hosts():
                run_potential_ssh_command(host, data_dir_permissions, config_parser)
        template_path = os.path.join(get_templates_dir_root(), template_dir, 'mariadb-values.yaml')
        dest_path = os.path.join('/opt', 'rdaf', 'deployment-scripts', 'mariadb-values.yaml')
        with open(template_path, 'r') as f:
            template_content = f.read()
        original_content = string.Template(template_content).safe_substitute(replacements)
        with open(dest_path, 'w') as f:
            f.write(original_content)

    def get_k8s_install_args(self, cmd_args):
        args = '--set image.tag={} '.format(cmd_args.tag)
        return args
    
    def k8s_pull_images(self, cmd_args, config_parser):
        if self.get_deployment_type(config_parser) != 'k8s':
            return
        docker_repo = self._get_docker_repo()['DOCKER_REPO']
        for host in self.get_hosts():
            logger.info(f'Pulling {self.component_name} images on host {host}')
            docker_pull_command = f'docker pull {docker_repo}/rda-platform-mariadb:{cmd_args.tag}'
            run_potential_ssh_command(host, docker_pull_command, config_parser)

    def k8s_install(self, cmd_args: argparse.Namespace, config_parser: configparser.ConfigParser):
        namespace = self.get_namespace(config_parser)
        if self.get_deployment_type(config_parser) != "aws":
            mariadb_pv_file = 'mariadb-pv-cluster.yaml' if len(self.get_hosts()) > 1 else 'mariadb-pv.yaml'
            mariadb_pvs = 'kubectl apply -f {} -n {}'.format(
                os.path.join(get_templates_dir_root(), 'k8s-local', mariadb_pv_file), namespace)
            run_command(mariadb_pvs)

        chart_template_path = os.path.join(rdaf.get_helm_charts_dir(), self.get_k8s_component_name())
        deployment_path = os.path.join('/opt', 'rdaf', 'deployment-scripts', 'helm', self.get_k8s_component_name())
        self.copy_helm_chart(chart_template_path, deployment_path)
        values_yaml = os.path.join('/opt', 'rdaf', 'deployment-scripts', 'mariadb-values.yaml')
        install_command = 'helm install  --create-namespace -n {} -f {} {} {} {} ' \
            .format(namespace, values_yaml, self.get_k8s_install_args(cmd_args), self.get_k8s_component_name(), deployment_path)
        run_command(install_command)

        if self.get_deployment_type(config_parser) != "aws":
            mariadb_port = self.get_service_node_port('rdaf-mariadb-public', config_parser)
            mariadb_host = self.get_hosts()[0]
            is_db_ready = 'mysqladmin ping -u{} -p{} -h{} -P {} --wait=60 --silent' \
                .format(self.get_user(), self.get_escaped_password(), mariadb_host, mariadb_port)

            run_potential_ssh_command(mariadb_host, is_db_ready, config_parser)
            logger.info("Granting user privileges...")

            # creating the configured users and health user for haproxy
            self.grant_privileges_to_user(mariadb_host, config_parser, user=self.get_user(), port=mariadb_port)

    def k8s_upgrade(self, cmd_args: argparse.Namespace, config_parser: configparser.ConfigParser):
        namespace = self.get_namespace(config_parser)
        chart_template_path = os.path.join(rdaf.get_helm_charts_dir(), self.get_k8s_component_name())
        deployment_path = os.path.join('/opt', 'rdaf', 'deployment-scripts', 'helm', self.get_k8s_component_name())
        self.copy_helm_chart(chart_template_path, deployment_path)
        values_yaml = os.path.join('/opt', 'rdaf', 'deployment-scripts', 'mariadb-values.yaml')
        upgrade_command = 'helm upgrade --install --create-namespace -n {} -f {} {} {} {} ' \
            .format(namespace, values_yaml, self.get_k8s_install_args(cmd_args), self.get_k8s_component_name(),
                    deployment_path)
        run_command(upgrade_command)

    def k8s_down(self, cmd_args: argparse.Namespace, config_parser: configparser.ConfigParser):
        namespace = self.get_namespace(config_parser)
        label = self.get_k8s_component_label()
        cmd = f'kubectl get pods -l {label} -n {namespace} -o json'
        ret, stdout, stderr = execute_command(cmd)
        component = json.loads(stdout)
        for items in component["items"]:
            metadata = items["metadata"]["ownerReferences"]
            for pod in metadata:
                name = pod['name']
                command = f'kubectl scale statefulset.apps/{name} -n {namespace} --replicas=0'
                run_command(command)
                return   

    def k8s_up(self, cmd_args: argparse.Namespace, config_parser: configparser.ConfigParser):
        namespace = self.get_namespace(config_parser)
        label = "app.kubernetes.io/instance={}".format(self.get_k8s_component_name())
        cmd = f'kubectl get all -l {label} -n {namespace} -o json'
        ret, stdout, stderr = execute_command(cmd)
        component = json.loads(stdout)
        for comp in component["items"]:
            if comp["kind"] == 'StatefulSet':
                comp_name = comp['metadata']['name']
                command = 'kubectl scale statefulset.apps/{} -n {} --replicas={}'\
                    .format(comp_name, namespace, len(self.get_hosts()))
                run_command(command)  

    def install(self, cmd_args: argparse.Namespace, config_parser: configparser.ConfigParser):
        self.open_ports(config_parser)
        server_id = 0
        for host in self.get_hosts():
            server_id += 1
            compose_file_path = self.create_compose_file(cmd_args, host)
            command = '/usr/local/bin/docker-compose --project-name infra -f {file} up -d {service}' \
                    .format(file=compose_file_path, service=self.component_name)
            run_potential_ssh_command(host, command, config_parser)

            if server_id == 1:
                is_db_ready = 'mysqladmin ping -uroot -p{} -h{} --wait=30 --silent' \
                    .format(self.get_escaped_password(), host)

                run_potential_ssh_command(host, is_db_ready, config_parser)
                logger.info("Granting user privileges...")
                # creating the configured users and health user for haproxy
                self.grant_privileges_to_user(host, config_parser)
                if self.is_geodr_deployment(config_parser):
                    self.create_replication_users(host, config_parser)

        if len(self.get_hosts()) > 1:
            self.add_auto_restart_script(config_parser)
            # check for cluster status
            if not self.is_geodr_deployment(config_parser):
                self._check_cluster_status(config_parser)

    def upgrade(self, cmd_args: argparse.Namespace, config_parser: configparser.ConfigParser):
        # down the mariadb first and then upgrade
        self.down(cmd_args, config_parser)
        for host in self.get_hosts():
            compose_file_path = self.create_compose_file(cmd_args, host)
            command = '/usr/local/bin/docker-compose --project-name infra -f {file} up -d {service}' \
                    .format(file=compose_file_path, service=self.component_name)
            run_potential_ssh_command(host, command, config_parser)

    def backup_data(self, cmd_args: argparse.Namespace, config_parser: configparser.ConfigParser,
                    backup_state: configparser.ConfigParser, backup_dir_root: os.path):
        mariadb_host = self.get_hosts()[0]
        user = self.get_user()
        password = self.get_password()
        mariadb_backup_dir = os.path.join(backup_dir_root, 'data', self.get_name())
        command = 'sudo mkdir -p ' + mariadb_backup_dir + ' && sudo chmod -R 777 ' + mariadb_backup_dir
        run_command(command)
        if not self.is_local_host(mariadb_host):
            command = 'sudo mkdir -p ' + mariadb_backup_dir + ' && sudo chmod -R 777 ' + mariadb_backup_dir
            run_potential_ssh_command(mariadb_host, command, config_parser)
        backup_dir_in_container = '/opt/rdaf/mariadb-backup'
        backup_command = 'sh -c "mariabackup --backup --user=$user'\
                         + ' --password=$password' + \
                         ' --host=' + mariadb_host + ' --port=3306 ' + \
                         (' --galera-info' if len(self.get_hosts()) > 1 else '') + \
                         ' --target-dir=' + backup_dir_in_container + \
                         ' --stream=xbstream | gzip > ' \
                         + os.path.join(backup_dir_in_container,
                                        MariaDB._backup_stream_gzip_file_name) + '"'

        image_name = self._get_mariadb_image_name()
        with Component.new_docker_client_(mariadb_host) as docker_client:
            container_collection = ContainerCollection(client=docker_client.client)
            container = container_collection.run(image=image_name,
                                                 environment={'user': user, 'password': password},
                                                 volumes={mariadb_backup_dir: {
                                                     'bind': backup_dir_in_container,
                                                     'mode': 'rw'},
                                                     '/var/mysql': {
                                                         'bind': '/bitnami/mariadb/data/',
                                                         'mode': 'rw'}},
                                                 command=backup_command, detach=True,
                                                 remove=True)

            for line in container.logs(stream=True):
                statement = line.decode("utf-8").strip()
                print(statement)

            if 'completed OK!' in statement:
                logger.debug('mariadb backup completed with output')
            else:
                rdafutils.cli_err_exit('mariadb backup failed due to an error: ' + statement)


    def k8s_backup_data(self, cmd_args: argparse.Namespace, config_parser: configparser.ConfigParser,
                    backup_state: configparser.ConfigParser, backup_dir_root: os.path):
        namespace = self.get_namespace(config_parser)
        mariadb_host = f'rda-mariadb-mariadb-galera-0.rda-mariadb-mariadb-galera-headless.{namespace}.svc.cluster.local'
        mariadb_backup_dir = os.path.join(backup_dir_root, 'data', self.get_name())
        command = 'sudo mkdir -p ' + mariadb_backup_dir + ' && sudo chmod -R 777 ' + mariadb_backup_dir
        run_command(command)
        backup_dir_in_container = '/opt/rdaf/mariadb-backup'
        backup_command = 'sh -c "mariabackup --backup --user=' + self.get_user() \
                         + ' --password=' + self.get_password() +\
                         ' --host=' + mariadb_host + ' --port=3306 ' + \
                         (' --galera-info' if len(self.get_hosts()) > 1 else '') + \
                         ' --target-dir=' + backup_dir_in_container + \
                         ' --stream=xbstream | gzip > ' \
                         + os.path.join(backup_dir_in_container,
                                        MariaDB._backup_stream_gzip_file_name) + '"'

        mariadb_backup_pod = os.path.join(get_templates_dir_root(), 'mariadb-backup.yaml')
        with open(mariadb_backup_pod, 'r') as f:
            template_content = f.read()
        replacements = self._get_docker_repo()
        replacements['NAMESPACE'] = namespace
        replacements['BACKUP_PATH'] = mariadb_backup_dir
        content = string.Template(template_content).substitute(replacements)

        with tempfile.TemporaryDirectory(prefix='rdaf') as tmp:
            deployment_file = os.path.join(tmp, 'mariadb-backup.yaml')
            with open(deployment_file, 'w+') as f:
                f.write(content)
                f.flush()

            run_command('kubectl apply -f ' + deployment_file)
            logger.info("Waiting for mariadb backup pod to be up and running...")
            time.sleep(5)
            pod_status_command = f'kubectl wait --for=condition=Ready pod --timeout=600s -n {namespace} ' \
                                 f'-l app_component=rda-mariadb-backup'
            ret, stdout, stderr = execute_command(pod_status_command)
            if ret != 0:
                cli_err_exit("Failed to get status of mariadb backup creation pod, due to: {}.".format(str(stderr)))

            logger.info("Triggering mariadb backup..")
            backup_pod = self.get_pods_names(config_parser, 'app_component=rda-mariadb-backup')[0]
            ret, stdout, stderr = execute_command(f'kubectl exec -it {backup_pod} -n {namespace} -- {backup_command}')
            print(str(stdout))
            last_line = str(stdout).strip().split('\n')[-1]
            if 'completed OK!' in last_line:
                logger.debug('mariadb backup completed with output')
            else:
                rdafutils.cli_err_exit('mariadb backup failed due to an error: ' + last_line)

            run_command('kubectl delete -f ' + deployment_file)

    def restore_conf(self, config_parser: configparser.ConfigParser,
                     backup_content_root_dir: os.path):
        super().restore_conf(config_parser, backup_content_root_dir)
        mysql_log = os.path.join(self.get_logs_dir(), 'mariadb.log')
        mysql_slow_log = os.path.join(self.get_logs_dir(), 'mariadb-slow.log')
        for host in self.get_hosts():
            command = 'mkdir -p ' + self.get_logs_dir() + ' && touch ' \
                      + mysql_log + ' ' + mysql_slow_log + \
                      ' && sudo chmod -R 777 ' + self.get_logs_dir()
            run_potential_ssh_command(host, command, config_parser)

    def required_container_state_before_restore(self):
        return 'exited'

    def restore_data(self, cmd_args: argparse.Namespace, config_parser: configparser.ConfigParser,
                     backup_content_root_dir: os.path,
                     backup_cfg_parser: configparser.ConfigParser):
        mariadb_backup_dir = os.path.join(backup_content_root_dir, 'data', self.get_name())
        if not os.path.isdir(mariadb_backup_dir):
            logger.warning(
                mariadb_backup_dir + ' is missing. Skipping restoration of MariaDB data')
            return
        prepare_done = False
        # do the --prepare through the container
        backup_dir_in_container = '/opt/rdaf/mariadb-backup'
        now = datetime.datetime.now()
        current_run_name = str(now.date()) + '-' + str(now.timestamp())
        prep_restore_dir_name = 'prepare-restore-' + current_run_name
        prepare_dir_in_container = os.path.join(backup_dir_in_container, prep_restore_dir_name)
        data_dir_path_in_container = '/bitnami/mariadb/data/'
        extracted_dir_in_container = os.path.join(prepare_dir_in_container, 'extracted')
        prepare_command = 'sh -c "mkdir -p ' + prepare_dir_in_container \
                          + ' && gzip -dc < ' \
                          + os.path.join(backup_dir_in_container,
                                         MariaDB._backup_stream_gzip_file_name) \
                          + ' > ' + os.path.join(prepare_dir_in_container,
                                                 MariaDB._backup_stream_file_name) \
                          + ' && mkdir -p ' + extracted_dir_in_container + ' && mbstream -x -C ' \
                          + extracted_dir_in_container \
                          + ' < ' + os.path.join(prepare_dir_in_container,
                                                 MariaDB._backup_stream_file_name) \
                          + ' && mariabackup --prepare --target-dir=' + \
                          extracted_dir_in_container + '"'
        restore_command = 'sh -c "mariabackup --copy-back --target-dir=' \
                          + extracted_dir_in_container \
                          + ' --datadir=' + data_dir_path_in_container \
                          + '"'

        for mariadb_host in self.get_hosts():
            image_name = self._get_mariadb_image_name()
            with Component.new_docker_client_(mariadb_host) as docker_client:
                container_collection = ContainerCollection(client=docker_client.client)
                if not prepare_done:
                    # clean up data dir(s)
                    logger.info('Initiating a data cleanup before restoration of ' + self.get_name())
                    self._delete_data(config_parser)
                    # initiate a --prepare (only once)
                    logger.info('Initiating a mariadb restoration --prepare on host ' + mariadb_host)
                    container = container_collection.run(
                        image=image_name, network_mode='host',
                        volumes={mariadb_backup_dir: {
                            'bind': backup_dir_in_container, 'mode': 'rw'},
                            '/var/mysql': {'bind': '/bitnami/mariadb/data/', 'mode': 'rw'}},
                        command=prepare_command, detach=True, remove=True)
                    for line in container.logs(stream=True):
                        statement = line.decode("utf-8").strip()
                        print(statement)

                    # holding to the state of backup
                    galera_info='00000000-0000-0000-0000-000000000000:-1'
                    galera_info_file = os.path.join(mariadb_backup_dir, prep_restore_dir_name,
                                                    'extracted', 'xtrabackup_galera_info')
                    if os.path.exists(galera_info_file):
                        # with open(galera_info_file, 'r') as f:
                        #     galera_info = f.read().strip()
                        galera_info = subprocess.check_output('sudo cat ' + galera_info_file, shell=True)
                        galera_info = galera_info.decode('UTF-8').strip()

                    state = galera_info.split(':', 1)
                    prepare_done = True

                logger.info('Initiating Mariadb data restoration on host ' + mariadb_host)
                container = container_collection.run(
                    image=image_name, network_mode='host',
                    volumes={mariadb_backup_dir: {
                        'bind': backup_dir_in_container, 'mode': 'rw'},
                        '/var/mysql': {
                            'bind': '/bitnami/mariadb/data/', 'mode': 'rw'}},
                    command=restore_command, detach=True, remove=True)

                for line in container.logs(stream=True):
                    statement = line.decode("utf-8").strip()
                    print(statement)

            state_dat = '''
# GALERA saved state
version: 2.1
uuid:    %s
seqno:   %s
safe_to_bootstrap: 1
  ''' % (state[0], state[1])

            copy_content_to_root_owned_file(mariadb_host, state_dat, '/var/mysql/grastate.dat',
                                            config_parser)
            permissions = 'sudo chown -R 1001:root /var/mysql'
            run_potential_ssh_command(mariadb_host, permissions, config_parser)

        # cleanup the dir that we used to prepare for restoration
        prep_dir = os.path.join(mariadb_backup_dir, prep_restore_dir_name)
        if os.path.isdir(prep_dir):
            try:
                shutil.rmtree(prep_dir, ignore_errors=True)
            except Exception:
                logger.warning('Ignoring error that happened while deleting directory '
                               + prep_dir, exc_info=1)


    def check_pod_state_for_restore(self, config_parser: configparser.ConfigParser):
        namespace = self.get_namespace(config_parser)
        command = 'kubectl get pods -n ' + namespace + ' --selector=app_component=rda-mariadb ' \
                  '-o jsonpath="{.items[*].status.phase}"'
        exit_code, stdout, stderr = execute_command(command)
        if 'Running' in str(stdout):
            cli_err_exit('Mariadb pod should not be Running for restore...')

    def k8s_restore_data(self, cmd_args: argparse.Namespace, config_parser: configparser.ConfigParser,
                         backup_content_root_dir: os.path, backup_cfg_parser: configparser.ConfigParser):
        namespace = self.get_namespace(config_parser)
        mariadb_backup_dir = os.path.join(backup_content_root_dir, 'data', self.get_name())
        if not os.path.isdir(mariadb_backup_dir):
            logger.warning(
                mariadb_backup_dir + ' is missing. Skipping restoration of MariaDB data')
            return

        # do the --prepare through the container
        backup_dir_in_container = '/opt/rdaf/mariadb-backup'
        now = datetime.datetime.now()
        current_run_name = str(now.date()) + '-' + str(now.timestamp())
        prep_restore_dir_name = 'prepare-restore-' + current_run_name
        prepare_dir_in_container = os.path.join(backup_dir_in_container, prep_restore_dir_name)
        data_dir_path_in_container = '/bitnami/mariadb/data'
        extracted_dir_in_container = os.path.join(prepare_dir_in_container, 'extracted')
        prepare_command = 'sh -c "mkdir -p ' + prepare_dir_in_container \
                          + ' && gzip -dc < ' \
                          + os.path.join(backup_dir_in_container,
                                         MariaDB._backup_stream_gzip_file_name) \
                          + ' > ' + os.path.join(prepare_dir_in_container,
                                                 MariaDB._backup_stream_file_name) \
                          + ' && mkdir -p ' + extracted_dir_in_container + ' && mbstream -x -C ' \
                          + extracted_dir_in_container \
                          + ' < ' + os.path.join(prepare_dir_in_container,
                                                 MariaDB._backup_stream_file_name) \
                          + ' && mariabackup --prepare --target-dir=' + \
                          extracted_dir_in_container + '"'
        restore_command = 'sh -c "mariabackup --copy-back --target-dir=' \
                          + extracted_dir_in_container \
                          + ' --datadir=' + data_dir_path_in_container \
                          + '"'

        nodes = self.get_k8s_nodes()
        for host in self.get_hosts():
            mariadb_restore_pod = os.path.join(get_templates_dir_root(), 'mariadb-restore.yaml')
            with open(mariadb_restore_pod, 'r') as f:
                template_content = f.read()
            replacements = self._get_docker_repo()
            replacements['NAMESPACE'] = namespace
            replacements['NODE_NAME'] = nodes[host]
            replacements['BACKUP_PATH'] = mariadb_backup_dir
            content = string.Template(template_content).substitute(replacements)

            with tempfile.TemporaryDirectory(prefix='rdaf') as tmp:
                deployment_file = os.path.join(tmp, 'mariadb-restore.yaml')
                with open(deployment_file, 'w+') as f:
                    f.write(content)
                    f.flush()

                run_command('kubectl apply -f ' + deployment_file)
                logger.info("Waiting for mariadb restore pod to be up and running...")
                time.sleep(5)
                pod_status_command = 'kubectl wait --for=condition=Ready pod --timeout=600s -n {} ' \
                                     '-l app_component=rda-mariadb-restore'.format(namespace)
                ret, stdout, stderr = execute_command(pod_status_command)
                if ret != 0:
                    cli_err_exit("Failed to get status of mariadb restore pod, due to: {}.".format(str(stderr)))

                restore_pod = self.get_pods_names(config_parser, 'app_component=rda-mariadb-restore')[0]
                if host == self.get_hosts()[0]:
                    # clean up data dir(s)
                    logger.info('Initiating a data cleanup before restoration of ' + self.get_name())
                    self._delete_data(config_parser)
                    # initiate a --prepare (only once)
                    logger.info('Initiating a mariadb restoration --prepare')
                    ret, stdout, stderr = execute_command(f'kubectl exec -it {restore_pod} -n {namespace} -- {prepare_command}')
                    print(str(stdout))

                    # holding to the state of backup
                    galera_info = '00000000-0000-0000-0000-000000000000:-1'
                    galera_info_file = os.path.join(mariadb_backup_dir, prep_restore_dir_name,
                                                    'extracted', 'xtrabackup_galera_info')
                    if os.path.exists(galera_info_file):
                        galera_info = subprocess.check_output('sudo cat ' + galera_info_file, shell=True)
                        galera_info = galera_info.decode('UTF-8').strip()
                    state = galera_info.split(':', 1)

                logger.info('Initiating Mariadb data restoration on host ' + host)
                ret, stdout, stderr = execute_command(f'kubectl exec -it {restore_pod} -n {namespace} -- {restore_command}')
                print(str(stdout))
                run_command('kubectl delete -f ' + deployment_file)

                logger.info("Waiting for mariadb restore pod to be deleted...")
                pod_status_command = 'kubectl wait --for=delete pod --timeout=600s -n {} ' \
                                     '-l app_component=rda-mariadb-restore'.format(namespace)
                ret, stdout, stderr = execute_command(pod_status_command)
                if ret != 0:
                    cli_err_exit("Failed to get status of mariadb restore pod, due to: {}.".format(str(stderr)))

            state_dat = '''
# GALERA saved state
version: 2.1
uuid:    %s
seqno:   %s
safe_to_bootstrap: 1
          ''' % (state[0], state[1])

            copy_content_to_root_owned_file(host, state_dat, '/var/mysql/data/grastate.dat', config_parser)
            permissions = 'sudo chown -R 1001:1001 /var/mysql'
            run_potential_ssh_command(host, permissions, config_parser)

        # cleanup the dir that we used to prepare for restoration
        prep_dir = os.path.join(mariadb_backup_dir, prep_restore_dir_name)
        if os.path.isdir(prep_dir):
            try:
                shutil.rmtree(prep_dir, ignore_errors=True)
            except Exception:
                logger.warning('Ignoring error that happened while deleting directory '
                               + prep_dir, exc_info=1)
        return

    def _get_data_dir_on_host(self, host: str) -> os.path:
        for h, data_dirs in self.configs[self._option_data_dir]:
            if h == host:
                # we expect only one data dir
                return data_dirs[0]
        return None

    def _copy_ssh_key(self, host: str, config_parser: configparser.ConfigParser):
        ssh_dir = os.path.expanduser("~/.ssh")
        public_key_path = os.path.join(ssh_dir, 'id_rsa.pub')
        if Component.is_local_host(host):
            # skip the copying of the keys, since we already have them locally.
            # however, add our own keys as authorized keys to allow the other
            # nodes using this same key to use password less login to the CLI host
            if os.path.isfile(public_key_path):
                authorized_keys = os.path.abspath(os.path.expanduser("~/.ssh/authorized_keys"))
                logger.info('Updating the ' + authorized_keys + ' on host ' + host)
                with open(authorized_keys, mode='a') as auth_file:
                    with open(public_key_path) as pk_file:
                        auth_file.write(pk_file.read())
                # set the correct permissions on ~/.ssh dir (this shouldn't really be here
                # but this is a hack for now)
                os.chmod(ssh_dir, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
            return
        private_key_path = os.path.join(ssh_dir, 'id_rsa')
        if not os.path.isfile(public_key_path):
            logger.warning(public_key_path + ' is missing or not a file')
            return
        if not os.path.isfile(private_key_path):
            logger.warning(private_key_path + ' is missing or not a file')
            return
        dest_dir = os.path.abspath(os.path.expanduser(os.path.join('~', '.ssh/')))
        private_key_dest_file = os.path.join(dest_dir, 'id_rsa')
        logger.info('Copying SSH private key to host ' + host)
        rdaf.component.do_potential_scp(host, private_key_path, private_key_dest_file)
        public_key_dest_file = os.path.join(dest_dir, 'id_rsa.pub')
        logger.info('Copying SSH public key to host ' + host)
        rdaf.component.do_potential_scp(host, public_key_path, public_key_dest_file)
        # set correct permissions on the files
        public_key_permissions = 644  # -rw-r--r--
        private_key_permissions = 600  # -rw-------
        ssh_dir_permissions = 700  # drwx------
        command = 'chmod ' + str(private_key_permissions) + ' ' + private_key_dest_file
        command += ' && chmod ' + str(public_key_permissions) + ' ' + public_key_dest_file
        command += ' && chmod ' + str(ssh_dir_permissions) + ' ' + ssh_dir
        logger.info('Setting up correct permissions on SSH key files on host ' + host)
        rdaf.component.run_potential_ssh_command(host, command, config_parser)

    def get_deployment_env(self, host: str):
        env = dict()
        env['MARIADB_DATA_MOUNT'] = self._get_host_data_dir(host)
        env['MARIADB_CONF_MOUNT'] = self.get_conf_dir()
        env['MARIADB_LOG_MOUNT'] = self.get_logs_dir()
        env['MARIADB_USER'] = self.get_user().replace("$", "$$")
        env['MARIADB_PASSWORD'] = self.get_password().replace("$", "$$")
        env['MARIADB_NODE'] = host
        if len(self.get_hosts()) == 1:
            env['RESTART_POLICY'] = 'unless-stopped'
        else:
            env['RESTART_POLICY'] = '\"no\"'
        return env

    def update_cluster_env(self, component_yaml: dict, host: str):
        if len(self.get_hosts()) == 1:
            return

        component_yaml['environment'].append('MARIADB_GALERA_CLUSTER_NAME=rdaf_galera')
        component_yaml['environment'].append('MARIADB_REPLICATION_USER=rdaf_replica')
        component_yaml['environment'].append('MARIADB_REPLICATION_PASSWORD=rdaf_replica')
        component_yaml['environment'].append('MARIADB_GALERA_CLUSTER_ADDRESS='
                                             + 'gcomm://' + ','.join(self.get_hosts()))
        if host == self.get_hosts()[0]:
            component_yaml['environment'].append('MARIADB_GALERA_CLUSTER_BOOTSTRAP=yes')

    def get_db_credentials(self):
        return self.get_user(), self.get_password()

    def grant_privileges_to_user(self, db_host: str, config_parser: configparser.ConfigParser,
                                 user='root', port='3306'):
        query = " GRANT ALL PRIVILEGES ON *.* TO '{}'@'%'; FLUSH PRIVILEGES;" \
            .format(self.get_user())

        command = 'mysql -u{} -p{} -h {} -P {} -e "{}"'.format(user, self.get_escaped_password(),
                                                               db_host, port, query)
        run_potential_ssh_command(db_host, command, config_parser)

        # health check user
        health_query = " DELETE FROM mysql.user WHERE User = 'health_check_user'; FLUSH PRIVILEGES; " \
                       "CREATE USER 'health_check_user'@'localhost'; " \
                       "CREATE USER 'health_check_user'@'%'; FLUSH PRIVILEGES;"

        health_user_cmd = 'mysql -u{} -p{} -h {} -P {} -e "{}"' .format(user, self.get_escaped_password(),
                                                                        db_host, port, health_query)
        run_potential_ssh_command(db_host, health_user_cmd, config_parser)

    def create_replication_users(self, db_host: str, config_parser: configparser.ConfigParser,
                                 user='root', port='3306'):
        if config_parser.getboolean("rdaf-cli", 'primary'):
            user_creation = "CREATE USER IF NOT EXISTS 'rep_master1'@'%' identified by 'admin1234'; " \
                    "GRANT replication slave on *.* to 'rep_master1'@'%'; " \
                    "GRANT ALL PRIVILEGES ON *.* TO 'rdaf_backup'@'{}' IDENTIFIED BY 'rdaf_backup'; FLUSH PRIVILEGES".format(db_host)
        else:
            user_creation = "CREATE USER IF NOT EXISTS 'rep_master2'@'%' identified by 'admin1234'; " \
                            "GRANT replication slave on *.* to 'rep_master2'@'%'; " \
                            "GRANT ALL PRIVILEGES ON *.* TO 'rdaf_backup'@'{}' IDENTIFIED BY 'rdaf_backup'; FLUSH PRIVILEGES".format(
                db_host)
        command = 'mysql -u{} -p{} -h {} -P {} -e "{}"'.format(user, self.get_escaped_password(),
                                                               db_host, port, user_creation)
        run_potential_ssh_command(db_host, command, config_parser)


    @staticmethod
    def _get_mariadb_image_name():
        yamls = find_all_files('infra.yaml', os.path.join('/opt', 'rdaf', 'deployment-scripts'))
        image = None
        for entry in yamls:
            with open(entry, 'r') as f:
                template_content = f.read()
            content = yaml.safe_load(template_content)
            if 'mariadb' in content['services']:
                image = content['services']['mariadb']['image']
                break

        if not image:
            rdafutils.cli_err_exit("Unable to retrieve the mariadb image tag..")

        return image

    def add_auto_restart_script(self, config_parser):
        import rdaf.component.ssh as comp_ssh
        from rdaf.contextual import COMPONENT_REGISTRY
        ssh_user = COMPONENT_REGISTRY.require(comp_ssh.SSHKeyManager.COMPONENT_NAME).get_ssh_user()
        # copy the mariadb auto restart script on each of these host
        for host in self.get_hosts():
            # adding the auto-restart log file
            auto_restart_log_file = os.path.join(self.get_logs_dir(), 'auto-restart.log')
            command = 'sudo touch ' + auto_restart_log_file + ' && sudo chown -R ' \
                      + str(os.getuid()) + ' ' + auto_restart_log_file + ' && sudo chgrp -R ' \
                      + str(os.getuid()) + ' ' + auto_restart_log_file
            run_potential_ssh_command(host, command, config_parser)
            template_root = get_templates_dir_root()
            script_template = os.path.join(template_root, 'mariadb-auto-restart.sh')
            with open(script_template, 'r') as f:
                template_content = f.read()
            replacements = dict()
            replacements['MARIADB_NODE_ADDR'] = host
            replacements['MARIADB_HEALTHCHK_USERNAME'] = 'health_check_user'
            replacements['MARIADB_HEALTHCHK_PASSWORD'] = ''
            replacements['MARIADB_DATA_DIR'] = '/var/mysql'
            replacements['SSH_USER'] = ssh_user
            replacements['MARIADB_CLUSTER_HOSTS'] = ','.join(self.get_hosts())
            substituted_content = string.Template(template_content).safe_substitute(replacements)
            # copy to the conf dir in the install root of the host
            script_file = os.path.join(self.get_conf_dir(), 'mariadb-auto-restart.sh')
            logger.info('Creating ' + script_file + ' on host ' + host)
            create_file(host, substituted_content.encode(encoding='UTF-8'), script_file)
            # this script internally uses ssh to do certain checks between the mariadb nodes
            # and since these nodes maybe different from the node where CLI is installed,
            # we need to copy the ssh keys into these nodes, so that they get used as the
            # private/public keys of those nodes
            self._copy_ssh_key(host, config_parser)
            command = 'chmod +x ' + script_file
            run_potential_ssh_command(host, command, config_parser)
            self.update_crontab_job(host, config_parser, script_file, auto_restart_log_file)

    @staticmethod
    def update_crontab_job(host, config_parser, script_file, log_file):
        command = "crontab -l | grep -v '{}' | crontab -".format(script_file)
        run_potential_ssh_command(host, command, config_parser)
        command = '(crontab -l 2>/dev/null || true; echo "@reboot sleep 60 && ' \
                  '/bin/bash {} >> {}") | crontab -'.format(script_file, log_file)
        run_potential_ssh_command(host, command, config_parser)

    def _check_cluster_status(self, config_parser):
        logger.info("Checking all the nodes in the cluster to be up and sync.")
        error = 0
        for host in self.get_hosts():
            count = 0
            while count < 3:
                count += 1
                wsrep_command = "SHOW STATUS LIKE 'wsrep%';"
                command = 'mysql -u {} -p{} -h {} -e "{}" | grep wsrep_local_state_comment'\
                    .format(self.get_user(), self.get_password(), host, wsrep_command)

                ret, stdout, stderr = run_command_exitcode(command, host, config_parser)
                if 'Synced' in str(stdout):
                    break
                if count == 3:
                    logger.warning(f"{host} node failed to sync with cluster.")
                    error += 1

                logger.info("waiting for mariadb node {} to be up and sync.".format(host))
                time.sleep(30)
        if error > 1:
            cli_err_exit("Failed to install Mariadb cluster as nodes are not in sync.")


    def geodr_prepare_replication(self, cmd_args: argparse.Namespace, config_parser: configparser.ConfigParser,
                                peer_configs: configparser.ConfigParser):
        mariadb_backup_dir = os.path.join('/tmp', 'data', self.get_name())
        peer_hosts = peer_configs.get('mariadb', 'host').split(',')
        mariadb_hosts = self.get_hosts()

        # Clean up on mariadb host
        for host in mariadb_hosts:
            if check_potential_remote_file_exists(host, mariadb_backup_dir):
              logger.info(f'Cleaning up existing contents in mariadb backup directory on mariadb host {host}')
              remove_dir_contents(host, mariadb_backup_dir, config_parser, use_sudo=True)

        # Clean up on peer hosts
        for host in peer_hosts:
            if check_potential_remote_file_exists(host, mariadb_backup_dir):
                logger.info(f'Cleaning up existing contents in mariadb backup directory on peer host {host}')
                remove_dir_contents(host, mariadb_backup_dir, config_parser, use_sudo=True)

        # backup on the existing setup
        self.backup_data(cmd_args, config_parser, None, '/tmp')

        # TODO need to handle cluster scenario
        # restore backup on secondary
        mariadb_host = self.get_hosts()[0]
        if not self.is_local_host(mariadb_host):
            do_potential_scp_fetch(mariadb_host, mariadb_backup_dir, mariadb_backup_dir)
        for host in peer_hosts:
            logger.info(f"copying backup to {host}")
            do_potential_scp(host, mariadb_backup_dir, mariadb_backup_dir)
            run_potential_ssh_command(host, f'sudo chmod -R 777 {mariadb_backup_dir}', config_parser)

        prepare_done = False
        # do the --prepare through the container
        backup_dir_in_container = '/opt/rdaf/mariadb-backup'
        prep_restore_dir_name = 'prepare-restore'
        prepare_dir_in_container = os.path.join(backup_dir_in_container, prep_restore_dir_name)
        data_dir_path_in_container = '/bitnami/mariadb/data/'
        extracted_dir_in_container = os.path.join(prepare_dir_in_container, 'extracted')
        prepare_command = 'sh -c "mkdir -p ' + prepare_dir_in_container \
                          + ' && gzip -dc < ' \
                          + os.path.join(backup_dir_in_container,
                                         MariaDB._backup_stream_gzip_file_name) \
                          + ' > ' + os.path.join(prepare_dir_in_container,
                                                 MariaDB._backup_stream_file_name) \
                          + ' && mkdir -p ' + extracted_dir_in_container + ' && mbstream -x -C ' \
                          + extracted_dir_in_container \
                          + ' < ' + os.path.join(prepare_dir_in_container,
                                                 MariaDB._backup_stream_file_name) \
                          + ' && mariabackup --prepare --target-dir=' + \
                          extracted_dir_in_container + '"'
        restore_command = 'sh -c "mariabackup --copy-back --target-dir=' \
                          + extracted_dir_in_container \
                          + ' --datadir=' + data_dir_path_in_container \
                          + '"'

        for mariadb_host in peer_hosts:
            image_name = self._get_mariadb_image_name()
            with Component.new_docker_client_(mariadb_host) as docker_client:
                container_collection = ContainerCollection(client=docker_client.client)
                if not prepare_done:
                    # clean up data dir(s)
                    logger.info('Initiating a data cleanup before restoration of ' + self.get_name())
                    remove_dir_contents(mariadb_host, '/var/mysql/', config_parser, use_sudo=True)

                    # initiate a --prepare (only once)
                    logger.info('Initiating a mariadb restoration --prepare on host ' + mariadb_host)
                    logger.info(f'{prepare_command}')
                    container = container_collection.run(
                        image=image_name, network_mode='host',
                        volumes={mariadb_backup_dir: {
                            'bind': backup_dir_in_container, 'mode': 'rw'},
                            '/var/mysql': {'bind': '/bitnami/mariadb/data/', 'mode': 'rw'}},
                        command=prepare_command, detach=True, remove=True)
                    for line in container.logs(stream=True):
                        statement = line.decode("utf-8").strip()
                        print(statement)

                    # holding to the state of backup
                    galera_info = '00000000-0000-0000-0000-000000000000:-1'
                    galera_info_file = os.path.join(mariadb_backup_dir, prep_restore_dir_name,
                                                    'extracted', 'xtrabackup_galera_info')
                    if check_potential_remote_file_exists(mariadb_host, galera_info_file):
                        with tempfile.TemporaryDirectory(prefix='rdaf') as tmp:
                            do_potential_scp_fetch(mariadb_host, galera_info_file, os.path.join(tmp, 'galera_info_file'),
                                                   is_dir=False)
                            galera_info = subprocess.check_output('sudo cat ' + os.path.join(tmp, 'galera_info_file'),
                                                                  shell=True)
                        galera_info = galera_info.decode('UTF-8').strip()

                    state = galera_info.split(':', 1)
                    prepare_done = True

                logger.info('Initiating Mariadb data restoration on host ' + mariadb_host)
                logger.info(f'{restore_command}')
                container = container_collection.run(
                    image=image_name, network_mode='host',
                    volumes={mariadb_backup_dir: {
                        'bind': backup_dir_in_container, 'mode': 'rw'},
                        '/var/mysql': {
                            'bind': '/bitnami/mariadb/data/', 'mode': 'rw'}},
                    command=restore_command, detach=True, remove=True)

                for line in container.logs(stream=True):
                    statement = line.decode("utf-8").strip()
                    print(statement)

            state_dat = '''
# GALERA saved state
version: 2.1
uuid:    %s
seqno:   %s
safe_to_bootstrap: 1
      ''' % (state[0], state[1])

            copy_content_to_root_owned_file(mariadb_host, state_dat, '/var/mysql/grastate.dat',
                                        config_parser)
            permissions = 'sudo chown -R 1001:root /var/mysql'
            run_potential_ssh_command(mariadb_host, permissions, config_parser)

            # cleanup the dir that we used to prepare for restoration
            prep_dir = os.path.join(mariadb_backup_dir, prep_restore_dir_name)
            if os.path.isdir(prep_dir):
                try:
                    shutil.rmtree(prep_dir, ignore_errors=True)
                except Exception:
                    logger.warning('Ignoring error that happened while deleting directory '
                                   + prep_dir, exc_info=1)

        logger.info('Mariadb replication completed.')

    def geodr_start_replication(self, cmd_args: argparse.Namespace, config_parser: configparser.ConfigParser,
                                peer_configs: configparser.ConfigParser):
        peer_hosts = peer_configs.get('mariadb', 'host').split(',')

        # command = "SELECT @@global.gtid_slave_pos;"
        # command = 'mysql -u {} -p{} -h {} -Nse "{}"'.format(self.get_user(), self.get_escaped_password(),
        #                                                  peer_hosts[0], command)
        # exit_code, stdout, stderr = run_command_exitcode(command, socket.gethostname(), config_parser)
        # if not stdout.strip():
        backup_info_file = os.path.join("/var/mysql", "mariadb_backup_info")
        gtid = None
        if check_potential_remote_file_exists(peer_hosts[0], backup_info_file):
            command = 'sudo chmod -R 777 ' + backup_info_file
            run_potential_ssh_command(peer_hosts[0], command, config_parser)
            with tempfile.TemporaryDirectory(prefix='rdaf') as tmp:
                do_potential_scp_fetch(peer_hosts[0], backup_info_file, os.path.join(tmp, 'backup_info'),
                                        is_dir=False)
                with open(os.path.join(tmp, 'backup_info'), 'r') as f:
                    for line in f:
                        if 'binlog_pos' in line:
                            # Extract the GTID using split
                            parts = line.split("'")
                            gtid = parts[-2]
                            break

        if not gtid:
            command = "SELECT @@global.gtid_slave_pos;"
            command = 'mysql -u {} -p{} -h {} -Nse "{}"'.format(self.get_user(), self.get_escaped_password(), peer_hosts[0], command)
            exit_code, stdout, stderr = run_command_exitcode(command, socket.gethostname(), config_parser)
            if not stdout.strip():
                gtid = stdout.strip()
            else:
                raise Exception("Failed to get GTID for mariadb on server " + peer_hosts[0])
        set_gtid = f"set GLOBAL gtid_slave_pos='{gtid}';"
        command = 'mysql -u{} -p{} -h {} -P {} -e "{}"'.format(self.get_user(), self.get_escaped_password(), peer_hosts[0], 3306, set_gtid)
        run_potential_ssh_command(peer_hosts[0], command, config_parser)

        enable_replication = (
            f"stop slave; CHANGE MASTER TO MASTER_HOST='{self.get_hosts()[0]}', MASTER_USER='rep_master1',"
            f"MASTER_PASSWORD='admin1234', MASTER_USE_GTID=slave_pos; start slave")

        command = 'mysql -u{} -p{} -h {} -P {} -e "{}"'.format(self.get_user(), self.get_escaped_password(),
                                                               peer_hosts[0], 3306, enable_replication)
        run_potential_ssh_command(peer_hosts[0], command, config_parser)
        logger.info("Replication started for mariadb.")
    
    def geodr_status(self, cmd_args: argparse.Namespace, config_parser: configparser.ConfigParser,
                 peer_configs: configparser.ConfigParser):
        peer_hosts = peer_configs.get('mariadb', 'host').split(',')
        check_status = "SHOW SLAVE STATUS\\G"
        command = 'mysql -u{} -p{} -h {} -P {} -e "{}"'.format(
            self.get_user(),
            self.get_escaped_password(),
            peer_hosts[0],
            3306,
            check_status)
        logger.info(f"Checking replication status for mariadb")
        run_potential_ssh_command(peer_hosts[0], command, config_parser)

    def switch_primary(self, cmd_args: argparse.Namespace, config_parser: configparser.ConfigParser,
                                peer_configs: configparser.ConfigParser):
        logger.info("Switching mariadb to be a primary instance.")

        command = 'mysql -u{} -p{} -h {} -P 3306 -e "STOP SLAVE; SHOW SLAVE STATUS\\G;"'.format(self.get_user(), self.get_escaped_password(),
                                                               self.get_hosts()[0])
        run_potential_ssh_command(self.get_hosts()[0], command, config_parser)