#!/usr/bin/env python3
import os
import pickle
import tempfile
from datetime import datetime
from pathlib import Path
from typing import *

import hydra
import mlflow
import optuna
import pytz
import torch
from omegaconf import DictConfig, OmegaConf
from optuna import Trial

import approaches
import utils
from approaches.abst_appr import AbstractAppr
from approaches.param_consumable import ParamConsumable
from dataloader import get_shuffled_dataloder
from mymetrics import MyMetrics
from utils import BColors, myprint as print, suggest_float, suggest_int


def instance_appr(trial: Trial, cfg: DictConfig,
                  list__ncls: List[int], inputsize: Tuple[int, ...],
                  dict__idx_task__dataloader: Dict[int, Dict[str, Any]]) -> AbstractAppr:
    if cfg.device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        device = cfg.device
    # endif
    print(f'device: {device}', bcolor=BColors.OKBLUE)

    expname = cfg.expname
    lr = suggest_float(trial, cfg, 'backbone', 'lr')
    lr_factor = suggest_float(trial, cfg, 'lr_factor')
    lr_min = suggest_float(trial, cfg, 'lr_min')
    epochs_max = cfg.epochs_max
    patience_max = cfg.patience_max
    batch_size = cfg.seq.batch_size
    backbone = cfg.backbone.name
    psearch = cfg.psearch
    nhid = cfg.backbone.nhid
    ablation = cfg.ablation

    seed_pt = suggest_int(trial, cfg, 'seed_pt')
    utils.set_seed_pt(seed_pt + cfg.seedoffset)

    log_dir = str(Path(f'./logs_{expname}').resolve())

    list__dl_train = [dict__idx_task__dataloader[t]['train'] for t in range(len(list__ncls))]  # type: List[DataLoader]
    list__dl_val = [dict__idx_task__dataloader[t]['val'] for t in range(len(list__ncls))]  # type: List[DataLoader]
    list__dl_test = [dict__idx_task__dataloader[t]['test'] for t in range(len(list__ncls))]  # type: List[DataLoader]

    def fetch_param_float(*pnames: str) -> float:
        if psearch:
            v = suggest_float(trial, cfg, *pnames)
        else:
            v = suggest_float(trial, cfg, 'appr', 'tuned', cfg.seq.name, pnames[-1])
        # endif

        return v
    # enddef

    def fetch_param_int(*pnames: str) -> int:
        if psearch:
            v = suggest_int(trial, cfg, *pnames)
        else:
            v = suggest_int(trial, cfg, 'appr', 'tuned', cfg.seq.name, pnames[-1])
        # endif

        return v
    # enddef

    if cfg.appr.name.lower() == 'mtl':
        if backbone in ['mlp', 'alexnet']:
            drop1 = fetch_param_float('drop1')
            drop2 = fetch_param_float('drop2')
        else:
            raise NotImplementedError
        # endif

        if cfg.seq.name.lower() in ['imagenet_100']:
            small_lr = True
        else:
            small_lr = False
        # endif

        appr = approaches.appr_mtl.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                        lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                        backbone=backbone, batch_size=batch_size, epochs_max=epochs_max, patience_max=patience_max,
                                        nhid=nhid, drop1=drop1, drop2=drop2, small_lr=small_lr)
    elif cfg.appr.name.lower() == 'ncl':
        if backbone in ['mlp', 'alexnet']:
            drop1 = fetch_param_float('drop1')
            drop2 = fetch_param_float('drop2')
        else:
            raise NotImplementedError
        # endif

        appr = approaches.appr_ncl.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                        lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                        epochs_max=epochs_max, patience_max=patience_max,
                                        nhid=nhid, drop1=drop1, drop2=drop2, backbone=backbone)
    elif cfg.appr.name.lower() == 'hat':
        smax = suggest_float(trial, cfg, 'smax')
        if backbone in ['mlp', 'alexnet']:
            drop1 = fetch_param_float('drop1')
            drop2 = fetch_param_float('drop2')
            lamb = 10 ** fetch_param_float('appr', 'lamb')
        else:
            raise NotImplementedError
        # endif

        appr = approaches.appr_hat.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                        lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                        epochs_max=epochs_max, patience_max=patience_max,
                                        smax=smax, lamb=lamb,
                                        backbone=backbone,
                                        nhid=nhid, drop1=drop1, drop2=drop2,
                                        hat_enabled=True)
    elif cfg.appr.name.lower() == 'cat':
        smax = suggest_float(trial, cfg, 'smax')
        if backbone in ['mlp', 'alexnet']:
            drop1 = fetch_param_float('drop1')
            drop2 = fetch_param_float('drop2')
            nheads = fetch_param_int('appr', 'nheads')
            lamb = 10 ** fetch_param_float('appr', 'lamb')
        else:
            raise NotImplementedError
        # endif

        appr = approaches.appr_cat_wrap.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                             lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                             epochs_max=epochs_max, patience_max=patience_max,
                                             smax=smax, lamb=lamb, backbone=backbone,
                                             nhid=nhid, drop1=drop1, drop2=drop2, nheads=nheads,
                                             dict__idx_task__dataloader=dict__idx_task__dataloader)
    elif cfg.appr.name.lower() == 'spg':
        if backbone in ['mlp', 'alexnet']:
            drop1 = fetch_param_float('drop1')
            drop2 = fetch_param_float('drop2')
            shift = 1
        else:
            raise NotImplementedError
        # endif

        appr = approaches.appr_spg.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize, batch_size=batch_size,
                                        lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                        epochs_max=epochs_max, patience_max=patience_max,
                                        backbone=backbone,
                                        nhid=nhid, drop1=drop1, drop2=drop2,
                                        shift=shift,
                                        ablation=ablation,
                                        seqname=cfg.seq.name,
                                        )
    elif cfg.appr.name.lower() == 'ewcgi':
        if backbone in ['mlp', 'alexnet']:
            drop1 = fetch_param_float('drop1')
            drop2 = fetch_param_float('drop2')
        else:
            raise NotImplementedError
        # endif
        shift = 1
        lamb = 10 ** fetch_param_float('appr', 'lamb')

        appr = approaches.appr_ewcgi.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize, batch_size=batch_size,
                                          lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                          epochs_max=epochs_max, patience_max=patience_max,
                                          backbone=backbone,
                                          nhid=nhid, drop1=drop1, drop2=drop2, lamb=lamb,
                                          shift=shift, ablation=ablation, seqname=cfg.seq.name)
    elif cfg.appr.name.lower() == 'supsup':
        if backbone in ['mlp', 'alexnet']:
            drop1 = 0
            drop2 = 0
            # drop1 = fetch_param_float('drop1')
            # drop2 = fetch_param_float('drop2')
            sparsity = 10 ** (fetch_param_float('appr', 'sparsity'))
            momentum = fetch_param_float('appr', 'momentum')
        else:
            raise NotImplementedError(backbone)
        # endif

        appr = approaches.appr_supsup_wrap.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                                lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                                epochs_max=epochs_max, patience_max=patience_max,
                                                backbone=backbone,
                                                nhid=nhid, drop1=drop1, drop2=drop2,
                                                sparsity=sparsity, momentum=momentum,
                                                expname=expname, log_dir=log_dir, batch_size=batch_size,
                                                )
    elif cfg.appr.name.lower() == 'tag':
        if backbone in ['mlp', 'alexnet']:
            drop1 = fetch_param_float('drop1')
            drop2 = fetch_param_float('drop2')
        else:
            raise NotImplementedError(backbone)
        # endif

        appr = approaches.appr_tag.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                        lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                        epochs_max=epochs_max, patience_max=patience_max,
                                        nhid=nhid, drop1=drop1, drop2=drop2, backbone=backbone)
    elif cfg.appr.name.lower() == 'ucl':
        if backbone in ['mlp', 'alexnet']:
            drop1 = fetch_param_float('drop1')
            drop2 = fetch_param_float('drop2')
            lamb = 10 ** fetch_param_float('appr', 'lamb')
            alpha = cfg.appr.alpha
            beta = cfg.appr.beta
            ratio = fetch_param_float('appr', 'ratio')
        else:
            raise NotImplementedError(backbone)
        # endif

        appr = approaches.appr_ucl_wrap.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                             lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                             epochs_max=epochs_max, patience_max=patience_max,
                                             backbone=backbone, batch_size=batch_size,
                                             nhid=nhid, lamb=lamb, drop1=drop1, drop2=drop2, ratio=ratio,
                                             list__dl_val=list__dl_val, expname=expname,
                                             alpha=alpha, beta=beta,
                                             log_dir=log_dir,
                                             )
    elif cfg.appr.name.lower() == 'pathnet':
        if backbone in ['mlp', 'alexnet']:
            drop1 = fetch_param_float('drop1')
            drop2 = fetch_param_float('drop2')
            expand_factor = fetch_param_float('appr', 'expand_factor')
            N = fetch_param_int('appr', 'N')
            M = max(N, fetch_param_int('appr', 'M'))
            N = 2 ** N
            M = 2 ** M
        else:
            raise NotImplementedError
        # endif

        appr = approaches.appr_pathnet_wrap.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                                 lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                                 epochs_max=epochs_max, patience_max=patience_max,
                                                 backbone=backbone,
                                                 batch_size=batch_size,
                                                 nhid=nhid, drop1=drop1, drop2=drop2,
                                                 expand_factor=expand_factor, M=M, N=N)
    elif cfg.appr.name.lower() == 'pgn':
        if backbone in ['mlp', 'alexnet']:
            drop1 = fetch_param_float('drop1')
            drop2 = fetch_param_float('drop2')
            expand_factor = fetch_param_float('appr', 'expand_factor')
        else:
            raise NotImplementedError
        # endif

        appr = approaches.appr_pnn_wrap.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                             lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                             epochs_max=epochs_max, patience_max=patience_max,
                                             backbone=backbone,
                                             batch_size=batch_size,
                                             nhid=nhid, drop1=drop1, drop2=drop2,
                                             expand_factor=expand_factor)
    elif cfg.appr.name.lower() == 'one':
        if backbone in ['mlp', 'alexnet']:
            drop1 = fetch_param_float('drop1')
            drop2 = fetch_param_float('drop2')
        else:
            raise NotImplementedError(backbone)
        # endif

        appr = approaches.appr_one.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                        lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                        epochs_max=epochs_max, patience_max=patience_max,
                                        backbone=backbone,
                                        nhid=nhid, drop1=drop1, drop2=drop2,
                                        )
    elif cfg.appr.name.lower() == 'agem':
        if backbone in ['mlp', 'alexnet']:
            drop1 = fetch_param_float('drop1')
            drop2 = fetch_param_float('drop2')
            buffer_size = cfg.appr.buffer_size
            buffer_percent = cfg.appr.buffer_percent
        else:
            raise NotImplementedError(backbone)
        # endif

        appr = approaches.appr_agem_wrap.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                              lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                              epochs_max=epochs_max, patience_max=patience_max,
                                              backbone=backbone,
                                              nhid=nhid, drop1=drop1, drop2=drop2,
                                              buffer_size=buffer_size, buffer_percent=buffer_percent)
    elif cfg.appr.name.lower() == 'ewc':
        smax = suggest_float(trial, cfg, 'smax')
        if backbone in ['mlp', 'alexnet']:
            drop1 = fetch_param_float('drop1')
            drop2 = fetch_param_float('drop2')
            lamb = 10 ** fetch_param_float('appr', 'lamb')
        else:
            raise NotImplementedError(backbone)
        # endif

        appr = approaches.appr_ewc.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                        lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                        epochs_max=epochs_max, patience_max=patience_max,
                                        backbone=backbone, smax=smax, lamb=lamb,
                                        nhid=nhid, drop1=drop1, drop2=drop2)
    elif cfg.appr.name.lower() == 'spgfi':
        if backbone in ['mlp', 'alexnet']:
            drop1 = fetch_param_float('drop1')
            drop2 = fetch_param_float('drop2')
        else:
            raise NotImplementedError(backbone)
        # endif

        appr = approaches.appr_spgfi.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                          batch_size=batch_size, lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                          epochs_max=epochs_max, patience_max=patience_max, backbone=backbone,
                                          nhid=nhid, drop1=drop1, drop2=drop2)
    elif cfg.appr.name.lower() == 'si':
        if backbone in ['mlp', 'alexnet']:
            drop1 = fetch_param_float('drop1')
            drop2 = fetch_param_float('drop2')
            lamb = 10 ** fetch_param_float('appr', 'lamb')
        else:
            raise NotImplementedError(backbone)
        # endif

        appr = approaches.appr_si.Appr(device=device, list__ncls=list__ncls, inputsize=inputsize,
                                       lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                                       epochs_max=epochs_max, patience_max=patience_max,
                                       backbone=backbone, lamb=lamb,
                                       nhid=nhid, drop1=drop1, drop2=drop2)
    else:
        raise NotImplementedError(cfg.appr.name)
    # endif

    return appr


