import copy
from logging.handlers import WatchedFileHandler
import multiprocessing
import pickle
import subprocess
from copy import deepcopy
import datetime
import itertools
import os
import os.path as osp
from pathlib import Path
import random
import shutil
import sys
from collections import namedtuple
import time
import traceback
from tqdm import tqdm

from rlkit.core.logging import green, logger
from rlkit.launchers import print_welcome
from rlkit.core.logging import bold, colors, red
from rlkit.data_management.mdp_path_loader import load_file
from rlkit.core.multiprocessing import NestablePool
from rlkit import conf
from rlkit.launchers.pipeline import Pipelines
from rlkit.torch.pytorch_util import set_gpu_mode
import rlkit.launchers as l
import wandb
import torch

GitInfo = namedtuple(
    "GitInfo",
    [
        "directory",
        "code_diff",
        "code_diff_staged",
        "commit_hash",
        "branch_name",
    ],
)


def generate_snapshot(root, dest_path):
    """Given a destination path, copy the local root dir to the path. To
    save disk space, only ``*.py`` files will be copied.
    This function can be used to generate a snapshot of the repo so that the
    exactly same code status will be recovered when later playing a trained
    model or launching a grid-search job in the waiting queue.
    Args:
        root (str): the path to the repo
        dest_path (str): the path to generate a snapshot of repo
    """

    def rsync(src, target, includes, excludes):
        args = ["rsync", "-rI"]
        args += ["--exclude=%s" % i for i in excludes]
        args += ["--include=%s" % i for i in includes]
        args += ["--exclude=*"]
        args += [src, target]
       
        subprocess.check_call(
            " ".join(args), stdout=sys.stdout, stderr=sys.stdout, shell=True
        )

    with open(os.path.join(root, ".gitignore"), "r") as fin:
        excludes = fin.read().splitlines()

   
    includes = ["*"]
    rsync(root, dest_path, includes, excludes)

   


def get_git_infos(dirs):
    try:
        import git

        git_infos = []
        for directory in dirs:
           
            try:
                repo = git.Repo(directory)
                try:
                    branch_name = repo.active_branch.name
                except TypeError:
                    branch_name = "[DETACHED]"
                git_infos.append(
                    GitInfo(
                        directory=directory,
                        code_diff=repo.git.diff(None),
                        code_diff_staged=repo.git.diff("--staged"),
                        commit_hash=repo.head.commit.hexsha,
                        branch_name=branch_name,
                    )
                )
            except git.exc.InvalidGitRepositoryError as e:
                print("Not a valid git repo: {}".format(directory))
    except ImportError:
        git_infos = None
    return git_infos


def recursive_items(dictionary):
    """
    Get all (key, item) recursively in a potentially recursive dictionary.
    Usage:

    ```
    x = {
        'foo' : {
            'bar' : 5
        }
    }
    recursive_items(x)
   
   
   
    ```
    :param dictionary:
    :return:
    """
    for key, value in dictionary.items():
        yield key, value
        if type(value) is dict:
            yield from recursive_items(value)


def save_experiment_data(dictionary, log_dir):
    with open(log_dir + "/experiment.pkl", "wb") as handle:
        pickle.dump(dictionary, handle, protocol=pickle.HIGHEST_PROTOCOL)


def run_hyperparameters(parallel_cls, variant, hyperparameters: dict):
    print_welcome()

    if hyperparameters is None:
        raise Exception("No Hyperparameters given")

    all_experiment_combinations = []
    for kwarg_vals in list(itertools.product(*hyperparameters.values())):

        hp_string = ""
        trainer_kwargs = copy.deepcopy(variant["trainer_kwargs"])
        v = deepcopy(variant)
        for kw, val in zip(hyperparameters.keys(), kwarg_vals):
            hp_string += f"{kw[0]}={val}-"
            trainer_kwargs[kw] = val
            v[kw] = val
        v["trainer_kwargs"] = trainer_kwargs
        v["__gridsearch"] = hp_string[:-1]
        experiment_combinations = list(
            itertools.product(
                parallel_cls.seeds,
                parallel_cls.envs,
                (v,),
            )
        )

        all_experiment_combinations += experiment_combinations

    pool_run(list(enumerate(all_experiment_combinations)))


def run_parallel_pipeline_here(parallel_cls, variant):
    print_welcome()
    pool_run(
        list(
            enumerate(
                itertools.product(
                    parallel_cls.seeds,
                    parallel_cls.envs,
                    (variant,),
                )
            )
        )
    )


def pool_run(experiment_combinations):
    with torch.multiprocessing.Manager() as manager:
        d = manager.dict()
        with NestablePool(torch.cuda.device_count() * 3) as p:
            list(
                tqdm(
                    p.imap_unordered(
                        parallel_run_experiment_here_wrapper,
                        [(d, e) for e in experiment_combinations],
                    ),
                    total=len(experiment_combinations),
                )
            )


