import torch
# import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import datetime
import os
import json
import random
from platform import system
from warnings import filterwarnings

import optim
from config.configuration import get_run_name
from data_generate.dataset import FewShotImageDataset
from data_generate.sampler import SuppQueryBatchSampler
from training import model as models
from training.variational import VariationalApprox
from training.boml import BayesianOnlineMetaLearnVariationalInference_InnerOnMean as BayesianOnlineMetaLearnVariationalInference
from training.util import enlist_transformation


torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def train(config, run_spec, seed=0):
    torch.manual_seed(seed)

    start_datetime = datetime.datetime.now()
    experiment_date = '{:%Y-%m-%d_%H:%M:%S}'.format(start_datetime)
    config['experiment_parent_dir'] = os.path.join(config['run_dir'], get_run_name(config['dataset_ls']))
    config['experiment_dir'] = os.path.join(config['experiment_parent_dir'],
                                            '{}_{}_{}'.format(run_spec, experiment_date, seed))

    os.system("echo 'running {}_{} seed {}'".format(run_spec, experiment_date, seed)) if system() == 'Linux' \
        else print('running {}_{} seed {}'.format(run_spec, experiment_date, seed))

    # save config json file
    if not os.path.exists(config['experiment_dir']):
        os.makedirs(config['experiment_dir'])
    with open(os.path.join(
            config['experiment_dir'],
            'config{}_{}.json'.format(0 if config['completed_task_idx'] is None
                                      else config['completed_task_idx'] + 1, run_spec)
    ), 'w') as outfile:
        outfile.write(json.dumps(config, indent=4))

    # define result directory and previous result directory if applicable
    # define tensorboard writer
    if config['completed_task_idx'] is not None:
        completed_result_dir = os.path.join(
            os.path.join(os.path.join(config['run_dir'], get_run_name(config['dataset_ls'])),
                         config['completed_exp_name']),
            'result'
        )
    else:
        completed_result_dir = None
    writer = SummaryWriter(os.path.join(config['experiment_dir'], 'logtb'))
    result_dir = os.path.join(config['experiment_dir'], 'result')
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    # define model
    model = getattr(models, config['net'])(**config['net_kwargs']).to(device=config['device'])
    if config['completed_task_idx'] is not None:
        model.load_state_dict(
            torch.load(os.path.join(completed_result_dir, 'model{}.pt'.format(config['completed_task_idx'])))
        )

    # define variational object
    var_approx = VariationalApprox(model=model, num_mc_sample=config['num_mc_sample'])

    # detach model meta-parameters
    var_approx.detach_model_params()

    if config['completed_task_idx'] is not None:
        var_approx.mean \
            = torch.load(os.path.join(completed_result_dir, 'mean{}.pt'.format(config['completed_task_idx'])))
        var_approx.covar \
            = torch.load(os.path.join(completed_result_dir, 'covar{}.pt'.format(config['completed_task_idx'])))

        prev_glob_step \
            = torch.load(os.path.join(completed_result_dir, 'prev_glob_step{}.pt'.format(config['completed_task_idx'])))
        evalset = torch.load(
            os.path.join(completed_result_dir, 'evalset{}.pt'.format(config['completed_task_idx'])))
    else:
        prev_glob_step = 0
        evalset = []

    # clone & detach mean and covar as mean_old and covar_old
    var_approx.update_mean_cov()

    # define meta-training object
    bomvi = BayesianOnlineMetaLearnVariationalInference(model, variational_obj=var_approx, device=config['device'])
    # run partial num of datasets or all
    num_dataset_to_run = len(config['dataset_ls']) if config['num_dataset_to_run'] == 'all' \
        else config['num_dataset_to_run']

    for task_idx, task in enumerate(config['dataset_ls'][:num_dataset_to_run], 0):
        if config['completed_task_idx'] is not None and config['completed_task_idx'] >= task_idx:
            pass
        else:
            # split directory for this dataset
            split_dir = os.path.join(os.path.join(config['data_dir'], config['split_folder']), task)

            # define optimiser and lr scheduler
            optim_outer = getattr(optim, config[task]['optim_outer_name']) \
                (list(var_approx.mean.values()) + list(var_approx.covar.values()), **config[task]['optim_outer_kwargs'])
            if config[task]['lr_sch_outer_name'] is None:
                scheduler_outer = None
            else:
                scheduler_outer = getattr(lr_scheduler, config[task]['lr_sch_outer_name']) \
                    (optim_outer, **config[task]['lr_sch_outer_kwargs'])

            # define transformation of images
            transformation = transforms.Compose(
                enlist_transformation(img_resize=config['img_resize'], is_grayscale=config['is_grayscale'],
                                      device=config['device'], img_normalise=config[task]['img_normalise'])
            )

            trainset = FewShotImageDataset(
                task_list=np.load(os.path.join(split_dir, 'metatrain.npy'), allow_pickle=True).tolist(),
                supercls=config[task]['supercls'], img_lvl=int(config[task]['supercls']) + 1, transform=transformation,
                relabel=None, device=config['device'], cuda_img_tensor=config['cuda_img_tensor'],
                verbose='{} trainset'.format(task)
            )
            # define & append meta-evaluation dataset and dataloader
            evalset.append(FewShotImageDataset(
                task_list=np.load(os.path.join(split_dir, 'metatest.npy'), allow_pickle=True).tolist(),
                supercls=config[task]['eval_supercls'], img_lvl=int(config[task]['eval_supercls']) + 1,
                transform=transformation, relabel=None, device=config['device'],
                cuda_img_tensor=config['cuda_img_tensor'], verbose='{} evalset'.format(task)
            ))
            # meta-training
            bomvi.metatrain(
                trainset=trainset, evalset=evalset, optimiser_outer=optim_outer, lr_scheduler_outer=scheduler_outer,
                nstep_outer=config[task]['nstep_outer'], nstep_inner=config[task]['nstep_inner'],
                lr_inner=config[task]['lr_inner'], first_order=config[task]['first_order'], seqtask=config['seqtask'],
                num_task_per_itr=config[task]['num_task_per_itr'], task_by_supercls=config[task]['task_by_supercls'],
                num_way=config['net_kwargs']['num_way'], num_shot=config[task]['num_shot'],
                num_query_per_cls=config[task]['num_query_per_cls'], eval_prev_task=True,
                eval_per_num_iter=config[task]['eval_per_num_iter'], num_eval_task=config[task]['num_eval_task'],
                eval_task_by_supercls=config[task]['eval_task_by_supercls'],
                nstep_inner_eval=config[task]['nstep_inner_eval'], writer=writer, task_idx=task_idx,
                prev_glob_step=prev_glob_step, verbose=task
            )
            # update global step
            prev_glob_step += config[task]['nstep_outer']

            # update mean and covariance of meta-parameters
            var_approx.update_mean_cov()

            # save mean, covariance, mean_old and covar_old
            if (task_idx + 1) == num_dataset_to_run:
                # save previous info for continued training
                if config['num_dataset_to_run'] != "all":
                    torch.save(prev_glob_step, f=os.path.join(result_dir, 'prev_glob_step{}.pt'.format(task_idx)))
                    torch.save(evalset, f=os.path.join(result_dir, 'evalset{}.pt'.format(task_idx)))
                # save task mean, covs
                torch.save(var_approx.mean, f=os.path.join(result_dir, 'mean{}.pt'.format(task_idx)))
                torch.save(var_approx.covar, f=os.path.join(result_dir, 'covar{}.pt'.format(task_idx)))
                # save model
                with open(os.path.join(result_dir, 'model{}.pt'.format(task_idx)), 'wb') as f:
                    state_dict = model.state_dict()
                    torch.save(state_dict, f)

        torch.cuda.empty_cache()

    # check how long it ran
    run_time_print = '\ncompleted in {}'.format(datetime.datetime.now() - start_datetime)
    os.system('echo "{}"'.format(run_time_print)) if system() == 'Linux' else print(run_time_print)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser('BOMVI Sequential Dataset')
    parser.add_argument('--config_path', type=str, help='Path of .json file to import config from')
    args = parser.parse_args()
    # load config file
    jsonfile = open(str(args.config_path))
    config = json.loads(jsonfile.read())
    # train
    train(config=config, run_spec=os.path.splitext(os.path.split(args.config_path)[-1])[0], seed=random.getrandbits(24))