def load_dataloader(cfg: DictConfig) -> Dict[int, Dict[str, Any]]:
    basename_data = f'seq={cfg.seq.name}_bs={cfg.seq.batch_size}_seed={cfg.seed}'
    dirpath_data = os.path.join(hydra.utils.get_original_cwd(), 'data')

    # load data
    filepath_pkl = os.path.join(dirpath_data, f'{basename_data}.pkl')
    if os.path.exists(filepath_pkl):
        with open(filepath_pkl, 'rb') as f:
            dict__idx_task__dataloader = pickle.load(f)
        # endwith
        print(f'Loaded from {filepath_pkl}', bcolor=BColors.OKBLUE)
    else:
        dict__idx_task__dataloader = get_shuffled_dataloder(cfg)
        with open(filepath_pkl, 'wb') as f:
            pickle.dump(dict__idx_task__dataloader, f)
        # endwith
    # endif

    # compute hash
    num_tasks = len(dict__idx_task__dataloader.keys())
    hash = []
    for idx_task in range(num_tasks):
        name = dict__idx_task__dataloader[idx_task]['fullname']
        ncls = dict__idx_task__dataloader[idx_task]['ncls']
        num_train = len(dict__idx_task__dataloader[idx_task]['train'].dataset)
        num_val = len(dict__idx_task__dataloader[idx_task]['val'].dataset)
        num_test = len(dict__idx_task__dataloader[idx_task]['test'].dataset)

        msg = f'idx_task: {idx_task}, name: {name}, ncls: {ncls}, num: {num_train}/{num_val}/{num_test}'
        hash.append(msg)
    # endfor
    hash = '\n'.join(hash)

    # check hash
    filepath_hash = os.path.join(dirpath_data, f'{basename_data}.txt')
    if os.path.exists(filepath_hash):
        with open(filepath_hash, 'rt') as f:
            hash_target = f.read()
        # endwith
        assert hash_target == hash

        print(f'Succesfully matched to {filepath_hash}', bcolor=BColors.OKBLUE)
        print(hash)
    else:
        # save hash
        with open(filepath_hash, 'wt') as f:
            f.write(hash)
        # endwith
    # endif

    return dict__idx_task__dataloader


