# -*- coding: utf-8 -*-
import argparse
import os
from copy import deepcopy
import glob
import re

import pickle
import numpy as np
import optuna
import yaml
from optuna.trial import TrialState

from utils.common_utils import merge_excels

parser = argparse.ArgumentParser(description='automl')
parser.add_argument('--cfg_file', default='./configs/cdan_officehome_search.yml',
                    type=str, help='the config file for hyper parameters setting')
parser.add_argument('--gpu_id', default=None, type=str, help='specify the GPUs used')
parser.add_argument('--n_trials', default=50, type=int, help='number of trials of optuna')
parser.add_argument('--opt_metric', type=str, default='accuracy')
parser.add_argument('--sampler', type=str, default='random')
parser.add_argument('--output_path', type=str, 
                    default='/home/username/DAmetric_logs/cdan_officehome_search', 
                    help="The log dir")
args, unknown = parser.parse_known_args()
if args.gpu_id is not None:
    print(args.gpu_id)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
from train import get_args, main
args_main = get_args(parser)


class OptParamParser(object):
    def __init__(self, opt_metric, cfg=None, trial=None):
        self.opt_metric = opt_metric
        self.cfg = cfg
        self.trial = trial

    def set_cfg(self, cfg):
        self.cfg = cfg

    def set_trial(self, trial):
        self.trial = trial

    def parse(self, name, cast=None):
        if not self.cfg:
            raise RuntimeError("cfg is None, please set it before parse")
        if not self.trial:
            raise RuntimeError("trial is None, please set it before parse")
        if name not in self.cfg:
            raise KeyError(f"param {name} can't be parse, please set it in conf file")
        value = self.cfg[name]
        if isinstance(value, list):
            value = self.trial.suggest_categorical(name, value)
        elif isinstance(value, dict):
            if isinstance(value['low'], int):
                value = self.trial.suggest_int(name, **value)
                print(f'{name} int: {value}')
            else:
                value = self.trial.suggest_float(name, **value)
                value = np.round(value, 8)
                print(f'{name} float: {value}')
        if cast:
            value = cast(value)
        return value

def to_numerical(obj: dict):
    for k, v in obj.items():
        if isinstance(v, str):
            obj[k] = float(v)

def transform_cfg(ori_cfg):
    cfg = deepcopy(ori_cfg)
    for k, v_obj in cfg.items():
        if isinstance(v_obj, dict):
            to_numerical(v_obj)
    return cfg

opt_param_paser = OptParamParser(args.opt_metric)

def objective(trial):
    opt_param_paser.set_trial(trial)
    
    args_main.lr = opt_param_paser.parse('lr', cast=float)
    args_main.weight_decay = opt_param_paser.parse('weight_decay', cast=float)
    if args_main.method == 'DANN':
        args_main.lr_multi_D = opt_param_paser.parse('lr_multi_D', cast=float)
    elif args_main.method == 'CDAN':
        args_main.lr_multi_D = opt_param_paser.parse('lr_multi_D', cast=float)
        args_main.entropy = opt_param_paser.parse('entropy')
        if "domainnet" in args_main.data.lower():
            args_main.randomized = True
            args_main.randomized_dim = 51200
    elif args_main.method == 'MCC':
        args_main.temperature = opt_param_paser.parse('temperature')
    elif args_main.method == 'MDD':
        args_main.margin = opt_param_paser.parse('margin')
    args_main.trade_off = opt_param_paser.parse('trade_off', cast=float)
    args_main.bottleneck_dim = opt_param_paser.parse('bottleneck_dim')
    args_main.batch_size = opt_param_paser.parse('batch_size')
    if "office" in args_main.data.lower():
        args_main.k_fold = opt_param_paser.parse('k_fold')
    args_main.save_checkpoints = opt_param_paser.parse('save_checkpoints')
    args_main.imagenet_test = opt_param_paser.parse('imagenet_test')
    if "visda" in args_main.data.lower():
        args_main.train_resizing = opt_param_paser.parse('train_resizing')
    args_main.arch = opt_param_paser.parse('arch')
    result_path = os.path.join(args.output_path, f'trial{trial.number}-'+ \
                    '-'.join([f'#{k},{eval(f"args_main.{k}")}#' for k in opt_param_paser.cfg]))
    result_file = os.path.join(result_path, "metric_scores.xlsx")
    print(f"result_file: {result_file}")
    exist_result_file = glob.glob("*".join(re.split(r'trial\d+', result_file)))
    if len(exist_result_file)>0:
        print("result_file exists:", exist_result_file)
        import pandas as pd
        df = pd.read_excel(exist_result_file[0], engine='openpyxl')
        opt_metrics = df[args.opt_metric].values
        return max(opt_metrics)
    args_main.log = result_path
    all_metric_scores = main(args_main, opt_param_paser)
    opt_metrics = [metric_score[args.opt_metric] for metric_score in all_metric_scores]
    return max(opt_metrics)

def save_study(study, trial):
    study_path = os.path.join(args.output_path, "study.pkl")
    with open(study_path, "wb") as f:
        pickle.dump(study, f)

if __name__ == "__main__":
    with open(args.cfg_file, 'r', encoding='utf8') as f:
        ori_cfg = yaml.full_load(f)
    cfg = transform_cfg(ori_cfg)
    print(cfg)
    opt_param_paser.set_cfg(cfg)

    study_path = os.path.join(args.output_path, "study.pkl")
    if os.path.exists(study_path):
        with open(study_path, "rb") as f:
            study = pickle.load(f)
        print("Successfully load study from", study_path)
    else:
        if args.sampler == "random":
            sampler = optuna.samplers.RandomSampler(seed=0) 
        elif args.sampler == "TPE":
            sampler = optuna.samplers.TPESampler(n_startup_trials=10, multivariate=True, seed=0)
        pruner = optuna.pruners.MedianPruner(n_startup_trials=1)
        study = optuna.create_study(direction="maximize", sampler=sampler, pruner=pruner)

    n_trials = max(args.n_trials-len(study.trials), 0)
    study.optimize(objective, n_trials=n_trials, timeout=None, show_progress_bar=True, 
                    callbacks=[save_study])

    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Number of pruned trials: ", len(pruned_trials))
    print("  Number of complete trials: ", len(complete_trials))
    print("Best trial:")
    trial = study.best_trial
    print("  Value: ", trial.value)
    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))
    merge_excels(args.output_path)

    
    