import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import datetime
import os
import random
import json
from platform import system
from tqdm import tqdm

from training import model as models
from training.laplace import LaplaceApprox
from training.boml import BayesianOnlineMetaLearnLaplaceApprox
from training.util import enlist_transformation
from data_generate.task_generator import TaskGenerator
from data_generate.dataset import FewShotImageDataset
from data_generate.sampler import SuppQueryBatchSampler

torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def train(config, run_spec, seed):
    torch.manual_seed(seed)
    # add additional experiment details to config
    start_datetime = datetime.datetime.now()
    experiment_date = '{:%Y-%m-%d_%H:%M:%S}'.format(start_datetime)
    config['experiment_parent_dir'] = os.path.join(config['run_dir'], config['dataset'])
    config['experiment_dir'] = os.path.join(config['experiment_parent_dir'], '{}_{}_{}'.format(run_spec, experiment_date, seed))

    os.system("echo 'running {}_{} seed {}'".format(run_spec, experiment_date, seed)) if system() == 'Linux' \
        else print('running {}_{} seed {}'.format(run_spec, experiment_date, seed))

    # save config json file
    if not os.path.exists(config['experiment_dir']):
        os.makedirs(config['experiment_dir'])
    with open(os.path.join(config['experiment_dir'], 'config{}_{}.json'.format(0, run_spec)), 'w') as outfile:
        outfile.write(json.dumps(config, indent=4))

    # split directory for this dataset
    split_dir = os.path.join(os.path.join(config['data_dir'], config['split_folder']), config['dataset'])
    # Generate sequential tasks or use existing task list
    if config['tasklist_path'] is None:
        taskgen = TaskGenerator(num_way=config['net_kwargs']['num_way'], supercls=config['supercls'],
                                split_dir=split_dir, tasklist_dir=config['experiment_dir'])
        taskgen.generate_task_list(split_npyfilename='metatrain.npy', num_class_excl=taskgen.num_way,
                                   num_task_per_supercls=config['num_task_per_supercls'], save_npy=True)
        taskgen.load_task(task_npyfilename='tasklist.npy')

        tasklist = taskgen.tasklist
    else:
        # if reusing other task list, copy to current experiment directory
        tasklist = np.load(config['tasklist_path'], allow_pickle=True).tolist()
        np.save(os.path.join(config['experiment_dir'], 'tasklist.npy'), tasklist)

    # define summarywriter
    writer = SummaryWriter(os.path.join(config['experiment_dir'], 'logtb'))
    result_dir = os.path.join(config['experiment_dir'], 'result')
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    # define model
    model = getattr(models, config['net'])(**config['net_kwargs']).to(device=config['device'])

    # define laplace object
    lapl_approx = LaplaceApprox(
        model=model, is_lapl_list=config['is_lapl_list'], nll_supp_wrt_metaparam=config['nll_supp_wrt_metaparam'],
        hessian_xterm=config['hessian_xterm'], kfac_init_mult=config['kfac_init_mult'], upd_scale=config['upd_scale'],
        device=config['device']
    )

    # define meta-training object
    bomla = BayesianOnlineMetaLearnLaplaceApprox(model, laplace_obj=lapl_approx, device=config['device'])

    # define transformation of images
    transformation = transforms.Compose(
        enlist_transformation(img_resize=config['img_resize'], is_grayscale=config['is_grayscale'],
                              device=config['device'], img_normalise=config['img_normalise'])
    )

    # meta-evaluation dataset
    evalset = FewShotImageDataset(
        task_list=np.load(os.path.join(split_dir, 'metatest.npy'), allow_pickle=True).tolist(),
        supercls=config['eval_supercls'], img_lvl=int(config['eval_supercls']) + 1, transform=transformation,
        relabel=None, device=config['device'], cuda_img_tensor=config['cuda_img_tensor'], verbose=None
    )

    for task_idx, task in tqdm(enumerate(tasklist), desc='bomla seqtask', total=len(tasklist)):
        # define optimiser and lr scheduler
        optim_outer = getattr(optim, config['optim_outer_name'])\
            (model.meta_parameters(), **config['optim_outer_kwargs'])
        if config['lr_sch_outer_name'] is None:
            scheduler_outer = None
        else:
            scheduler_outer = getattr(lr_scheduler, config['lr_sch_outer_name'])\
                (optim_outer, **config['lr_sch_outer_kwargs'])

        # meta-training dataset
        trainset = FewShotImageDataset(
            task_list=task, supercls=config['supercls'], img_lvl=1, transform=transformation, relabel=None,
            device=config['device'], cuda_img_tensor=config['cuda_img_tensor'], verbose=None
        )
        trainsampler = SuppQueryBatchSampler(
            dataset=trainset, seqtask=config['seqtask'], num_batch=config['num_batch_per_outer'],
            num_shot=config['num_shot']
        )
        trainloader = DataLoader(trainset, batch_sampler=trainsampler)

        # meta-training
        bomla.metatrain_seqtask(
            trainloader=trainloader, evalset=evalset, optimiser_outer=optim_outer, lr_scheduler_outer=scheduler_outer,
            lapl_approx_reg=config['lapl_approx_reg'], lr_inner=config['lr_inner'], nstep_outer=config['nstep_outer'],
            nstep_inner=config['nstep_inner'], first_order=config['first_order'],
            eval_per_num_epoch=config['eval_per_num_epoch'], num_eval_task=config['num_eval_task'],
            eval_task_by_supercls=config['eval_task_by_supercls'], nstep_inner_eval=config['nstep_inner_eval'],
            writer=writer, task_idx=task_idx, verbose=None, prev_glob_step=None
        )
        # update mean
        lapl_approx.update_mean(max_lapl_list_len=config['max_lapl_list_len'])

        if config['lapl_approx_reg']:
            new_tpar_act_cov, new_tpar_grad_cov, new_tpar_bn_fisher = lapl_approx.get_fisher_bd_kfac_covs(
                dataloader=trainloader, param=None, run_inner=True, nstep_inner=config['nstep_inner'],
                lr_inner=config['lr_inner'], exclude_module_name=None, verbose=None
            )
            # update hessian
            lapl_approx.update_hessian(
                term='tpar', new_act_cov=new_tpar_act_cov, new_grad_cov=new_tpar_grad_cov, new_bn_fisher=new_tpar_bn_fisher,
                upd_scale=config['upd_scale'], norm=config['norm'], max_lapl_list_len=config['max_lapl_list_len']
            )

            # update hessian if likelihood term included
            if config['nll_supp_wrt_metaparam']:
                new_mpar_act_cov, new_mpar_grad_cov, new_mpar_bn_fisher = lapl_approx.get_fisher_bd_kfac_covs(
                    dataloader=trainloader, param=None, run_inner=False, exclude_module_name=None, verbose=None
                )
                lapl_approx.update_hessian(
                    term='mpar', new_act_cov=new_mpar_act_cov, new_grad_cov=new_mpar_grad_cov, new_bn_fisher=new_mpar_bn_fisher,
                    upd_scale=config['upd_scale'], norm=config['norm'], max_lapl_list_len=config['max_lapl_list_len']
                )

            if (task_idx + 1) == len(tasklist):
                # save task mean, covs
                torch.save(lapl_approx.mean, f=os.path.join(result_dir, 'mean{}.pt'.format(task_idx)))
                torch.save(lapl_approx.tpar_act_cov, f=os.path.join(result_dir, 'tpar_act_cov{}.pt'.format(task_idx)))
                torch.save(lapl_approx.tpar_grad_cov, f=os.path.join(result_dir, 'tpar_grad_cov{}.pt'.format(task_idx)))
                torch.save(lapl_approx.tpar_bn_fisher, f=os.path.join(result_dir, 'tpar_bn_fisher{}.pt'.format(task_idx)))
                if config['nll_supp_wrt_metaparam']:
                    torch.save(lapl_approx.mpar_act_cov, f=os.path.join(result_dir, 'mpar_act_cov{}.pt'.format(task_idx)))
                    torch.save(lapl_approx.mpar_grad_cov, f=os.path.join(result_dir, 'mpar_grad_cov{}.pt'.format(task_idx)))
                    torch.save(lapl_approx.mpar_bn_fisher, f=os.path.join(result_dir, 'mpar_bn_fisher{}.pt'.format(task_idx)))

        if (task_idx + 1) == len(tasklist):
            # save model
            with open(os.path.join(result_dir, 'model{}.pt'.format(task_idx)), 'wb') as f:
                state_dict = model.state_dict()
                torch.save(state_dict, f)

    # check how long it ran
    run_time_print = '\ncompleted in {}'.format(datetime.datetime.now() - start_datetime)
    os.system('echo "{}"'.format(run_time_print)) if system() == 'Linux' else print(run_time_print)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser('BOMLA Sequential Task')
    parser.add_argument('--config_path', type=str, help='Path of .json file to import config from')
    args = parser.parse_args()
    # load config file
    jsonfile = open(str(args.config_path))
    config = json.loads(jsonfile.read())
    # train
    train(config=config, run_spec=os.path.splitext(os.path.split(args.config_path)[-1])[0], seed=random.getrandbits(24))
