#!/usr/bin/env python3
import os
import pickle
import tempfile
from datetime import datetime
from pathlib import Path
from typing import *

import hydra
import mlflow
import optuna
import pytz
from omegaconf import DictConfig, OmegaConf
from optuna import Trial

import approaches
import utils
from approaches.abst_appr import AbstractAppr
from dataloader import get_shuffled_dataloder
from mymetrics import MyMetrics
from utils import BColors, myprint as print, suggest_float, suggest_int


def instance_appr(trial: Trial, cfg: DictConfig,
                  list__ncls: List[int], inputsize: Tuple[int, ...],
                  dict__idx_task__dataloader: Dict[int, Dict[str, Any]]) -> AbstractAppr:
    if cfg.device is None:
        import torch
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        device = cfg.device
    # endif
    print(f'device: {device}', bcolor=BColors.OKBLUE)

    expname = cfg.expname
    lr = suggest_float(trial, cfg, 'lr')
    lr_factor = suggest_float(trial, cfg, 'lr_factor')
    lr_min = suggest_float(trial, cfg, 'lr_min')
    epochs_max = cfg.epochs_max
    patience_max = cfg.patience_max
    seed_pt = cfg.seed_pt
    batch_size = cfg.batch_size

    def drops() -> Tuple[float, float]:
        if cfg.seed == 1:
            drop1 = suggest_float(trial, cfg, 'drop1')
            drop2 = suggest_float(trial, cfg, 'drop2')
        else:
            drop1 = suggest_float(trial, cfg, 'appr', 'tuned', cfg.seq.name, 'drop1')
            drop2 = suggest_float(trial, cfg, 'appr', 'tuned', cfg.seq.name, 'drop2')
        # endif
        return drop1, drop2
    # enddef

    if cfg.appr.name.lower() == 'hat':
        smax = suggest_float(trial, cfg, 'smax')
        drop1, drop2 = drops()
        if cfg.seed == 1:
            lamb = suggest_float(trial, cfg, 'lamb')
        else:
            lamb = suggest_float(trial, cfg, 'appr', 'tuned', cfg.seq.name, 'lamb')
        # endif
        appr = approaches.appr_hat.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                        lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                        epochs_max=epochs_max, patience_max=patience_max,
                                        smax=smax, lamb=lamb,
                                        drop1=drop1, drop2=drop2,
                                        )
    elif cfg.appr.name.lower() == 'pathnet':
        drop1, drop2 = drops()
        appr = approaches.appr_pathnet_wrap.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                                 lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                                 epochs_max=epochs_max, patience_max=patience_max,
                                                 batch_size=batch_size,
                                                 drop1=drop1, drop2=drop2,
                                                 )
    elif cfg.appr.name.lower() == 'acl':
        drop1, drop2 = drops()
        checkpoint = str(Path(f'./checkpoint_{cfg.expname}').resolve())
        os.mkdir(checkpoint)
        appr = approaches.appr_acl_wrap.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                             lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                             epochs_max=epochs_max, patience_max=patience_max,
                                             batch_size=batch_size,
                                             drop1=drop1, drop2=drop2,
                                             checkpoint=checkpoint,
                                             )
    elif cfg.appr.name.lower() == 'stl':
        drop1, drop2 = drops()
        appr = approaches.appr_stl.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                        lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                        epochs_max=epochs_max, patience_max=patience_max,
                                        drop1=drop1, drop2=drop2)
    elif cfg.appr.name.lower() == 'ncl':
        drop1, drop2 = drops()
        appr = approaches.appr_ncl.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                        lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                        epochs_max=epochs_max, patience_max=patience_max,
                                        drop1=drop1, drop2=drop2)
    elif cfg.appr.name.lower() == 'supsup':
        log_dir = str(Path(f'./logs_{expname}').resolve())
        if cfg.seed == 1:
            sparsity = suggest_int(trial, cfg, 'appr', 'sparsity')
            momentum = suggest_float(trial, cfg, 'appr', 'momentum')
        else:
            sparsity = suggest_int(trial, cfg, 'appr', 'tuned', cfg.seq.name, 'sparsity')
            momentum = suggest_float(trial, cfg, 'appr', 'tuned', cfg.seq.name, 'momentum')
        # endif
        appr = approaches.appr_supsup_wrap.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                                lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                                epochs_max=epochs_max, patience_max=patience_max,
                                                sparsity=sparsity, momentum=momentum,
                                                expname=expname, log_dir=log_dir, batch_size=batch_size,
                                                )
    elif cfg.appr.name.lower() == 'hatewc':
        smax = suggest_float(trial, cfg, 'smax')
        drop1, drop2 = drops()
        if cfg.seed == 1:
            lamb = suggest_float(trial, cfg, 'lamb')
        else:
            lamb = suggest_float(trial, cfg, 'appr', 'tuned', cfg.seq.name, 'lamb')
        # endif
        appr = approaches.appr_hatewc.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                           lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                           epochs_max=epochs_max, patience_max=patience_max,
                                           smax=smax, lamb=lamb,
                                           drop1=drop1, drop2=drop2)
    elif cfg.appr.name.lower() == 'cat':
        smax = suggest_float(trial, cfg, 'smax')
        drop1, drop2 = drops()
        if cfg.seed == 1:
            lamb = suggest_float(trial, cfg, 'lamb')
            nheads = suggest_int(trial, cfg, 'appr', 'nheads')
        else:
            lamb = suggest_float(trial, cfg, 'appr', 'tuned', cfg.seq.name, 'lamb')
            nheads = suggest_int(trial, cfg, 'appr', 'tuned', cfg.seq.name, 'nheads')
        # endif
        appr = approaches.appr_cat_wrap.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                             lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                             epochs_max=epochs_max, patience_max=patience_max,
                                             smax=smax, lamb=lamb,
                                             drop1=drop1, drop2=drop2, nheads=nheads,
                                             dict__idx_task__dataloader=dict__idx_task__dataloader)
    elif cfg.appr.name.lower() == 'prm':
        smax = suggest_float(trial, cfg, 'smax')
        drop1, drop2 = drops()
        if cfg.seed == 1:
            lamb = suggest_float(trial, cfg, 'lamb')
        else:
            lamb = suggest_float(trial, cfg, 'appr', 'tuned', cfg.seq.name, 'lamb')
        # endif
        ablation = cfg.ablation
        appr = approaches.appr_prm.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                        lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                        epochs_max=epochs_max, patience_max=patience_max,
                                        smax=smax, lamb=lamb, seed_pt=seed_pt, ablation=ablation,
                                        drop1=drop1, drop2=drop2,
                                        )
    elif cfg.appr.name.lower() == 'prmwo2so':
        smax = suggest_float(trial, cfg, 'smax')
        drop1, drop2 = drops()
        if cfg.seed == 1:
            lamb = suggest_float(trial, cfg, 'lamb')
        else:
            lamb = suggest_float(trial, cfg, 'appr', 'tuned', cfg.seq.name, 'lamb')
        # endif
        ablation = cfg.appr.ablation
        appr = approaches.appr_prmwo2so.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                             lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                             epochs_max=epochs_max, patience_max=patience_max,
                                             smax=smax, lamb=lamb, seed_pt=seed_pt, ablation=ablation,
                                             drop1=drop1, drop2=drop2,
                                             )
    elif cfg.appr.name.lower() == 'hypernet':
        out_dir = str(Path(f'./logs_{expname}').resolve())
        if cfg.seed == 1:
            lamb = suggest_float(trial, cfg, 'lamb')
        else:
            lamb = suggest_float(trial, cfg, 'appr', 'tuned', cfg.seq.name, 'lamb')
        # endif
        appr = approaches.appr_hypernet_wrap.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                                  lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                                  epochs_max=epochs_max, patience_max=patience_max,
                                                  lamb=lamb, batch_size=batch_size,
                                                  out_dir=out_dir, expname=expname,
                                                  )
    else:
        raise NotImplementedError
    # endif

    return appr


