import argparse
from asyncore import write
import json
import os
import random
import time
from collections import OrderedDict
from unittest import result

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import numpy as np
from tqdm import tqdm, trange
from torch.utils.data import DataLoader

from tensorboardX import SummaryWriter

import sys
sys.path.append('..')
import models
import optimizers
from config.configuration import *


from util import enlist_transformation
from data_generate.dataset import FewShotImageDataset
from data_generate.sampler import SuppQueryBatchSampler
from meta_train import meta_train
from meta_test import meta_test
import util



def main(config, run_spec):

    begin_time = time.time()

    # set the random seed
    # random.seed(0)
    # np.random.seed(0)
    # torch.manual_seed(0)
    # torch.cuda.manual_seed(0)

    # set the record file
    # start_datetime = datetime.datetime.now()
    # experiment_date = '{:%Y-%m-%d_%H:%M:%S}'.format(start_datetime)

    #* change the config
    config['dataset_ls'] = config['dataset_ls'][:config['num_dataset_to_run']]
    config['classifier_args']['n_way'] = config['num_way']

    #* set the save file path
    run_name = get_run_name(config['dataset_ls'])
    meta_save_path = os.path.join(config['run_dir'], '{}_shot'.format(config['num_shot']), run_name, run_spec+'{}'.format(config['prompt_args']['dim']))
    # if not os.path.exists(meta_save_path):
    #     os.makedirs(meta_save_path, exist_ok=True)
    ensure_path(meta_save_path)
    
    #* save the config file
    with open(os.path.join(meta_save_path, 'config_{}.json'.format(run_spec)), 'w') as outfile:
        outfile.write(json.dumps(config, indent=4))

    #* set the tensorboard writer
    writer = SummaryWriter(os.path.join(meta_save_path, 'log'))
    result_dir = os.path.join(meta_save_path, 'result')
    if not os.path.exists(result_dir):
        os.makedirs(result_dir, exist_ok=True)

    #* define the training dataset
    transformation = transforms.Compose(
        enlist_transformation(img_resize=config['img_resize'], is_grayscale=config['is_grayscale'],
                            device=config['device'], img_normalise=config['img_normalise'])
    )    
    num_dataset_to_run = len(config['dataset_ls']) if config['num_dataset_to_run'] == "all" \
                        else config['num_dataset_to_run']

    trainloaders = []
    testloaders = []
    for task_idx, task in enumerate(config['dataset_ls'][:num_dataset_to_run], 0):
        
        split_dir = os.path.join(os.path.join(config['data_dir'], '5-shot'), task)
        
        #* for training dataset
        train_task_list = np.load(os.path.join(split_dir, 'metatrain.npy'))
        trainset = FewShotImageDataset(task_list=train_task_list, transform=transformation, 
                                device=config['device'], verbose='{} trainset'.format(task))
        trainsampler = SuppQueryBatchSampler(
            dataset=trainset, num_way=config['num_way'], num_shot=config['num_shot'],
            num_query_per_cls=config['num_query_per_cls'], num_task=config['num_task_per_itr']
        )
        trainloader = DataLoader(trainset, batch_sampler=trainsampler)
        trainloaders.append(trainloader)

        #* for testing dataset
        test_task_list = np.load(os.path.join(split_dir, 'metatest.npy'))
        testset = FewShotImageDataset(task_list=test_task_list, transform=transformation, 
                                device=config['device'], verbose='{} testset'.format(task))
        testsampler = SuppQueryBatchSampler(
            dataset=testset, num_way=config['num_way'], num_shot=config['num_shot'],
            num_query_per_cls=config['num_query_per_cls'], num_task=config['num_val_task']
        )
        testloader = DataLoader(testset, batch_sampler=testsampler)
        testloaders.append(testloader)


    #* load the model
    if config['prompt']:
        model = models.make(config['encoder'], config['encoder_args'],
                        config['classifier'], config['classifier_args'], config['img_resize'], config['device'], 
                        config['prompt'], config['prompt_args'])
    else:
        model = models.make(config['encoder'], config['encoder_args'],
                        config['classifier'], config['classifier_args'], config['img_resize'], config['device'])
    optimizer, lr_scheduler = optimizers.make(config['optimizer'], model.parameters(),
                                            **config['optimizer_args'])

    #* train and test the model
    inner_args = util.config_inner_args(config['inner_args'])
    
    # define the variable to save the result
    train_loss_dict = dict()
    train_acc_dict = dict()

    test_loss_mean_dict_list = []
    test_loss_std_dict_list = []
    test_acc_mean_dict_list = []
    test_acc_std_dict_list = []

    for i in range(config['num_dataset_to_run']):
        test_loss_mean_dict_list.append(dict())
        test_loss_std_dict_list.append(dict())
        test_acc_mean_dict_list.append(dict())
        test_acc_std_dict_list.append(dict())

    for itr in trange(config['num_epoch'], desc='meta-train', ncols=100):
        
        model.train()

        train_loss, train_accuracy = meta_train(model, optimizer, lr_scheduler, trainloaders, inner_args, config, itr, writer) 
        train_loss_dict[itr] = train_loss
        train_acc_dict[itr] = train_accuracy

        if (itr+1)%10 == 0 or itr==0:
            test_loss_mean_list, test_loss_std_list, test_acc_mean_list, test_acc_std_list \
                            = meta_test(model, testloaders, inner_args, config, itr, writer)
            
            for i in range(config['num_dataset_to_run']):
                test_loss_mean_dict_list[i][itr] = test_loss_mean_list[i]
                test_loss_std_dict_list[i][itr] = test_loss_std_list[i]
                test_acc_mean_dict_list[i][itr] = test_acc_mean_list[i]
                test_acc_std_dict_list[i][itr] = test_acc_std_list[i]
        
        # * save the model
        if (itr+1)%100 == 0:
            torch.save(model.state_dict(), f=os.path.join(result_dir, 'step{}.pt'.format(itr)))
    
    #* save result
    writer.close()
    np.save(os.path.join(result_dir, 'train_loss.npy'), train_loss_dict)
    np.save(os.path.join(result_dir, 'train_acc.npy'), train_acc_dict)

    for i in range(config['num_dataset_to_run']):
        np.save(os.path.join(result_dir, '{}_test_loss_mean.npy'.format(config['dataset_ls'][i])), test_loss_mean_dict_list[i])
        np.save(os.path.join(result_dir, '{}_test_loss_std.npy'.format(config['dataset_ls'][i])), test_loss_std_dict_list[i])
        np.save(os.path.join(result_dir, '{}_test_acc_mean.npy'.format(config['dataset_ls'][i])), test_acc_mean_dict_list[i])
        np.save(os.path.join(result_dir, '{}_test_acc_std.npy'.format(config['dataset_ls'][i])), test_acc_std_dict_list[i])

    end_time = time.time()
    total_run_time = end_time - begin_time
    print('-------------------------------------')
    print('All the processes finish and the total time cost is {}'.format(total_run_time))



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', 
                        help='configuration file')
    parser.add_argument('--efficient', 
                        help='if True, enables gradient checkpointing',
                        action='store_true',
                        default=False)
    args = parser.parse_args()
    # load config file
    jsonfile = open(str(args.config_path))
    config = json.loads(jsonfile.read())

    run_spec = os.path.splitext(os.path.split(args.config_path)[-1])[0]

    main(config, run_spec)