from tqdm import tqdm
import copy
import os
import pandas as pd
import numpy as np
import shutil
from datetime import datetime
from filelock import FileLock
import time

from .utils.parsing import get_config
from .utils.executing import execute_strategy


def recursive_strategy(li):
    ''' Puts all possible combinations of key-value
        in a list of dicts
    '''
    if len(li) < 1:
        return [{}]
    ret = recursive_strategy(li[1:])
    elem = li[0]
    key = elem["parents"][-1]
    parents = elem["parents"][:-1]
    subvalues = elem["values"]

    new_ret = []
    for dico in ret:
        for value in subvalues:
            new_dico = dico.copy()
            new_dico[key] = {
                "parents": parents,
                "value": value
            }
            new_ret.append(new_dico)
    return new_ret


def find_strategy_type(stype, cfg, parents=None):
    if parents is None:  # @parents should NOT be modified
        parents = []
    ret = []
    for key in cfg.keys():
        if key == stype:
            return [{
                "parents": parents.copy(),
                "values": cfg[key]
            }]
        elif type(cfg[key]) == dict:
            ret += find_strategy_type(stype, cfg[key], parents=parents+[key])
    return ret


def main():
    cfg = get_config()

    main_logspath = os.environ.get("LOGSPATH", "logs")
    if not os.path.exists(main_logspath):
        os.mkdir(main_logspath)
    lock = FileLock(f"{main_logspath}/mk_logdir.lock")
    with lock:
        while True:
            logdir = f"{main_logspath}/" + "{:%Y-%m-%dT%H%M%S}".format(datetime.now())
            if not os.path.exists(logdir):
                break
            time.sleep(1)
        os.mkdir(logdir)

    if "strategy" not in cfg:
        execute_strategy(cfg, logdir=logdir)
        exit(0)

    # Find the suffixes that will be matched for either:
    # - multi-launching (diff_strats): execute slight variations on the cfg
    # - grid-searching (tmp_strats) : find the best model, erase the others
    ordered_strats = [[], []]
    monitor_best = None
    monitor_best_op = "more"
    for strat in cfg["strategy"]:
        skey = strat["type"]
        allstrats = find_strategy_type(skey, cfg)
        if not strat["best_only"]:
            ordered_strats[0] += allstrats
        else:
            ordered_strats[1] += allstrats
            monitor_best = strat["monitor"]
            monitor_best_op = strat.get("monitor_op", monitor_best_op)

    diff_strats = recursive_strategy(ordered_strats[0])
    tmp_strats = recursive_strategy(ordered_strats[1])

    def _adapt_cfg(cfg, strat):
        new_cfg = copy.deepcopy(cfg)
        for key, params in strat.items():
            mod_dict = new_cfg
            for pkey in params["parents"]:
                mod_dict = mod_dict[pkey]
            mod_dict[key] = params["value"]
        return new_cfg

    for d_cnt, dstrat in enumerate(tqdm(diff_strats, desc="Diff. Strats.")):
        d_cfg = _adapt_cfg(cfg, dstrat)
        d_logdir = os.path.join(logdir, str(d_cnt))

        # Start hyper parameter search
        for t_cnt, tstrat in enumerate(tqdm(tmp_strats, desc="Excl. Strats.", leave=False)):
            t_cfg = _adapt_cfg(d_cfg, tstrat)
            t_logdir = os.path.join(d_logdir, "tmp", str(t_cnt))
            execute_strategy(t_cfg, logdir=t_logdir)

        # Keep only best result with hyper parameter
        best_id = 0
        if monitor_best is not None and len(tmp_strats) > 1:
            op = {
                "less": lambda x, y: x < y,
                "more": lambda x, y: x > y
            }[monitor_best_op]
            best = {
                "less": np.Inf,
                "more": -np.Inf
            }[monitor_best_op]  # FIXME: be consistent with modelcheckpoint
            for t_cnt in range(len(tmp_strats)):
                # TODO: add support for pruned-training where this should
                #       be inside!
                t_fname = os.path.join(d_logdir, "tmp", str(t_cnt), "val.csv")
                df = pd.read_csv(t_fname)
                t_value = df[monitor_best].max()
                if op(t_value, best):
                    best_id = t_cnt
                    best = t_value
        best_dir = os.path.join(d_logdir, "tmp", str(best_id))
        for fname in os.listdir(best_dir):
            try:
                shutil.move(os.path.join(best_dir, fname), d_logdir)
            except shutil.Error:
                print(f"Could not move {fname}")
        shutil.rmtree(os.path.join(d_logdir, "tmp"))


if __name__ == "__main__":
    main()