def parallel_run_experiment_here_wrapper(experiment_tuple):
    """A wrapper around run_experiment_here that uses just a single argument to work with multiprocessing pool map."""
    d, (i, (seed, env_id, variant)) = experiment_tuple

    cp = torch.multiprocessing.current_process().ident
    start = time.time()
    while cp is None:
        cp = torch.multiprocessing.current_process().ident
        time.sleep(1)
        if time.time() - start > 30:
            raise Exception("Couldn't get current process id!")
            
    bold(f"Running env_id: {env_id}, seed: {seed} with process {cp}")
    if torch.cuda.is_available():
        if d.get(cp) is None:
            gpu_id = int(i % torch.cuda.device_count())
            d[cp] = gpu_id
        else:
            gpu_id = d[cp]
    else:
        gpu_id = None
    variant = deepcopy(variant)
    variant["seed"] = seed
    variant["env_id"] = env_id
    run_pipeline_here(
        variant=variant,
        gpu_id=gpu_id,
        parallel=True,
        snapshot_mode=variant["snapshot_mode"],
        snapshot_gap=variant["snapshot_gap"],
    )


def resume(variant, previous_path):
    data = load_file(previous_path)
    algo = data["algorithm"]

    algo.num_epochs = variant["num_epochs"]

    post_pretrain_hyperparams = variant["trainer_kwargs"].get(
        "post_pretrain_hyperparams", {}
    )
    algo.trainer.set_algorithm_weights(**post_pretrain_hyperparams)

    algo.train()


def run_pipeline_here(
    variant,
    use_gpu=True,
    gpu_id=0,
   
    snapshot_mode="gap_and_last",
    snapshot_gap=100,
    git_infos=None,
    script_name=None,
    base_log_dir=None,
    force_randomize_seed=False,
    parallel=False,
    **setup_logger_kwargs,
):
    """
    Run an experiment locally without any serialization.
    This will add the 'log_dir' key to variant, and set variant['version'] to 'normal' if isn't already set.

    experiments. Note that one experiment may correspond to multiple seeds,.
    :param seed: Seed used for this experiment.
    :param use_gpu: Run with GPU. By default False.
    :param script_name: Name of the running script
    :return:
        trainer_cls=trainer_cls,
    """

    if not parallel:
        print_welcome()

    start = datetime.datetime.today()
    try:
        seed = variant.get("seed")
        algorithm = variant.get("algorithm")

        if force_randomize_seed or seed is None:
            seed = random.randint(0, 100000)
            variant["seed"] = seed
        l.reset_execution_environment()

        actual_log_dir, conflict_exists_and_continue = setup_logger(
            algorithm=algorithm,
            variant=variant,
            seed=seed,
            snapshot_mode=snapshot_mode,
            snapshot_gap=snapshot_gap,
            base_log_dir=base_log_dir,
            git_infos=git_infos,
            script_name=script_name,
            parallel=parallel,
            env_id=variant["env_id"],
            **setup_logger_kwargs,
        )

        l.set_seed(seed)
        set_gpu_mode(use_gpu, gpu_id)

        if conflict_exists_and_continue:
            return resume(variant, actual_log_dir)

        run_experiment_here_kwargs = dict(
            variant=variant,
            seed=seed,
            use_gpu=use_gpu,
            algorithm=algorithm,
            snapshot_mode=snapshot_mode,
            snapshot_gap=snapshot_gap,
            git_infos=git_infos,
            parallel=parallel,
            script_name=script_name,
            base_log_dir=base_log_dir,
            **setup_logger_kwargs,
        )
        save_experiment_data(
            dict(run_experiment_here_kwargs=run_experiment_here_kwargs),
            actual_log_dir,
        )

        variant["log_dir"] = actual_log_dir

        return Pipelines.run_pipeline(variant)
    except Exception as e:
        exception_name, exc_value, _ = sys.exc_info()
        if (
            exception_name is not None
            and not issubclass(exception_name, KeyboardInterrupt)
            and not issubclass(exception_name, FileExistsError)
        ):
            red( 
                f'{variant.get("algorithm")} seed: {variant.get("seed")} env_id: {variant.get("env_id")} started at {start.strftime("%I:%M %p %a %b %y")}, has crashed'
            )
            print(traceback.format_exc(), flush=True)
            sys.stdout.flush()
            sys.stderr.flush()
            if conf.Wandb.is_on:
                if wandb.run is not None:
                    wandb.alert(
                        title="Experiment Crash",
                        text=f'{variant.get("algorithm")} started at {start.strftime("%I:%M %p %a %b %y")}, has crashed',
                        level="ERROR",
                    )
            if conf.DEBUG:
                raise e
    else:
        green("Successfully finished")


