import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import datetime
import os
import random
import json
from platform import system

from config.configuration import get_run_name
from training import model as models
from training.laplace import LaplaceApprox
from training.boml import BayesianOnlineMetaLearnLaplaceApprox
from training.util import enlist_transformation
from data_generate.dataset import FewShotImageDataset
from data_generate.sampler import SuppQueryBatchSampler

torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def train(config, run_spec, seed=0):
    torch.manual_seed(seed)

    start_datetime = datetime.datetime.now()
    experiment_date = '{:%Y-%m-%d_%H:%M:%S}'.format(start_datetime)
    config['experiment_parent_dir'] = os.path.join(config['run_dir'], get_run_name(config['dataset_ls']))
    config['experiment_dir'] = os.path.join(config['experiment_parent_dir'],
                                            '{}_{}_{}'.format(run_spec, experiment_date, seed))

    running_print = 'running {}_{} seed {}'.format(run_spec, experiment_date, seed)
    os.system('echo "{}"'.format(running_print)) if system() == 'Linux' else print(running_print)

    # save config json file
    if not os.path.exists(config['experiment_dir']):
        os.makedirs(config['experiment_dir'])
    with open(os.path.join(
            config['experiment_dir'],
            'config{}_{}.json'.format(0 if config['completed_task_idx'] is None
                                      else config['completed_task_idx'] + 1, run_spec)
    ), 'w') as outfile:
        outfile.write(json.dumps(config, indent=4))

    # define result directory and previous result directory if applicable
    # define tensorboard writer
    if config['completed_task_idx'] is not None:
        completed_result_dir = os.path.join(
            os.path.join(os.path.join(config['run_dir'], get_run_name(config['dataset_ls'])),
                         config['completed_exp_name']),
            'result'
        )
    else:
        completed_result_dir = None
    writer = SummaryWriter(os.path.join(config['experiment_dir'], 'logtb'))
    result_dir = os.path.join(config['experiment_dir'], 'result')
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    # define model
    model = getattr(models, config['net'])(**config['net_kwargs']).to(device=config['device'])
    if config['completed_task_idx'] is not None:
        model.load_state_dict(
            torch.load(os.path.join(completed_result_dir, 'model{}.pt'.format(config['completed_task_idx'])))
        )
    # define laplace object
    lapl_approx = LaplaceApprox(
        model=model, is_lapl_list=config['is_lapl_list'], nll_supp_wrt_metaparam=config['nll_supp_wrt_metaparam'],
        hessian_xterm=config['hessian_xterm'], kfac_init_mult=config['kfac_init_mult'], upd_scale=config['upd_scale'],
        device=config['device']
    )
    if config['completed_task_idx'] is not None:
        if config['lapl_approx_reg']:  #
            lapl_approx.mean \
                = torch.load(os.path.join(completed_result_dir, 'mean{}.pt'.format(config['completed_task_idx'])))
            lapl_approx.act_cov \
                = torch.load(os.path.join(completed_result_dir, 'act_cov{}.pt'.format(config['completed_task_idx'])))
            lapl_approx.grad_cov \
                = torch.load(os.path.join(completed_result_dir, 'grad_cov{}.pt'.format(config['completed_task_idx'])))

        prev_glob_step \
            = torch.load(os.path.join(completed_result_dir, 'prev_glob_step{}.pt'.format(config['completed_task_idx'])))
        evalset = torch.load(
            os.path.join(completed_result_dir, 'evalset{}.pt'.format(config['completed_task_idx'])))
    else:
        prev_glob_step = 0
        evalset = []  # None if 'rainbow' in dataset_ls[0] else
    # define meta-training object
    bomla = BayesianOnlineMetaLearnLaplaceApprox(model, laplace_obj=lapl_approx, device=config['device'])

    dataset_list = config['dataset_ls']
    num_dataset_to_run = len(dataset_list) if config['num_dataset_to_run'] == 'all' else config['num_dataset_to_run']

    for task_idx, task in enumerate(dataset_list[:num_dataset_to_run], 0):
        if config['completed_task_idx'] is not None and config['completed_task_idx'] >= task_idx:
            pass
        else:
            # split directory for this dataset
            split_dir = os.path.join(os.path.join(config['data_dir'], config['split_folder']), task)

            # define optimiser and lr scheduler
            optim_outer = getattr(optim, config[task]['optim_outer_name']) \
                (model.meta_parameters(), **config[task]['optim_outer_kwargs'])
            if config[task]['lr_sch_outer_name'] is None:
                scheduler_outer = None
            else:
                scheduler_outer = getattr(lr_scheduler, config[task]['lr_sch_outer_name']) \
                    (optim_outer, **config[task]['lr_sch_outer_kwargs'])

            # define transformation of images
            transformation = transforms.Compose(
                enlist_transformation(
                    img_resize=config['img_resize'], resize_interpolation=config[task]['resize_interpolation'],
                    is_grayscale=config['is_grayscale'], device=config['device'],
                    img_normalise=config[task]['img_normalise']
                )
            )

            # define meta-training dataset
            trainset = FewShotImageDataset(
                task_list=np.load(os.path.join(split_dir, 'metatrain.npy'), allow_pickle=True).tolist(),
                supercls=config[task]['supercls'], img_lvl=int(config[task]['supercls']) + 1, transform=transformation,
                relabel=None, device=config['device'], cuda_img_tensor=config['cuda_img_tensor'],
                verbose='{} trainset'.format(task)
            )

            # define & append meta-evaluation dataset
            evalset.append(FewShotImageDataset(
                task_list=np.load(os.path.join(split_dir, 'metatest.npy'), allow_pickle=True).tolist(),
                supercls=config[task]['eval_supercls'], img_lvl=int(config[task]['eval_supercls']) + 1,
                transform=transformation, relabel=None, device=config['device'],
                cuda_img_tensor=config['cuda_img_tensor'], verbose='{} evalset'.format(task)
            ))

            # meta-training
            bomla.metatrain(
                trainset=trainset, evalset=evalset, optimiser_outer=optim_outer, lr_scheduler_outer=scheduler_outer,
                lapl_approx_reg=config['lapl_approx_reg'], nstep_outer=config[task]['nstep_outer'],
                nstep_inner=config[task]['nstep_inner'], lr_inner=config[task]['lr_inner'],
                first_order=config[task]['first_order'], seqtask=config['seqtask'],
                num_way=config['net_kwargs']['num_way'], num_shot=config[task]['num_shot'],
                num_query_per_cls=config[task]['num_query_per_cls'], num_task_per_itr=config[task]['num_task_per_itr'],
                task_by_supercls=config[task]['task_by_supercls'], eval_prev_task=True,
                eval_per_num_iter=config[task]['eval_per_num_iter'], num_eval_task=config[task]['num_eval_task'],
                eval_task_by_supercls=config[task]['eval_task_by_supercls'],
                nstep_inner_eval=config[task]['nstep_inner_eval'], writer=writer, task_idx=task_idx,
                prev_glob_step=prev_glob_step, verbose=task
            )

            # update global step
            prev_glob_step += config[task]['nstep_outer']
            # update mean
            lapl_approx.update_mean()

            if config['lapl_approx_reg']:  #
                # compute activation and pre-activation gradient outer products
                kfacsampler = SuppQueryBatchSampler(
                    dataset=trainset, seqtask=False, num_task=config[task]['num_task_for_kfac'],
                    task_by_supercls=config[task]['task_by_supercls'], num_way=config['net_kwargs']['num_way'],
                    num_shot=config[task]['num_shot'], num_query_per_cls=config[task]['num_query_per_cls']
                )
                kfacloader = DataLoader(trainset, batch_sampler=kfacsampler)
                new_tpar_act_cov, new_tpar_grad_cov, new_tpar_bn_fisher = lapl_approx.get_fisher_bd_kfac_covs(
                    dataloader=kfacloader, param=None, run_inner=True, nstep_inner=config[task]['nstep_inner'],
                    lr_inner=config[task]['lr_inner'], exclude_module_name=None,
                    verbose='{} tpar-kfac calc'.format(task)
                )
                # update hessian
                lapl_approx.update_hessian(
                    term='tpar', new_act_cov=new_tpar_act_cov, new_grad_cov=new_tpar_grad_cov,
                    new_bn_fisher=new_tpar_bn_fisher, upd_scale=config['upd_scale'], norm=config['norm']
                )
                # update hessian if nll_supp term included
                if config['nll_supp_wrt_metaparam']:
                    new_mpar_act_cov, new_mpar_grad_cov, new_mpar_bn_fisher = lapl_approx.get_fisher_bd_kfac_covs(
                        dataloader=kfacloader, param=None, run_inner=False, exclude_module_name=None,
                        verbose='{} mpar-kfac calc'.format(task)
                    )
                    lapl_approx.update_hessian(
                        term='mpar', new_act_cov=new_mpar_act_cov, new_grad_cov=new_mpar_grad_cov,
                        new_bn_fisher=new_mpar_bn_fisher, upd_scale=config['upd_scale'], norm=config['norm']
                    )

                if (task_idx + 1) == num_dataset_to_run:
                    # save task mean, covs
                    torch.save(lapl_approx.mean, f=os.path.join(result_dir, 'mean{}.pt'.format(task_idx)))
                    torch.save(lapl_approx.tpar_act_cov,
                               f=os.path.join(result_dir, 'tpar_act_cov{}.pt'.format(task_idx)))
                    torch.save(lapl_approx.tpar_grad_cov,
                               f=os.path.join(result_dir, 'tpar_grad_cov{}.pt'.format(task_idx)))
                    torch.save(lapl_approx.tpar_bn_fisher,
                               f=os.path.join(result_dir, 'tpar_bn_fisher{}.pt'.format(task_idx)))

                    if config['nll_supp_wrt_metaparam']:
                        torch.save(lapl_approx.mpar_act_cov,
                                   f=os.path.join(result_dir, 'mpar_act_cov{}.pt'.format(task_idx)))
                        torch.save(lapl_approx.mpar_grad_cov,
                                   f=os.path.join(result_dir, 'mpar_grad_cov{}.pt'.format(task_idx)))
                        torch.save(lapl_approx.mpar_bn_fisher,
                                   f=os.path.join(result_dir, 'mpar_bn_fisher{}.pt'.format(task_idx)))


            if (task_idx + 1) == num_dataset_to_run:
                # save previous info for continued training
                if config['num_dataset_to_run'] != "all":
                    torch.save(prev_glob_step, f=os.path.join(result_dir, 'prev_glob_step{}.pt'.format(task_idx)))
                    torch.save(evalset, f=os.path.join(result_dir, 'evalset{}.pt'.format(task_idx)))
                # save model
                with open(os.path.join(result_dir, 'model{}.pt'.format(task_idx)), 'wb') as f:
                    state_dict = model.state_dict()
                    torch.save(state_dict, f)

        torch.cuda.empty_cache()

    # check how long it ran
    run_time_print = '\ncompleted in {}'.format(datetime.datetime.now() - start_datetime)
    os.system('echo "{}"'.format(run_time_print)) if system() == 'Linux' else print(run_time_print)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser('BOMLA Sequential Dataset')
    parser.add_argument('--config_path', type=str, help='Path of .json file to import config from')
    args = parser.parse_args()
    # load config file
    jsonfile = open(str(args.config_path))
    config = json.loads(jsonfile.read())
    # train
    train(config=config, run_spec=os.path.splitext(os.path.split(args.config_path)[-1])[0], seed=random.getrandbits(24))