def load_dataloader(cfg: DictConfig) -> Dict[int, Dict[str, Any]]:
    basename_data = f'seq={cfg.seq.name}_seed={cfg.seed}'
    dirpath_data = os.path.join(hydra.utils.get_original_cwd(), 'data')

    # load data
    filepath_pkl = os.path.join(dirpath_data, f'{basename_data}.pkl')
    if os.path.exists(filepath_pkl):
        with open(filepath_pkl, 'rb') as f:
            dict__idx_task__dataloader = pickle.load(f)
        # endwith
        print(f'Loaded from {filepath_pkl}', bcolor=BColors.OKBLUE)
    else:
        dict__idx_task__dataloader = get_shuffled_dataloder(cfg)
        with open(filepath_pkl, 'wb') as f:
            pickle.dump(dict__idx_task__dataloader, f)
        # endwith
    # endif

    # compute hash
    num_tasks = len(dict__idx_task__dataloader.keys())
    hash = []
    for idx_task in range(num_tasks):
        name = dict__idx_task__dataloader[idx_task]['fullname']
        ncls = dict__idx_task__dataloader[idx_task]['ncls']
        num_train = len(dict__idx_task__dataloader[idx_task]['train'].dataset)
        num_val = len(dict__idx_task__dataloader[idx_task]['val'].dataset)
        num_test = len(dict__idx_task__dataloader[idx_task]['test'].dataset)

        msg = f'idx_task: {idx_task}, name: {name}, ncls: {ncls}, num: {num_train}/{num_val}/{num_test}'
        hash.append(msg)
    # endfor
    hash = '\n'.join(hash)

    # check hash
    filepath_hash = os.path.join(dirpath_data, f'{basename_data}.txt')
    if os.path.exists(filepath_hash):
        with open(filepath_hash, 'rt') as f:
            hash_target = f.read()
        # endwith
        assert hash_target == hash

        print(f'Succesfully matched to {filepath_hash}', bcolor=BColors.OKBLUE)
    else:
        # save hash
        with open(filepath_hash, 'wt') as f:
            f.write(hash)
        # endwith
    # endif

    return dict__idx_task__dataloader


