#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : utils.py
# Author : Anonymous1
# Email  : anonymous1@anon
#
# Distributed under terms of the MIT license.

import imp
import os.path as osp
import numpy as np
import time
import pandas as pd

__all__ = [
    "prod",
    "get_inds_given_nums",
    "HyperParametersChoices",
    "get_hps_choices",
    "get_inds",
    "get_hps_dict",
    "get_selected_names",
    "hps_kv_to_str",
    "get_name_by_hps_dict",
    "get_num_choices",
    "load_source",
    "get_localtime_str",
    "get_stats",
    "get_stats_str",
    "get_summary",
    "is_better",
    "print_list",
]


def prod(int_list):
    s = 1
    for i in int_list:
        assert type(i) is int
        s = s * i
    return s


def get_inds_given_nums(ind, nums, control_variable=False):
    inds = []
    for num in nums:
        if control_variable:
            if ind < num:
                inds.append(ind)
                ind = 0
            else:
                inds.append(0)
                ind = ind - (num - 1)
        else:
            inds.append(ind % num)
            ind = ind // num
    return inds


class HyperParametersChoices:
    def __init__(self, config, is_combo=False, is_bool=False, control_variable=False):
        self.config = config
        self.is_combo = is_combo
        self.is_bool = is_bool
        self.length = len(config)
        self.control_variable = control_variable
        self.num = self.get_num()

    # The total number of choices
    def get_num(self):
        if self.control_variable:
            if self.is_bool:
                return len(self.config) + 1
            else:
                return sum([len(v) - 1 for k, v in self.config.items()]) + 1
        else:  # grid search
            if self.is_bool:
                return 2 ** len(self.config)
            else:
                return prod([len(v) for k, v in self.config.items()])

    def get_kv_pairs(self, index):
        kv_pairs = []
        if self.is_bool:
            nums = [2] * len(self.config)
        else:
            nums = [len(v) for k, v in self.config.items()]
        inds = get_inds_given_nums(index, nums, control_variable=self.control_variable)
        if self.is_bool:
            for i, k in enumerate(self.config):
                default_choice = 1
                # use ~ at the beginning to change the default choice to True
                if k.startswith("~"):
                    k = k[1:]
                    default_choice = 0
                val = inds[i] == default_choice
                kv_pairs.append((k, val))
        else:
            for i, (k, v) in enumerate(self.config.items()):
                val = v[inds[i]]
                kv_pairs.append((k, val))
        return kv_pairs

    def get_item(self, index):
        kv_pairs = self.get_kv_pairs(index)
        hps_dict = {}
        for p in kv_pairs:
            k, v = p
            if self.is_combo:
                hps_dict.update(v)
            else:
                hps_dict[k] = v
        return hps_dict

    def get_selected_names(self, index):
        kv_pairs = self.get_kv_pairs(index)
        selected_names = []
        for kv_pair in kv_pairs:
            k, v = kv_pair
            selected_names.append((k, "{}".format(v)))
        return selected_names


def get_hps_choices(config, control_variable=False):
    cv = control_variable
    combo_hps = HyperParametersChoices(
        config["combo_args"], is_combo=True, control_variable=cv
    )
    single_hps = HyperParametersChoices(config["single_args"], control_variable=cv)
    list_hps = HyperParametersChoices(config["list_args"], control_variable=cv)
    bool_hps = HyperParametersChoices(
        config["bool_args"], is_bool=True, control_variable=cv
    )
    return combo_hps, single_hps, list_hps, bool_hps


def get_inds(ind, all_hps_choices, control_variable=False):
    nums = [hps_choices.num for hps_choices in all_hps_choices]
    return get_inds_given_nums(ind, nums, control_variable=control_variable)


def get_hps_dict(ind, all_hps_choices, control_variable=False):
    hps_dict = {}
    inds = get_inds(ind, all_hps_choices, control_variable=control_variable)
    for x, hps_choices in zip(inds, all_hps_choices):
        hps_dict.update(hps_choices.get_item(x))
    return hps_dict


def get_selected_names(ind, all_hps_choices, control_variable=False):
    selected_names = []
    inds = get_inds(ind, all_hps_choices, control_variable=control_variable)
    for x, hps_choices in zip(inds, all_hps_choices):
        selected_names.extend(hps_choices.get_selected_names(x))
    return selected_names


