import argparse
import multiprocessing
import os
import random
import subprocess
from copy import deepcopy

import yaml
from tqdm.contrib.concurrent import process_map

from scripts import utils
from utils.common_utils import merge_excels

parser = argparse.ArgumentParser(description='automl')
parser.add_argument('--cfg_file', default='./configs/cdan_visda.yml', 
                    type=str, help='the config file for hyper parameters setting')
parser.add_argument('--root', metavar='DIR', default='/home/username/datasets',
                        help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='VisDA2017', choices=utils.get_dataset_names(),
                    help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
                         ' (default: VisDA2017)')
parser.add_argument('-s', '--source', default='Syn', help='source domain(s)', nargs='+')
parser.add_argument('-t', '--target', default='Real', help='target domain(s)', nargs='+')
parser.add_argument('--method', type=str, default='CDAN',
                    help="The DA training method")
parser.add_argument('--output_path', type=str, 
                    default='/home/username/DAmetric_logs/visda_cdan', 
                    help="The log dir")
parser.add_argument('--epochs', default=20, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
                    help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
                    metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=12345, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--hp_limit', default=16, type=int,
                    help='maximum number of hp to search. ')
parser.add_argument('--n_gpu', default=4, type=int, help='number of gpu')
parser.add_argument('--gpus', default=None, nargs='+', help='specify the GPUs used')
parser.add_argument('--workers', default=4, type=int, help='the number of process or thread')
parser.add_argument('--eval_mode', action='store_true', default=False,
                        help='evaluation the model')
parser.add_argument('--per-class-eval', action='store_true',
                        help='whether output per-class accuracy during evaluation')
args = parser.parse_args()


def to_numerical(obj: dict):
    for k, v in obj.items():
        if isinstance(v, str):
            obj[k] = float(v)


def transform_cfg(ori_cfg):
    cfg = deepcopy(ori_cfg)
    for k, v_obj in cfg.items():
        if isinstance(v_obj, dict):
            v_list = []
            if 'step' in v_obj:
                to_numerical(v_obj)
                v = v_obj['low']
                while v <= v_obj['high']:
                    v_list.append(v)
                    v += v_obj['interval']
            elif 'multiple' in v_obj:
                to_numerical(v_obj)
                v = v_obj['start']
                while v <= v_obj['end']:
                    v_list.append(v)
                    v *= v_obj['multiple']
            elif 'value' in v_obj:
                to_numerical(v_obj)
                v_list.append(v_obj['value'])
            cfg[k] = v_list
    return cfg


class HPSearcher:
    def __init__(self, cfg):
        self.cfg_list = list(cfg.items())
        self.n_hp = len(self.cfg_list)
        self.hp_combinations = []

    def search(self, i, path):
        final_hp = i == self.n_hp - 1
        hp_name, hp_values = self.cfg_list[i]
        for hp_value in hp_values:
            if final_hp:
                self.hp_combinations.append(path + [(hp_name, hp_value)])
            else:
                self.search(i + 1, path + [(hp_name, hp_value)])

    def reset(self, cfg=None):
        if cfg:
            self.cfg_list = list(cfg.items())
            self.n_hp = len(self.cfg_list)
        self.hp_combinations = []

    def get_hp_combinations(self, limit=-1):
        self.search(0, [])
        total_comb = len(self.hp_combinations)
        if limit <= 0 or limit >= total_comb:
            return self.hp_combinations
        indexes = sorted(random.choices(range(total_comb), k=limit))
        return [self.hp_combinations[i] for i in indexes]


def process(resources):
    hp_combination, devices = resources
    output_path = os.path.join(args.output_path, '-'.join([f'#{k},{v}#' for k, v in hp_combination]))
    device = devices.get()  # Access gpu resource if devices is not none; otherwise, wait.
    
    if not args.eval_mode:
        cmd = f"CUDA_VISIBLE_DEVICES={device} python train.py --method {args.method} --log {output_path} " \
            f"--root {args.root} -d {args.data} -s {' '.join(args.source)} -t {' '.join(args.target)} " \
            f"--epochs {args.epochs} -i {args.iters_per_epoch} -p {args.print_freq} --seed {args.seed} " \
            f"--opt {' '.join([f'{k} {v}' for k, v in hp_combination])} "
        # print(cmd)
        result_file = os.path.join(output_path, "metric_scores.xlsx")
        if not os.path.exists(result_file):
            print(output_path)
            result = subprocess.run(cmd, shell=True)
        else:
            result = subprocess.run(f"echo '{result_file}'", shell=True)
    else:
        cmd = f"CUDA_VISIBLE_DEVICES={device} python evaluation.py --method {args.method} --log {output_path} " \
            f"-d {args.data}"
        if args.per_class_eval:
            cmd += " --per-class-eval"
        print(output_path)
        result = subprocess.run(cmd, shell=True)
    devices.put(device)  # Revert gpu resource
    return result


if __name__ == "__main__":
    # cfg_file = os.path.join(os.path.dirname(__file__), "config.yml")
    with open(args.cfg_file, 'r', encoding='utf8') as f:
        ori_cfg = yaml.full_load(f)
    cfg = transform_cfg(ori_cfg)
    print(cfg)
    hp_searcher = HPSearcher(cfg=cfg)
    hp_combinations = hp_searcher.get_hp_combinations(limit=args.hp_limit)
    print(hp_combinations)
    print(f"total combinations of hype-parameters: {len(hp_combinations)}")
    # exit(0)
    n_gpu = len(args.gpus) if args.gpus else args.n_gpu
    gpus = multiprocessing.Manager().Queue(maxsize=n_gpu)
    for i in args.gpus if args.gpus else range(args.n_gpu):
        gpus.put_nowait(i)
    results = process_map(process, [(hp_combination, gpus) for hp_combination in hp_combinations], max_workers=n_gpu)
    n_failed = sum(result.returncode for result in results)
    if n_failed > 0:
        for result in results:
            print(result)
    print(f"finished. total: {len(results)}, failed: {n_failed}")
    merge_excels(args.output_path)