def outer_objective(cfg: DictConfig, expid: str) -> Callable[[Trial], float]:
    dict__idx_task__dataloader = load_dataloader(cfg)

    num_tasks = len(dict__idx_task__dataloader.keys())
    list__name = [dict__idx_task__dataloader[idx_task]['name'] for idx_task in range(num_tasks)]
    list__ncls = [dict__idx_task__dataloader[idx_task]['ncls'] for idx_task in range(num_tasks)]
    inputsize = dict__idx_task__dataloader[0]['inputsize']  # type: Tuple[int, ...]

    def objective(trial: Trial) -> float:
        appr = instance_appr(trial, cfg, list__ncls=list__ncls, inputsize=inputsize,
                             dict__idx_task__dataloader=dict__idx_task__dataloader)

        with mlflow.start_run(experiment_id=expid):
            mlflow.log_params(trial.params)
            print(f'\n'
                  f'******* trial params *******\n'
                  f'{trial.params}\n',
                  f'****************************', bcolor=BColors.OKBLUE)

            list__dl_test = [dict__idx_task__dataloader[idx_task]['test']
                             for idx_task in range(num_tasks)]
            mm = MyMetrics(list__name, list__dl_test=list__dl_test)

            for idx_task in range(num_tasks):
                # dataloader
                dl_train = dict__idx_task__dataloader[idx_task]['train']
                dl_val = dict__idx_task__dataloader[idx_task]['val']

                time_consumed = appr.train(idx_task=idx_task, dl_train=dl_train, dl_val=dl_val,
                                           args_on_forward={},
                                           args_on_after_backward={})
                appr.complete_learning(idx_task=idx_task)
                mm.add_record_time(idx_task, time_consumed)

                # test for all previous tasks
                for t_prev in range(idx_task + 1):
                    results_test = appr.test(t_prev, dict__idx_task__dataloader[t_prev]['test'],
                                             args_on_forward={})
                    loss_test, acc_test = results_test['loss_test'], results_test['acc_test']

                    # record
                    mm.add_record(idx_task_learned=idx_task, idx_task_tested=t_prev,
                                  loss=loss_test, acc=acc_test)
                # endfor | t_prev
                with tempfile.TemporaryDirectory() as dir:
                    metrics_final, list__artifacts = mm.save(dir, idx_task)
                    for k, v in metrics_final[idx_task].items():
                        trial.set_user_attr(k, v)
                    # endfor

                    mlflow.log_metrics(metrics_final[idx_task], step=idx_task)
                    for artifact in list__artifacts:
                        mlflow.log_artifact(artifact)
                    # endfor
                # endwith
            # endfor | idx_task

            obj = metrics_final[num_tasks - 1]['acc__Overall']
        # endwith | mlflow.start_run()

        return obj
    # enddef

    return objective


@hydra.main(config_path='conf', config_name='config')
def main(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))

    utils.set_seed(cfg.seed, cfg.seed_pt)
    mlflow.pytorch.autolog()
    expname = cfg.expname
    expid = mlflow.create_experiment(expname)

    study = optuna.create_study(direction=cfg.optuna.direction,
                                storage=cfg.optuna.storage,
                                sampler=optuna.samplers.TPESampler(seed=cfg.seed_pt),
                                load_if_exists=False,
                                study_name=expname,
                                )
    # study.set_user_attr()
    study.set_user_attr('Completed', False)
    if cfg.seed == 1:
        n_trials = cfg.optuna.n_trials
    else:
        n_trials = 1
    # endif
    study.optimize(outer_objective(cfg, expid), n_trials=n_trials,
                   gc_after_trial=True, show_progress_bar=True)
    study.set_user_attr('Completed', True)
    print(f'best params: {study.best_params}')
    print(f'best value: {study.best_value}')
    print(study.trials_dataframe())
    print(f'{expname}')


if __name__ == '__main__':
    OmegaConf.register_new_resolver('nowjst',
                                    lambda pattern:
                                    datetime.now(pytz.timezone('Asia/Tokyo')).strftime(pattern))
    main()
