import argparse
from asyncore import write
import datetime
import json
import os
import random
from re import L
import time
from collections import OrderedDict
from unittest import result

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import numpy as np
from tqdm import tqdm, trange
from torch.utils.data import DataLoader

from tensorboardX import SummaryWriter

import sys
sys.path.append('..')
sys.path.append('../models')
import models
from models.sharpMaml.model import ModelWideConv, ModelResNet12
import optimizers
from config.configuration import *


from util import enlist_transformation
from data_generate.dataset import FewShotImageDataset
from data_generate.sampler import SuppQueryBatchSampler
from meta_train_sharpMaml import meta_train
from meta_test_sharpMaml import meta_test
import util

from models.sharpMaml.metalearners import ModelAgnosticMetaLearning
from models.sam import SAM



def main(config, run_spec):

    begin_time = time.time()

    # set the random seed
    # random.seed(0)
    # np.random.seed(0)
    # torch.manual_seed(0)
    # torch.cuda.manual_seed(0)

    # set the record file
    # start_datetime = datetime.datetime.now()
    # experiment_date = '{:%Y-%m-%d_%H:%M:%S}'.format(start_datetime)

    #* change the config
    config['dataset_ls'] = config['dataset_ls'][:config['num_dataset_to_run']]
    config['classifier_args']['n_way'] = config['num_way']

    #* set the save file path
    start_datetime = datetime.datetime.now()
    experiment_date = '{:%Y-%m-%d_%H:%M:%S}'.format(start_datetime)
    run_name = get_run_name(config['dataset_ls'])
    if config['num_dataset_to_run'] < config['num_dataset_to_test']:
        run_name = 'cross_'+run_name

    if config['encoder'] != 'convnet4':
        config['run_dir'] += config['encoder']
    
    if config['adjust']:
        meta_save_path = os.path.join(config['run_dir'], '{}_shot'.format(config['num_shot']), run_name, run_spec+'_{}'.format(experiment_date))
    else:
        if 'prompt' in config.keys():
            if config['prompt_args']['dynamic_lr']:
                meta_save_path = os.path.join(config['run_dir'], '{}_shot'.format(config['num_shot']), run_name, run_spec+'{}_{}'.format(config['prompt_args']['dim'], config['prompt_args']['kl_weight']))
            elif config['prompt_args']['bayesian']:
                meta_save_path = os.path.join(config['run_dir'], '{}_shot'.format(config['num_shot']), run_name, run_spec+'{}_bayesian'.format(config['prompt_args']['dim']))
            else:
                meta_save_path = os.path.join(config['run_dir'], '{}_shot'.format(config['num_shot']), run_name, run_spec+'{}'.format(config['prompt_args']['dim']))
        else:
            meta_save_path = os.path.join(config['run_dir'], '{}_shot'.format(config['num_shot']), run_name, run_spec)
    # if not os.path.exists(meta_save_path):
    #     os.makedirs(meta_save_path, exist_ok=True)
    
    ensure_path(meta_save_path)
    
    #* save the config file
    with open(os.path.join(meta_save_path, 'config_{}.json'.format(run_spec)), 'w') as outfile:
        outfile.write(json.dumps(config, indent=4))

    #* set the tensorboard writer
    writer = SummaryWriter(os.path.join(meta_save_path, 'log'))
    result_dir = os.path.join(meta_save_path, 'result')
    if not os.path.exists(result_dir):
        os.makedirs(result_dir, exist_ok=True)
    


    #* load the model
    if config['encoder'] == "wide-convnet4":
        model = ModelWideConv(config['num_way']).to(config['device'])
    elif config['encoder'] == 'resnet12':
        model = ModelResNet12().to(config['device'])
    
    #* define the optimizer
    base_optimizer = torch.optim.Adam
    meta_optimizer = SAM(model.parameters(), base_optimizer, rho=config['alpha'],
                         adaptive=config['adap'], lr=config['optimizer_args']['lr'])
    
    #* define the meta-learner
    metalearner = ModelAgnosticMetaLearning(model,
                                            meta_optimizer,
                                            adap=config['adap'],
                                            alpha=config['alpha'],
                                            SAM_lower=config['SAM_lower'],
                                            first_order=config['inner_args']['first_order'],
                                            num_adaptation_steps=config['num-steps'],
                                            step_size=config['step-size'],
                                            m=config['m'],
                                            loss_function=F.cross_entropy,
                                            device=config['device'])
    
    #* define the training dataset
    transformation = transforms.Compose(
        enlist_transformation(img_resize=config['img_resize'], is_grayscale=config['is_grayscale'],
                            device=config['device'], img_normalise=config['img_normalise'])
    )    
    num_dataset_to_run = len(config['dataset_ls']) if config['num_dataset_to_run'] == "all" \
                        else config['num_dataset_to_run']
    
    num_dataset_to_test = len(config['test_dataset_ls']) if config['num_dataset_to_test'] == "all" \
                        else config['num_dataset_to_test']

    trainloaders = []
    testloaders = []
    # task_per_itr = int(config['batch_size']/config['num_dataset_to_run'])
    task_per_itr = int(config['batch_size']/config['num_dataset_to_run'])
    
    for task_idx, task in enumerate(config['dataset_ls'][:num_dataset_to_run], 0):
        
        split_dir = os.path.join(os.path.join(config['data_dir'], '5-shot'), task)
        
        #* for training dataset
        train_task_list = np.load(os.path.join(split_dir, 'metatrain.npy'))
        trainset = FewShotImageDataset(task_list=train_task_list, transform=transformation, 
                                device=config['device'], task_name = task, verbose='{} trainset'.format(task))
        trainsampler = SuppQueryBatchSampler(
            dataset=trainset, num_way=config['num_way'], num_shot=config['num_shot'],
            num_query_per_cls=config['num_query_per_cls'], num_task=task_per_itr
        )
        trainloader = DataLoader(trainset, batch_sampler=trainsampler)
        trainloaders.append(trainloader)

    for task_idx, task in enumerate(config['test_dataset_ls'][:num_dataset_to_test], 0):
        
        split_dir = os.path.join(os.path.join(config['data_dir'], '5-shot'), task)
        
        #* for testing dataset
        test_task_list = np.load(os.path.join(split_dir, 'metatest.npy'))
        testset = FewShotImageDataset(task_list=test_task_list, transform=transformation, 
                                device=config['device'], task_name = task, verbose='{} testset'.format(task))
        testsampler = SuppQueryBatchSampler(
            dataset=testset, num_way=config['num_way'], num_shot=config['num_shot'],
            num_query_per_cls=config['num_query_per_cls'], num_task=config['num_val_task'], train_shuffle=False
        )
        testloader = DataLoader(testset, batch_sampler=testsampler)
        testloaders.append(testloader)

    #* train and test the model
    inner_args = util.config_inner_args(config['inner_args'])
    
    # define the variable to save the result
    train_loss_dict = dict()
    train_acc_dict = dict()

    test_loss_mean_dict_list = []
    test_loss_std_dict_list = []
    test_acc_mean_dict_list = []
    test_acc_std_dict_list = []

    for i in range(config['num_dataset_to_test']):
        test_loss_mean_dict_list.append(dict())
        test_loss_std_dict_list.append(dict())
        test_acc_mean_dict_list.append(dict())
        test_acc_std_dict_list.append(dict())

    checkpoint = 0.0
    max_index = 0

    train_times = dict()
    for itr in trange(config['num_epoch'], desc='meta-train', ncols=100):
        
        model.train()
        epoch_begin_time = time.time()
        train_loss, train_accuracy = meta_train(metalearner, meta_optimizer, trainloaders, inner_args, config, itr, seq_task=config['seq_task'], writer=writer) 

        epoch_end_time = time.time()
        train_times[itr] = epoch_end_time-epoch_begin_time

        train_loss_dict[itr] = train_loss
        train_acc_dict[itr] = train_accuracy

        # if 'prompt' in config.keys():
        #     cov = model.prompt_cov.data.cpu().numpy()
        #     cov = np.log(1.+np.exp(cov))
            # print('the std of cov is: {}'.format(np.std(cov)))
            # print('max:{}, min:{}, mean:{}'.format(np.max(cov), np.min(cov), np.mean(cov)))

        if (itr+1)%500 == 0 or itr==0:
        # if (itr+1)%1000 == 0:
            test_loss_mean_list, test_loss_std_list, test_acc_mean_list, test_acc_std_list \
                            = meta_test(metalearner, testloaders, inner_args, config, itr, writer, result_dir)
            
            for i in range(config['num_dataset_to_test']):
                test_loss_mean_dict_list[i][itr] = test_loss_mean_list[i]
                test_loss_std_dict_list[i][itr] = test_loss_std_list[i]
                test_acc_mean_dict_list[i][itr] = test_acc_mean_list[i]
                test_acc_std_dict_list[i][itr] = test_acc_std_list[i]
                print('{}: acc_mean={} and acc_std={}'.format(config['test_dataset_ls'][i], test_acc_mean_list[i], test_acc_std_list[i]))

            print('Average Performance: acc_mean={} and acc_std={}'.format(np.mean(test_acc_mean_list), np.mean(test_acc_std_list)))
            
            if np.mean(test_acc_mean_list) > checkpoint:
                checkpoint = np.mean(test_acc_mean_list)
                max_index = itr
                torch.save(model.state_dict(), f=os.path.join(result_dir, 'step{}.pt'.format(itr)))

        # * save the model
        # if (itr+1)%100 == 0:
        #     torch.save(model.state_dict(), f=os.path.join(result_dir, 'step{}.pt'.format(itr)))

        # lr_scheduler.step()

    #* save result
    writer.close()
    np.save(os.path.join(result_dir, 'train_test.npy'), train_times)
    np.save(os.path.join(result_dir, 'train_loss.npy'), train_loss_dict)
    np.save(os.path.join(result_dir, 'train_acc.npy'), train_acc_dict)

    for i in range(config['num_dataset_to_test']):
        np.save(os.path.join(result_dir, '{}_test_loss_mean.npy'.format(config['test_dataset_ls'][i])), test_loss_mean_dict_list[i])
        np.save(os.path.join(result_dir, '{}_test_loss_std.npy'.format(config['test_dataset_ls'][i])), test_loss_std_dict_list[i])
        np.save(os.path.join(result_dir, '{}_test_acc_mean.npy'.format(config['test_dataset_ls'][i])), test_acc_mean_dict_list[i])
        np.save(os.path.join(result_dir, '{}_test_acc_std.npy'.format(config['test_dataset_ls'][i])), test_acc_std_dict_list[i])

        # acc_max_itr = max(test_acc_mean_dict_list[i], key=test_acc_mean_dict_list[i].get)
        # acc_max = test_acc_mean_dict_list[i].get(acc_max_itr)
        # acc_max_std = test_acc_std_dict_list[i].get(acc_max_itr)

        # print("the best performance of {} is {:.2f} + {:.2f}".format(config['dataset_ls'][i], acc_max*100, acc_max_std*100))

    max_avg_mean = 0.0
    max_avg_std = 0.0
    for i in range(config['num_dataset_to_test']):
        max_avg_mean += test_acc_mean_dict_list[i][max_index]
        max_avg_std += test_acc_std_dict_list[i][max_index]
        print("the best performance of {} is {:.2f} + {:.2f}".format(config['test_dataset_ls'][i], test_acc_mean_dict_list[i][max_index]*100, test_acc_std_dict_list[i][max_index]*100))

    print("the best average performance is {:.2f} + {:.2f}".format(max_avg_mean*100/config['num_dataset_to_test'], max_avg_std*100/config['num_dataset_to_test']))


    end_time = time.time()
    total_run_time = end_time - begin_time
    print('-------------------------------------')
    print('All the processes finish and the total time cost is {}'.format(total_run_time))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', 
                        help='configuration file')
    parser.add_argument('--efficient', 
                        help='if True, enables gradient checkpointing',
                        action='store_true',
                        default=False)
    parser.add_argument('--seq_task',
                        help='if True, the mini-batch consists of tasks only from one task distribution',
                        default=False)
    parser.add_argument('--adjust',
                        help='if True, the saved file will be create based on the experimental time',
                        default=False)
    args = parser.parse_args()
    # load config file
    jsonfile = open(str(args.config_path))
    config = json.loads(jsonfile.read())
    config['adjust'] = args.adjust
    config['seq_task'] = args.seq_task

    run_spec = os.path.splitext(os.path.split(args.config_path)[-1])[0]

    main(config, run_spec)