def create_log_dir(
    algorithm,
    env_id,
    variant,
    version="normal",
    seed=0,
    parallel=False,
    base_log_dir=None,
):
    """
    Creates and returns a unique log directory.

    :param algorithm: All experiments with this prefix will have log
    directories be under this directory.
    experiment.
    :param base_log_dir: The directory where all log should be saved.
    :return:
    """
    conflict_exists_and_continue = False
    if variant.get("__gridsearch"):
        log_dir = (
            Path(base_log_dir or conf.Log.basedir)
            / algorithm
            / version
            / env_id
            / (variant.get("__gridsearch")).replace(' ', '_')
            / str(seed)
        )
    else:
        log_dir = (
            Path(base_log_dir or conf.Log.basedir)
            / algorithm
            / version
            / env_id
            / str(seed)
        )
    
    if osp.exists(log_dir):
        print(
            colors.WARNING
            + "This experiment already exists: {}".format(log_dir)
            + colors.ENDC
        )
        if parallel:
            print("Exiting")
            raise FileExistsError
           

        if conf.Log.conflict_policy == "REPLACE":
            if conf.DEBUG or l.query_yes_no(
                "Would you like to replace the existing directory?"
            ):
                bold("Replacing this directory...")
                shutil.rmtree(log_dir)
                os.makedirs(log_dir, exist_ok=True)
                bold("Replaced")
            else:
                print("Not replacing, exiting now")
                raise FileExistsError
        elif conf.Log.conflict_policy == "CONTINUE":
            bold("Continuing training from previous checkpoint")
            conflict_exists_and_continue = True
            raise NotImplementedError
    else:
        os.makedirs(log_dir, exist_ok=False)
    return str(log_dir), conflict_exists_and_continue


def setup_logger(
    algorithm="default",
    env_id=None,
    variant=None,
    text_log_file="debug.log",
    variant_log_file="variant.json",
    tabular_log_file="progress.csv",
    snapshot_mode="last",
    snapshot_gap=1,
    log_tabular_only=False,
    git_infos=None,
    script_name=None,
    wandb_entity=conf.Wandb.entity,
    wandb_project=conf.Wandb.project,
    parallel=False,
    **create_log_dir_kwargs,
):
    """
    Set up logger to have some reasonable default settings.

    Will save log output to

        basedir/<algorithm>/<algorithm-version>/<env_id>/<seed>

    exp_name will be auto-generated to be unique.

    If log_dir is specified, then that directory is used as the output dir.

    :param algorithm: The sub-directory for this specific experiment.
    :param variant:
    :param text_log_file:
    :param variant_log_file:
    :param tabular_log_file:
    :param snapshot_mode:
    :param log_tabular_only:
    :param snapshot_gap:
    :param log_dir:
    :param git_infos:
    :param script_name: If set, save the script name to this.
    :return:
    """
    if variant.get("version") is None:
        variant["version"] = "normal"

    log_dir, conflict_exists_and_continue = create_log_dir(
        algorithm,
        env_id,
        variant,
        version=variant["version"],
        parallel=parallel,
        **create_log_dir_kwargs,
    )
    if parallel:
        sys.stdout = open(osp.join(log_dir, "stdout.out"), "a")
        sys.stderr = open(osp.join(log_dir, "stderr.out"), "a")

    if conf.Wandb.is_on:
        wandb_group = f"{algorithm}-{variant['version']}-{env_id}"
        if variant.get("__gridsearch"):
            wandb_name = f"seed-{variant['seed']}-hp-{variant.get('__gridsearch')}"
        else:
            wandb_name = f"seed-{variant['seed']}"

        wandb.init(
            project=wandb_project,
            entity=wandb_entity,
            group=wandb_group,
            name=wandb_name,
            config=variant,
            reinit=True,
        )
        wandb.run.log_code(os.path.join(conf.Log.repo_dir, "src"))

    if git_infos is None:
        git_infos = l.get_git_infos([conf.Log.repo_dir])

    if variant is not None:
        variant_log_path = osp.join(log_dir, variant_log_file)
        logger.log_variant(variant_log_path, variant)

    tabular_log_path = osp.join(log_dir, tabular_log_file)
    text_log_path = osp.join(log_dir, text_log_file)

    logger.add_text_output(text_log_path)
    logger.add_tabular_output(tabular_log_path)
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(snapshot_mode)
    logger.set_snapshot_gap(snapshot_gap)
    logger.set_log_tabular_only(log_tabular_only)
    exp_name = log_dir.split("/")[-1]
    logger.push_prefix("[%s] " % exp_name)
    l.generate_snapshot(conf.Log.repo_dir, log_dir)
    if git_infos is not None:
        for (
            directory,
            code_diff,
            code_diff_staged,
            commit_hash,
            branch_name,
        ) in git_infos:
            if directory[-1] == "/":
                directory = directory[:-1]
            diff_file_name = directory[1:].replace("/", "-") + ".patch"
            diff_staged_file_name = directory[1:].replace("/", "-") + "_staged.patch"
            if code_diff is not None and len(code_diff) > 0:
                with open(osp.join(log_dir, diff_file_name), "w") as f:
                    f.write(code_diff + "\n")
            if code_diff_staged is not None and len(code_diff_staged) > 0:
                with open(osp.join(log_dir, diff_staged_file_name), "w") as f:
                    f.write(code_diff_staged + "\n")
            with open(osp.join(log_dir, "git_infos.txt"), "a") as f:
                f.write("directory: {}\n".format(directory))
                f.write("git hash: {}\n".format(commit_hash))
                f.write("git branch name: {}\n\n".format(branch_name))
    if script_name is not None:
        with open(osp.join(log_dir, "script_name.txt"), "w") as f:
            f.write(script_name)
    return log_dir, conflict_exists_and_continue