def outer_objective(cfg: DictConfig, expid: str) -> Callable[[Trial], float]:
    dict__idx_task__dataloader = load_dataloader(cfg)

    use_dummy_1sttask = (cfg.appr.name.lower() == 'hatcompbatch')
    use_dummy_1sttask = False

    if use_dummy_1sttask:
        for idx_task in range(len(dict__idx_task__dataloader.keys()), 0, -1):
            dict__idx_task__dataloader[idx_task] = dict__idx_task__dataloader[idx_task - 1]
            pass
        # enddef
    # endif

    num_tasks = len(dict__idx_task__dataloader.keys())
    list__name = [dict__idx_task__dataloader[idx_task]['name'] for idx_task in range(num_tasks)]
    list__ncls = [dict__idx_task__dataloader[idx_task]['ncls'] for idx_task in range(num_tasks)]
    inputsize = dict__idx_task__dataloader[0]['inputsize']  # type: Tuple[int, ...]

    def objective(trial: Trial) -> float:
        appr = instance_appr(trial, cfg, list__ncls=list__ncls, inputsize=inputsize,
                             dict__idx_task__dataloader=dict__idx_task__dataloader)

        with mlflow.start_run(experiment_id=expid):
            mlflow.log_params(trial.params)
            print(f'\n'
                  f'******* trial params *******\n'
                  f'{trial.params}\n',
                  f'****************************', bcolor=BColors.OKBLUE)

            list__dl_test = [dict__idx_task__dataloader[idx_task]['test']
                             for idx_task in range(num_tasks)]
            mm = MyMetrics(list__name, list__dl_test=list__dl_test)

            for idx_task in range(num_tasks):
                # dataloader
                dl_train = dict__idx_task__dataloader[idx_task]['train']
                dl_val = dict__idx_task__dataloader[idx_task]['val']

                results_train = appr.train(idx_task=idx_task, dl_train=dl_train, dl_val=dl_val,
                                           args_on_forward={},
                                           args_on_after_backward={},
                                           list__dl_test=[list__dl_test[t] for t in range(idx_task + 1)],
                                           )
                time_consumed = results_train['time_consumed']

                if 'epoch' in results_train.keys():
                    epoch = results_train['epoch']
                else:
                    epoch = 0
                # endif

                appr.complete_learning(idx_task=idx_task, dl_train=dl_train, dl_val=dl_val)

                if isinstance(appr, ParamConsumable):
                    param_consumed = appr.compute_param_consumed(idx_task)
                else:
                    param_consumed = 0
                # endif

                mm.add_record_misc(idx_task,
                                   epoch=epoch,
                                   time_consumed=time_consumed,
                                   param_consumed=param_consumed,
                                   )
                '''
                mm_train.add_record(idx_task_learned=idx_task, idx_task_tested=idx_task,
                                    loss=results_train['loss_train'], acc=results_train['acc_train'])
                '''

                # test for all previous tasks
                for t_prev in range(idx_task + 1):
                    results_test = appr.test(t_prev, dict__idx_task__dataloader[t_prev]['test'],
                                             args_on_forward={})
                    loss_test, acc_test = results_test['loss_test'], results_test['acc_test']
                    print(f'[{t_prev}] acc: {acc_test:.3f}')

                    # record
                    mm.add_record(idx_task_learned=idx_task, idx_task_tested=t_prev,
                                  loss=loss_test, acc=acc_test)
                # endfor | t_prev

                with tempfile.TemporaryDirectory() as dir:
                    print(f'ordinary train/test after learning {idx_task}')
                    for mmm in [mm]:
                        if use_dummy_1sttask:
                            idxs = [0]
                        else:
                            idxs = []
                        # endif
                        metrics_final, list__artifacts = mmm.save(dir, idx_task,
                                                                  indices_task_ignored=idxs)
                        for k, v in metrics_final[idx_task].items():
                            trial.set_user_attr(k, v)
                        # endfor

                        mlflow.log_metrics(metrics_final[idx_task], step=idx_task)
                        for artifact in list__artifacts:
                            mlflow.log_artifact(artifact)
                        # endfor
                    # endfor
                # endwith

            # endfor | idx_task

            obj = metrics_final[num_tasks - 1]['acc__Overall']
        # endwith | mlflow.start_run()

        print(f'Emptying CUDA cache...')
        torch.cuda.empty_cache()

        # sys.exit()
        return obj
    # enddef

    return objective


@hydra.main(config_path='conf', config_name='config')
def main(cfg: DictConfig):
    print(f'\n{OmegaConf.to_yaml(cfg)}')

    utils.set_seed(cfg.seed)
    mlflow.pytorch.autolog()
    expname = cfg.expname
    expid = mlflow.create_experiment(expname)

    study = optuna.create_study(direction=cfg.optuna.direction,
                                storage=cfg.optuna.storage,
                                sampler=optuna.samplers.TPESampler(seed=cfg.seed + cfg.seedoffset),
                                load_if_exists=False,
                                study_name=expname,
                                )
    # study.set_user_attr()
    study.set_user_attr('Completed', False)
    n_trials = cfg.n_trials
    study.optimize(outer_objective(cfg, expid), n_trials=n_trials,
                   gc_after_trial=True, show_progress_bar=True)
    study.set_user_attr('Completed', True)
    print(f'best params: {study.best_params}')
    print(f'best value: {study.best_value}')
    print(study.trials_dataframe())
    print(f'{expname}')


if __name__ == '__main__':
    OmegaConf.register_new_resolver('nowjst',
                                    lambda pattern:
                                    datetime.now(pytz.timezone('Asia/Tokyo')).strftime(pattern))
    main()
