import argparse
from asyncore import write
import datetime
import json
import os
import random
from re import L
import time
from collections import OrderedDict
from unittest import result

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import numpy as np
from tqdm import tqdm, trange
from torch.utils.data import DataLoader

from tensorboardX import SummaryWriter

import sys
sys.path.append('..')
import models
import optimizers
from config.configuration import *


from util import enlist_transformation
from data_generate.dataset import FewShotImageDataset
from data_generate.sampler import SuppQueryBatchSampler
import util

def main(config, checkpoint_path, config_path):

    # load the model
    if 'prompt' in config.keys():
        model = models.make(config['encoder'], config['encoder_args'],
                    config['classifier'], config['classifier_args'], config['img_resize'], config['device'], 
                    config['prompt'], config['prompt_args'], config['decoder'], config['decoder_args'])
    else:
        model = models.make(config['encoder'], config['encoder_args'],
                        config['classifier'], config['classifier_args'], config['img_resize'], config['device'])
    
    checkpoint = torch.load(checkpoint_path)
    msg = model.load_state_dict(checkpoint, strict=False)

    # load the data
    #* define the training dataset
    transformation = transforms.Compose(
        enlist_transformation(img_resize=config['img_resize'], is_grayscale=config['is_grayscale'],
                            device=config['device'], img_normalise=config['img_normalise'])
    )
    num_dataset_to_test = len(config['test_dataset_ls']) if config['num_dataset_to_test'] == "all" \
                        else config['num_dataset_to_test']

    testloaders = []
    
    # #! the number of tasks
    # config['num_val_task'] = 2

    save_path = '../results'
    
    save_folder = os.path.join(save_path, config_path.split('/')[-2], config_path.split('/')[-1][:-5])
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    print(save_folder)

    for task_idx, task in enumerate(config['test_dataset_ls'][:num_dataset_to_test], 0):
        
        split_dir = os.path.join(os.path.join(config['data_dir'], '5-shot'), task)
        
        #* for testing dataset
        test_task_list = np.load(os.path.join(split_dir, 'metatest.npy'))
        testset = FewShotImageDataset(task_list=test_task_list, transform=transformation, 
                                device=config['device'], task_name = task, verbose='{} testset'.format(task))
        testsampler = SuppQueryBatchSampler(
            dataset=testset, num_way=config['num_way'], num_shot=config['num_shot'],
            num_query_per_cls=config['num_query_per_cls'], num_task=config['num_val_task']
        )
        testloader = DataLoader(testset, batch_sampler=testsampler)
        testloaders.append(testloader)

    # begin to evluation
    inner_args = util.config_inner_args(config['inner_args'])

    prompts = dict()
    for i in range(config['num_dataset_to_test']):
        prompts[config['test_dataset_ls'][i]] = []

    model.eval()

    num_sample = 100
    # evaluate the tasks and save the task-specific embeddings
    for index_ in range(len(testloaders)):

        testloader = testloaders[index_]
        
        for images, labels in tqdm(testloader):
            task_prompt = []
            for _ in range(num_sample):
                imgs = images.to(config['device'])
                lbls = labels.to(config['device'])
                supp_idx = config['num_way']*config['num_shot']
                support_img, query_img = imgs[:supp_idx].unsqueeze(dim=0), imgs[supp_idx:].unsqueeze(dim=0)
                support_lbl, query_lbl = lbls[:supp_idx].unsqueeze(dim=0), lbls[supp_idx:].unsqueeze(dim=0)
            
                prompt_ = model(support_img, query_img, support_lbl, inner_args, meta_train=False, return_prompt=True)
                task_prompt.append(prompt_)
        
            prompts[config['test_dataset_ls'][index_]].append(np.array(task_prompt).squeeze())
            print(prompts[config['test_dataset_ls'][index_]][0].shape)
    
    np.save(save_folder+'/meta_mean.npy', model.prompt_mean.cpu().detach().numpy())
    np.save(save_folder+'/meta_cov.npy', model.prompt_mean.cpu().detach().numpy())

    for i in range(config['num_dataset_to_test']):
        np.save(save_folder+'/{}.npy'.format(config['test_dataset_ls'][i]), np.array(prompts[config['test_dataset_ls'][i]]))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--efficient', 
                        help='if True, enables gradient checkpointing',
                        action='store_true',
                        default=False)
    parser.add_argument('--seq_task',
                        help='if True, the mini-batch consists of tasks only from one task distribution',
                        default=False)
    parser.add_argument('--adjust',
                        help='if True, the saved file will be create based on the experimental time',
                        default=False)
    args = parser.parse_args()
    # load config file
    # config_path = '../config/5-shot/testSoft.json'
    config_path = '../config/5-shot/testHard.json'
    jsonfile = open(str(config_path))
    config = json.loads(jsonfile.read())
    config['adjust'] = False
    config['seq_task'] = False
    config['prompt_args']['dim']=64
    config['prompt_args']['kl_weight']=1.0

    config['dataset_ls'] = config['dataset_ls'][:config['num_dataset_to_run']]
    config['classifier_args']['n_way'] = config['num_way']

    # checkpoint_path = '/root/tf-logs/runs/wide-convnet4/5_shot/aircraft-cifar_fs-mini_imagenet-miniQuickDraw/testSoft64_1.0/result/step9499.pt'
    checkpoint_path = '/root/tf-logs/runs/wide-convnet4/5_shot/aircraft-cifar_fs-mini_imagenet-miniQuickDraw/testHard64_1.0/result/step9499.pt'

    main(config, checkpoint_path, config_path)