import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import datetime
import os
import random
import json
from platform import system
from tqdm import tqdm

from training import model as models
from training.variational import VariationalApprox
from training.boml import BayesianOnlineMetaLearnVariationalInference_InnerOnMean as BayesianOnlineMetaLearnVariationalInference
from training.util import enlist_transformation
from data_generate.task_generator import TaskGenerator
from data_generate.dataset import FewShotImageDataset
from data_generate.sampler import SuppQueryBatchSampler

torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def train(config, run_spec, seed):
    torch.manual_seed(seed)
    start_datetime = datetime.datetime.now()
    # add additional experiment details to config
    experiment_date = '{:%Y-%m-%d_%H:%M:%S}'.format(start_datetime)
    config['experiment_parent_dir'] = os.path.join(config['run_dir'], config['dataset'])
    config['experiment_dir'] = os.path.join(config['experiment_parent_dir'],
                                            '{}_{}_{}'.format(run_spec, experiment_date, seed))

    run_print = 'running {}_{} seed {}'.format(run_spec, experiment_date, seed)
    os.system('echo "{}"'.format(run_print)) if system() == 'Linux' else print(run_print)

    # save config json file
    if not os.path.exists(config['experiment_dir']):
        os.makedirs(config['experiment_dir'])
    with open(os.path.join(config['experiment_dir'], 'config{}_{}.json'.format(0, run_spec)), 'w') as outfile:
        outfile.write(json.dumps(config, indent=4))

    # split directory for this dataset
    split_dir = os.path.join(os.path.join(config['data_dir'], config['split_folder']), config['dataset'])
    # Generate sequential tasks or use existing task list
    if config['tasklist_path'] is None:
        taskgen = TaskGenerator(num_way=config['net_kwargs']['num_way'], supercls=config['supercls'],
                                split_dir=split_dir, tasklist_dir=config['experiment_dir'])
        taskgen.generate_task_list(split_npyfilename='metatrain.npy', num_class_excl=taskgen.num_way,
                                   num_task_per_supercls=config['num_task_per_supercls'], save_npy=True)
        taskgen.load_task(task_npyfilename='tasklist.npy')

        tasklist = taskgen.tasklist
    else:
        # if reusing other task list, copy to current experiment directory
        tasklist = np.load(config['tasklist_path'], allow_pickle=True).tolist()
        np.save(os.path.join(config['experiment_dir'], 'tasklist.npy'), tasklist)

    # define summarywriter
    writer = SummaryWriter(os.path.join(config['experiment_dir'], 'logtb'))
    result_dir = os.path.join(config['experiment_dir'], 'result')
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    # define model
    model = getattr(models, config['net'])(**config['net_kwargs']).to(device=config['device'])

    # define variational object
    var_approx = VariationalApprox(model=model, num_mc_sample=config['num_mc_sample'])

    # detach model meta-parameters
    var_approx.detach_model_params()

    # clone & detach mean and covar as mean_old and covar_old
    var_approx.update_mean_cov()

    # define meta-training object
    bomvi = BayesianOnlineMetaLearnVariationalInference(model, variational_obj=var_approx, device=config['device'])

    # define transformation of images
    transformation = transforms.Compose(
        enlist_transformation(img_resize=config['img_resize'], is_grayscale=config['is_grayscale'],
                              device=config['device'], img_normalise=config['img_normalise'])
    )

    # meta-evaluation dataset
    evalset = FewShotImageDataset(
        task_list=np.load(os.path.join(split_dir, 'metatest.npy'), allow_pickle=True).tolist(),
        supercls=config['eval_supercls'], img_lvl=int(config['eval_supercls']) + 1, transform=transformation,
        relabel=None, device=config['device'], cuda_img_tensor=config['cuda_img_tensor'], verbose=None
    )

    for task_idx, task in tqdm(enumerate(tasklist), desc='bomvi seqtask', total=len(tasklist)):
        # define optimiser and lr scheduler
        optim_outer = getattr(optim, config['optim_outer_name'])\
            (list(var_approx.mean.values()) + list(var_approx.covar.values()), **config['optim_outer_kwargs'])
        if config['lr_sch_outer_name'] is None:
            scheduler_outer = None
        else:
            scheduler_outer = getattr(lr_scheduler, config['lr_sch_outer_name'])\
                (optim_outer, **config['lr_sch_outer_kwargs'])

        # meta-training dataset
        trainset = FewShotImageDataset(
            task_list=task, supercls=config['supercls'], img_lvl=1, transform=transformation, relabel=None,
            device=config['device'], cuda_img_tensor=config['cuda_img_tensor'], verbose=None
        )
        trainsampler = SuppQueryBatchSampler(
            dataset=trainset, seqtask=config['seqtask'], num_batch=config['num_batch_per_outer'],
            num_shot=config['num_shot']
        )
        trainloader = DataLoader(trainset, batch_sampler=trainsampler)

        # meta-training
        bomvi.metatrain_seqtask(
            trainloader=trainloader, evalset=evalset, optimiser_outer=optim_outer, lr_scheduler_outer=scheduler_outer,
            lr_inner=config['lr_inner'], nstep_outer=config['nstep_outer'], nstep_inner=config['nstep_inner'],
            first_order=config['first_order'], eval_per_num_epoch=config['eval_per_num_epoch'],
            num_eval_task=config['num_eval_task'], eval_task_by_supercls=config['eval_task_by_supercls'],
            nstep_inner_eval=config['nstep_inner_eval'], writer=writer, task_idx=task_idx, verbose=None,
            prev_glob_step=None
        )
        # update mean and covariance of meta-parameters
        var_approx.update_mean_cov()

        if (task_idx + 1) == len(tasklist):
            # save task mean, covs
            torch.save(var_approx.mean, f=os.path.join(result_dir, 'mean{}.pt'.format(task_idx)))
            torch.save(var_approx.covar, f=os.path.join(result_dir, 'covar{}.pt'.format(task_idx)))
            # save model
            with open(os.path.join(result_dir, 'model{}.pt'.format(task_idx)), 'wb') as f:
                state_dict = model.state_dict()
                torch.save(state_dict, f)

    # check how long it ran
    run_time_print = '\ncompleted in {}'.format(datetime.datetime.now() - start_datetime)
    os.system('echo "{}"'.format(run_time_print)) if system() == 'Linux' else print(run_time_print)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser('BOMVI Sequential Task')
    parser.add_argument('--config_path', type=str, help='Path of .json file to import config from')
    args = parser.parse_args()
    # load config file
    jsonfile = open(str(args.config_path))
    config = json.loads(jsonfile.read())
    # train
    train(config=config, run_spec=os.path.splitext(os.path.split(args.config_path)[-1])[0], seed=random.getrandbits(24))