def hps_kv_to_str(k, v, to_cmd=False):
    assert type(k) is str, "key should be str"
    if to_cmd:
        k = "-" + k
        sep = " "
    else:
        if k.startswith("-"):
            k = k[1:]
        k = k.replace("-", "_")
        sep = "_"
    if type(v) is bool:
        if v:
            return k
        else:
            return ""
    elif type(v) is list:
        if len(v) == 0:
            return ""
        v_str = list(map(lambda x: str(x), v))
        return "{}{}{}".format(k, sep, sep.join(v_str))
    else:
        return "{}{}{}".format(k, sep, v)


def get_name_by_hps_dict(hps_dict, prefix=""):
    name = prefix
    for k, v in hps_dict.items():
        kv_str = hps_kv_to_str(k, v)
        if len(kv_str) > 0:
            if len(name) > 0:
                name += "_"
            name += kv_str
    return name


def get_num_choices(hps_choices, control_variable=False):
    num_choices = list(map(lambda x: x.num, hps_choices))
    if control_variable:
        return sum(num_choices) - len(num_choices) + 1
    return prod(num_choices)


def get_localtime_str():
    return time.strftime("%Y_%m_%d__%H_%M_%S", time.localtime())


def get_stats(data):
    data = np.array(data)
    if len(data) == 0:
        data = np.array([0])
        print("[Warning] getting stats for empty data")
    return dict(mean=data.mean(), std=data.std(), max=data.max(), min=data.min())


def f2str(x, sig_fig=None, precision=None):
    if sig_fig is not None:
        return f"{x:.{sig_fig}}"
    if precision is not None:
        return f"{x:.{precision}f}"
    return f"{x}"


def f2str_with_proper_length(x, length):
    if x < 1:
        return f2str(x, precision=length - 2)  # sig_fig=3
    else:
        return f2str(x, precision=length - 3)


def get_stats_str(stats, precision=5, proper_length=None):
    res = []
    for key in ["mean", "std", "min", "max"]:
        value = stats[key]
        if proper_length is not None:
            value_str = f2str_with_proper_length(value, proper_length)
        else:
            value_str = f2str(value, precision=precision)
        res.append(f"{key}: {value_str}")
    return f"[{', '.join(res)}]"


def is_better(x, y, smaller_better=False):
    if y is None:
        return True
    # choose latest
    if smaller_better:
        return x <= y
    return x >= y


def get_summary(summary_file, smaller_better=False, key="res", complete_runs=False):
    runs = pd.read_csv(summary_file)
    num_runs = runs["run"].max() + 1
    best_val = [None] * num_runs
    best_epoch = [None] * num_runs
    avg_time = [0.0] * num_runs
    cur_epoch = 0
    for row in runs.itertuples():
        val_res = getattr(row, f"val_{key}")
        run_id = row.run
        avg_time[run_id] = getattr(row, "time_s", 0.0)
        if is_better(val_res, best_val[run_id], smaller_better=smaller_better):
            best_val[run_id] = val_res
            best_epoch[run_id] = row
        cur_epoch = row.epoch
    if complete_runs:
        if cur_epoch < runs["epoch"].max():
            best_epoch = best_epoch[:-1]
    best_epoch_id = [row.epoch for row in best_epoch]
    best_train_ress = [getattr(row, f"train_{key}") for row in best_epoch]
    best_val_ress = [getattr(row, f"val_{key}") for row in best_epoch]
    best_test_ress = [getattr(row, f"test_{key}") for row in best_epoch]
    return runs, (
        best_epoch_id,
        best_train_ress,
        best_val_ress,
        best_test_ress,
        np.mean(avg_time),
    )


def print_list(name, a):
    if len(a) == 0:
        return None
    st = name + "\t"
    for i in a:
        if type(i) is int:
            st += f"{i:7d}   "
        else:
            st += f"{i:.5f}   "
    print(st)


# https://github.com/vacancy/Jacinle/blob/master/jacinle/utils/imp.py#L39
def load_source(filename):
    basename = osp.basename(filename)
    if basename.endswith(".py"):
        basename = basename[:-3]
    basename = basename.replace(".", "_")
    return imp.load_source(basename, filename)
