import argparse
from asyncore import write
import datetime
import json
import os
import random
from re import L
import time
from collections import OrderedDict
from unittest import result

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import numpy as np
from tqdm import tqdm, trange
from torch.utils.data import DataLoader

from tensorboardX import SummaryWriter

import sys
sys.path.append('..')
import models
from models.meta_sgd import MetaSGD
import optimizers
from config.configuration import *


from util import enlist_transformation
from data_generate.dataset import FewShotImageDataset
from data_generate.sampler import SuppQueryBatchSampler
from meta_train_metaSGD import meta_train
from meta_test_metaSGD import meta_test
import util

def main(config, run_spec):

    begin_time = time.time()
    
    #* change the config
    config['dataset_ls'] = config['dataset_ls'][:config['num_dataset_to_run']]
    config['classifier_args']['n_way'] = config['num_way']

    #* set the save file path
    start_datetime = datetime.datetime.now()
    experiment_date = '{:%Y-%m-%d_%H:%M:%S}'.format(start_datetime)
    run_name = get_run_name(config['dataset_ls'])

    if config['adjust']:
        meta_save_path = os.path.join(config['run_dir'], '{}_shot'.format(config['num_shot']), run_name, run_spec+'_{}'.format(experiment_date))
    else:
        meta_save_path = os.path.join(config['run_dir'], '{}_shot'.format(config['num_shot']), run_name, run_spec)
    
    ensure_path(meta_save_path)
    
    #* save the config file
    with open(os.path.join(meta_save_path, 'config_{}.json'.format(run_spec)), 'w') as outfile:
        outfile.write(json.dumps(config, indent=4))

    #* set the tensorboard writer
    writer = SummaryWriter(os.path.join(meta_save_path, 'log'))
    result_dir = os.path.join(meta_save_path, 'result')
    if not os.path.exists(result_dir):
        os.makedirs(result_dir, exist_ok=True)

    model = MetaSGD(config)
    model.to(config['device'])
    model.define_task_lr_params()
    model_params = list(model.parameters()) + list(model.task_lr.values())

    optimizer, lr_scheduler = optimizers.make(config['optimizer'], model_params,
                                            **config['optimizer_args'])


    #* define the training dataset
    transformation = transforms.Compose(
        enlist_transformation(img_resize=config['img_resize'], is_grayscale=config['is_grayscale'],
                            device=config['device'], img_normalise=config['img_normalise'])
    )    
    num_dataset_to_run = len(config['dataset_ls']) if config['num_dataset_to_run'] == "all" \
                        else config['num_dataset_to_run']
    
    num_dataset_to_test = len(config['test_dataset_ls']) if config['num_dataset_to_test'] == "all" \
                        else config['num_dataset_to_test']

    trainloaders = []
    testloaders = []
    task_per_itr = int(config['batch_size']/config['num_dataset_to_run'])

    for task_idx, task in enumerate(config['dataset_ls'][:num_dataset_to_run], 0):
        
        split_dir = os.path.join(os.path.join(config['data_dir'], '5-shot'), task)
        
        #* for training dataset
        train_task_list = np.load(os.path.join(split_dir, 'metatrain.npy'))
        trainset = FewShotImageDataset(task_list=train_task_list, transform=transformation, 
                                device=config['device'], task_name = task, verbose='{} trainset'.format(task))
        trainsampler = SuppQueryBatchSampler(
            dataset=trainset, num_way=config['num_way'], num_shot=config['num_shot'],
            num_query_per_cls=config['num_query_per_cls'], num_task=task_per_itr
        )
        trainloader = DataLoader(trainset, batch_sampler=trainsampler)
        trainloaders.append(trainloader)

    for task_idx, task in enumerate(config['test_dataset_ls'][:num_dataset_to_test], 0):
        
        split_dir = os.path.join(os.path.join(config['data_dir'], '5-shot'), task)
        
        #* for testing dataset
        test_task_list = np.load(os.path.join(split_dir, 'metatest.npy'))
        testset = FewShotImageDataset(task_list=test_task_list, transform=transformation, 
                                device=config['device'], task_name = task, verbose='{} testset'.format(task))
        testsampler = SuppQueryBatchSampler(
            dataset=testset, num_way=config['num_way'], num_shot=config['num_shot'],
            num_query_per_cls=config['num_query_per_cls'], num_task=config['num_val_task']
        )
        testloader = DataLoader(testset, batch_sampler=testsampler)
        testloaders.append(testloader)

    #* train and test the model
    inner_args = util.config_inner_args(config['inner_args'])

    # define the variable to save the result
    train_loss_dict = dict()
    train_acc_dict = dict()

    test_loss_mean_dict_list = []
    test_loss_std_dict_list = []
    test_acc_mean_dict_list = []
    test_acc_std_dict_list = []

    for i in range(config['num_dataset_to_test']):
        test_loss_mean_dict_list.append(dict())
        test_loss_std_dict_list.append(dict())
        test_acc_mean_dict_list.append(dict())
        test_acc_std_dict_list.append(dict())

    checkpoint = 0.0

    for itr in trange(config['num_epoch'], desc='meta-train', ncols=100):

        model.train()

        train_loss, train_accuracy = meta_train(model, optimizer, lr_scheduler, trainloaders, inner_args, config, itr, config['seq_task'], writer) 
        train_loss_dict[itr] = train_loss
        train_acc_dict[itr] = train_accuracy

        if (itr+1)%100 == 0 or itr==0:
            test_loss_mean_list, test_loss_std_list, test_acc_mean_list, test_acc_std_list \
                            = meta_test(model, testloaders, inner_args, config, itr, writer)
            
            for i in range(config['num_dataset_to_test']):
                test_loss_mean_dict_list[i][itr] = test_loss_mean_list[i]
                test_loss_std_dict_list[i][itr] = test_loss_std_list[i]
                test_acc_mean_dict_list[i][itr] = test_acc_mean_list[i]
                test_acc_std_dict_list[i][itr] = test_acc_std_list[i]
                print('{}: acc_mean={} and acc_std={}'.format(config['test_dataset_ls'][i], test_acc_mean_list[i], test_acc_std_list[i]))

            print('Average Performance: acc_mean={} and acc_std={}'.format(np.mean(test_acc_mean_list), np.mean(test_acc_std_list)))
            
            if np.mean(test_acc_mean_list) > checkpoint:
                checkpoint = np.mean(test_acc_mean_list)
                torch.save(model.state_dict(), f=os.path.join(result_dir, 'step{}.pt'.format(itr)))

        lr_scheduler.step()
        
    #* save result
    writer.close()
    np.save(os.path.join(result_dir, 'train_loss.npy'), train_loss_dict)
    np.save(os.path.join(result_dir, 'train_acc.npy'), train_acc_dict)

    for i in range(config['num_dataset_to_test']):
        np.save(os.path.join(result_dir, '{}_test_loss_mean.npy'.format(config['test_dataset_ls'][i])), test_loss_mean_dict_list[i])
        np.save(os.path.join(result_dir, '{}_test_loss_std.npy'.format(config['test_dataset_ls'][i])), test_loss_std_dict_list[i])
        np.save(os.path.join(result_dir, '{}_test_acc_mean.npy'.format(config['test_dataset_ls'][i])), test_acc_mean_dict_list[i])
        np.save(os.path.join(result_dir, '{}_test_acc_std.npy'.format(config['test_dataset_ls'][i])), test_acc_std_dict_list[i])

    end_time = time.time()
    total_run_time = end_time - begin_time
    print('-------------------------------------')
    print('All the processes finish and the total time cost is {}'.format(total_run_time))

        










if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', 
                        help='configuration file')
    parser.add_argument('--efficient', 
                        help='if True, enables gradient checkpointing',
                        action='store_true',
                        default=False)
    parser.add_argument('--seq_task',
                        help='if True, the mini-batch consists of tasks only from one task distribution',
                        default=False)
    parser.add_argument('--adjust',
                        help='if True, the saved file will be create based on the experimental time',
                        default=False)
    args = parser.parse_args()
    # load config file
    jsonfile = open(str(args.config_path))
    config = json.loads(jsonfile.read())
    config['adjust'] = args.adjust
    config['seq_task'] = args.seq_task

    run_spec = os.path.splitext(os.path.split(args.config_path)[-1])[0]

    main(config, run_spec)
