import argparse
from asyncore import write
import datetime
import json
import os
import random
from re import L
import time
from collections import OrderedDict
from unittest import result

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import numpy as np
from tqdm import tqdm, trange
from torch.utils.data import DataLoader

from tensorboardX import SummaryWriter

import sys
sys.path.append('..')
import models
from models.ProtoNet import ProtoNet
import optimizers
from config.configuration import *


from util import enlist_transformation
from data_generate.dataset import FewShotImageDataset
from data_generate.sampler import SuppQueryBatchSampler
from meta_train_proto import meta_train
from meta_test_proto import meta_test
import util



def main(config, run_spec):

    begin_time = time.time()

    # set the random seed
    # random.seed(0)
    # np.random.seed(0)
    # torch.manual_seed(0)
    # torch.cuda.manual_seed(0)

    # set the record file
    # start_datetime = datetime.datetime.now()
    # experiment_date = '{:%Y-%m-%d_%H:%M:%S}'.format(start_datetime)

    #* change the config
    config['dataset_ls'] = config['dataset_ls'][:config['num_dataset_to_run']]
    config['classifier_args']['n_way'] = config['num_way']

    #* set the save file path
    start_datetime = datetime.datetime.now()
    experiment_date = '{:%Y-%m-%d_%H:%M:%S}'.format(start_datetime)
    run_name = get_run_name(config['dataset_ls'])
    if config['num_dataset_to_run'] < config['num_dataset_to_test']:
        run_name = 'cross_'+run_name

    if config['encoder'] != 'convnet4':
        config['run_dir'] += config['encoder']
    
    if config['adjust']:
        meta_save_path = os.path.join(config['run_dir'], '{}_shot'.format(config['num_shot']), run_name, run_spec+'_{}'.format(experiment_date))
    else:
        if 'prompt' in config.keys():
            if config['prompt_args']['dynamic_lr']:
                meta_save_path = os.path.join(config['run_dir'], '{}_shot'.format(config['num_shot']), run_name, run_spec+'{}_{}'.format(config['prompt_args']['dim'], config['prompt_args']['kl_weight']))
            elif config['prompt_args']['bayesian']:
                meta_save_path = os.path.join(config['run_dir'], '{}_shot'.format(config['num_shot']), run_name, run_spec+'{}_bayesian'.format(config['prompt_args']['dim']))
            else:
                meta_save_path = os.path.join(config['run_dir'], '{}_shot'.format(config['num_shot']), run_name, run_spec+'{}'.format(config['prompt_args']['dim']))
        else:
            meta_save_path = os.path.join(config['run_dir'], '{}_shot'.format(config['num_shot']), run_name, run_spec)
    # if not os.path.exists(meta_save_path):
    #     os.makedirs(meta_save_path, exist_ok=True)
    
    ensure_path(meta_save_path)
    
    #* save the config file
    with open(os.path.join(meta_save_path, 'config_{}.json'.format(run_spec)), 'w') as outfile:
        outfile.write(json.dumps(config, indent=4))

    #* set the tensorboard writer
    writer = SummaryWriter(os.path.join(meta_save_path, 'log'))
    result_dir = os.path.join(meta_save_path, 'result')
    if not os.path.exists(result_dir):
        os.makedirs(result_dir, exist_ok=True)
    


    #* load the model
    model = ProtoNet(config['encoder'], n_way=config['num_way'], n_support=config['num_shot'])
    model.to(config['device'])

    if config['optimizer'] == 'prompt':
        prompt_parameters = []
        other_parameters = []
        for name, param in model.named_parameters():
            if name == "prompt_mean" or name == "prompt_cov":
                prompt_parameters += [param]
            else:
                other_parameters += [param]
    
        optimizer, lr_scheduler = optimizers.make(config['optimizer'], [prompt_parameters, other_parameters],
                                                **config['optimizer_args'])
    else:
        optimizer, lr_scheduler = optimizers.make(config['optimizer'], model.parameters(),
                                            **config['optimizer_args'])

    #* define the training dataset
    transformation = transforms.Compose(
        enlist_transformation(img_resize=config['img_resize'], is_grayscale=config['is_grayscale'],
                            device=config['device'], img_normalise=config['img_normalise'])
    )    
    num_dataset_to_run = len(config['dataset_ls']) if config['num_dataset_to_run'] == "all" \
                        else config['num_dataset_to_run']
    
    num_dataset_to_test = len(config['test_dataset_ls']) if config['num_dataset_to_test'] == "all" \
                        else config['num_dataset_to_test']

    trainloaders = []
    testloaders = []
    # task_per_itr = int(config['batch_size']/config['num_dataset_to_run'])
    task_per_itr = int(config['batch_size']/config['num_dataset_to_run'])
    
    for task_idx, task in enumerate(config['dataset_ls'][:num_dataset_to_run], 0):
        
        split_dir = os.path.join(os.path.join(config['data_dir'], '5-shot'), task)
        
        #* for training dataset
        train_task_list = np.load(os.path.join(split_dir, 'metatrain.npy'))
        trainset = FewShotImageDataset(task_list=train_task_list, transform=transformation, 
                                device=config['device'], task_name = task, verbose='{} trainset'.format(task))
        trainsampler = SuppQueryBatchSampler(
            dataset=trainset, num_way=config['num_way'], num_shot=config['num_shot'],
            num_query_per_cls=config['num_query_per_cls'], num_task=task_per_itr, train_shuffle=False
        )
        trainloader = DataLoader(trainset, batch_sampler=trainsampler)
        trainloaders.append(trainloader)

    for task_idx, task in enumerate(config['test_dataset_ls'][:num_dataset_to_test], 0):
        
        split_dir = os.path.join(os.path.join(config['data_dir'], '5-shot'), task)
        
        #* for testing dataset
        test_task_list = np.load(os.path.join(split_dir, 'metatest.npy'))
        testset = FewShotImageDataset(task_list=test_task_list, transform=transformation, 
                                device=config['device'], task_name = task, verbose='{} testset'.format(task))
        testsampler = SuppQueryBatchSampler(
            dataset=testset, num_way=config['num_way'], num_shot=config['num_shot'],
            num_query_per_cls=config['num_query_per_cls'], num_task=config['num_val_task'], train_shuffle=False
        )
        testloader = DataLoader(testset, batch_sampler=testsampler)
        testloaders.append(testloader)

    #* train and test the model
    inner_args = util.config_inner_args(config['inner_args'])
    
    # define the variable to save the result
    train_loss_dict = dict()
    train_acc_dict = dict()

    test_loss_mean_dict_list = []
    test_loss_std_dict_list = []
    test_acc_mean_dict_list = []
    test_acc_std_dict_list = []

    for i in range(config['num_dataset_to_test']):
        test_loss_mean_dict_list.append(dict())
        test_loss_std_dict_list.append(dict())
        test_acc_mean_dict_list.append(dict())
        test_acc_std_dict_list.append(dict())

    checkpoint = 0.0
    max_index = 0

    train_times = dict()
    for itr in trange(config['num_epoch'], desc='meta-train', ncols=100):
        
        model.train()
        epoch_begin_time = time.time()
        train_loss, train_accuracy = meta_train(model, optimizer, lr_scheduler, trainloaders, inner_args, config, itr, config['seq_task'], writer) 
        epoch_end_time = time.time()
        train_times[itr] = epoch_end_time-epoch_begin_time

        train_loss_dict[itr] = train_loss
        train_acc_dict[itr] = train_accuracy

        # if 'prompt' in config.keys():
        #     cov = model.prompt_cov.data.cpu().numpy()
        #     cov = np.log(1.+np.exp(cov))
            # print('the std of cov is: {}'.format(np.std(cov)))
            # print('max:{}, min:{}, mean:{}'.format(np.max(cov), np.min(cov), np.mean(cov)))

        if (itr+1)%500 == 0 or itr==0:
        # if (itr+1)%1000 == 0:
            test_loss_mean_list, test_loss_std_list, test_acc_mean_list, test_acc_std_list \
                            = meta_test(model, testloaders, inner_args, config, itr, writer, result_dir)
            
            for i in range(config['num_dataset_to_test']):
                test_loss_mean_dict_list[i][itr] = test_loss_mean_list[i]
                test_loss_std_dict_list[i][itr] = test_loss_std_list[i]
                test_acc_mean_dict_list[i][itr] = test_acc_mean_list[i]
                test_acc_std_dict_list[i][itr] = test_acc_std_list[i]
                print('{}: acc_mean={} and acc_std={}'.format(config['test_dataset_ls'][i], test_acc_mean_list[i], test_acc_std_list[i]))

            print('Average Performance: acc_mean={} and acc_std={}'.format(np.mean(test_acc_mean_list), np.mean(test_acc_std_list)))
            
            if np.mean(test_acc_mean_list) > checkpoint:
                checkpoint = np.mean(test_acc_mean_list)
                max_index = itr
                torch.save(model.state_dict(), f=os.path.join(result_dir, 'step{}.pt'.format(itr)))

        # * save the model
        # if (itr+1)%100 == 0:
        #     torch.save(model.state_dict(), f=os.path.join(result_dir, 'step{}.pt'.format(itr)))

        lr_scheduler.step()

    #* save result
    writer.close()
    np.save(os.path.join(result_dir, 'train_test.npy'), train_times)
    np.save(os.path.join(result_dir, 'train_loss.npy'), train_loss_dict)
    np.save(os.path.join(result_dir, 'train_acc.npy'), train_acc_dict)

    for i in range(config['num_dataset_to_test']):
        np.save(os.path.join(result_dir, '{}_test_loss_mean.npy'.format(config['test_dataset_ls'][i])), test_loss_mean_dict_list[i])
        np.save(os.path.join(result_dir, '{}_test_loss_std.npy'.format(config['test_dataset_ls'][i])), test_loss_std_dict_list[i])
        np.save(os.path.join(result_dir, '{}_test_acc_mean.npy'.format(config['test_dataset_ls'][i])), test_acc_mean_dict_list[i])
        np.save(os.path.join(result_dir, '{}_test_acc_std.npy'.format(config['test_dataset_ls'][i])), test_acc_std_dict_list[i])

        # acc_max_itr = max(test_acc_mean_dict_list[i], key=test_acc_mean_dict_list[i].get)
        # acc_max = test_acc_mean_dict_list[i].get(acc_max_itr)
        # acc_max_std = test_acc_std_dict_list[i].get(acc_max_itr)

        # print("the best performance of {} is {:.2f} + {:.2f}".format(config['dataset_ls'][i], acc_max*100, acc_max_std*100))

    max_avg_mean = 0.0
    max_avg_std = 0.0
    for i in range(config['num_dataset_to_test']):
        max_avg_mean += test_acc_mean_dict_list[i][max_index]
        max_avg_std += test_acc_std_dict_list[i][max_index]
        print("the best performance of {} is {:.2f} + {:.2f}".format(config['test_dataset_ls'][i], test_acc_mean_dict_list[i][max_index]*100, test_acc_std_dict_list[i][max_index]*100))

    print("the best average performance is {:.2f} + {:.2f}".format(max_avg_mean*100/config['num_dataset_to_test'], max_avg_std*100/config['num_dataset_to_test']))


    end_time = time.time()
    total_run_time = end_time - begin_time
    print('-------------------------------------')
    print('All the processes finish and the total time cost is {}'.format(total_run_time))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', 
                        help='configuration file')
    parser.add_argument('--efficient', 
                        help='if True, enables gradient checkpointing',
                        action='store_true',
                        default=False)
    parser.add_argument('--seq_task',
                        help='if True, the mini-batch consists of tasks only from one task distribution',
                        default=False)
    parser.add_argument('--adjust',
                        help='if True, the saved file will be create based on the experimental time',
                        default=False)
    args = parser.parse_args()
    # load config file
    jsonfile = open(str(args.config_path))
    config = json.loads(jsonfile.read())
    config['adjust'] = args.adjust
    config['seq_task'] = args.seq_task

    run_spec = os.path.splitext(os.path.split(args.config_path)[-1])[0]

    main(config, run_spec)