import math
import numpy as np
from sklearn.datasets import load_files
from tqdm import tqdm

import time
import os
import torch
import torch.nn.functional as F

import util

def meta_test(model, testloaders, inner_args, config, itr_index, writer=None, result_dir=None):
    
    model.eval()
    all_loss_mean = []
    all_loss_std = []
    all_acc_mean = []
    all_acc_std = []

    #* for each dataset
    for index_ in range(len(testloaders)):
        
        testloader = testloaders[index_]

        #* define the variable to store the result from the same distribution
        loss_one_dis = []
        accuracy_one_dis = []
        test_times = dict()
        num = 0
        for images, labels in testloader:
            
            test_begin_time = time.time()
            # divide the data intp the device
            imgs = images.to(config['device'])
            lbls = labels.to(config['device'])
            supp_idx = config['num_way']*config['num_shot']
            support_img, query_img = imgs[:supp_idx].unsqueeze(dim=0), imgs[supp_idx:].unsqueeze(dim=0)
            support_lbl, query_lbl = lbls[:supp_idx].unsqueeze(dim=0), lbls[supp_idx:].unsqueeze(dim=0)

            if not model.is_prompt:
                logits, _ = model(support_img, query_img, support_lbl, inner_args, meta_train=False)
                
            else:
                logits = torch.zeros([1,config["num_query_per_cls"]*5,5]).to(config['device'])

                for i in range(config['prompt_args']['sample_num']):

                    sample_logits, _ = model(support_img, query_img, support_lbl, inner_args, meta_train=False)
                    logits += sample_logits
                
                logits = logits/config['prompt_args']['sample_num']
            
            test_end_time = time.time()
            test_times[num] = test_end_time - test_begin_time
            num += 1
            logits = logits.flatten(0,1)
            labels = query_lbl.flatten()
            
            pred = torch.argmax(logits, dim=1)
            acc = util.compute_acc(pred, labels)
            loss = F.cross_entropy(logits, labels)

            accuracy_one_dis.append(acc)
            loss_one_dis.append(loss.item())

        np.save(os.path.join(result_dir, config['test_dataset_ls'][index_]+'test_test.npy'), test_times)
        #* compute the mean and the std of result on one distribution
        loss_mean = np.mean(loss_one_dis)
        acc_mean = np.mean(accuracy_one_dis)

        sqrt_nsample = math.sqrt(config['num_val_task'])
        loss_95ci = 1.96 * np.std(loss_one_dis, ddof=1) / sqrt_nsample
        acc_95ci = 1.96 * np.std(accuracy_one_dis, ddof=1) / sqrt_nsample

        # print('loss: {} + {}'.format(loss_mean, loss_95ci))
        # print('acc: {} + {}'.format(acc_mean, acc_95ci))

        writer.add_scalar(
            tag='loss_meta_eval_task{}'.format(index_),
            scalar_value=loss_mean, global_step=itr_index
        )
        writer.add_scalar(
            tag='accuracy_meta_eval_task{}'.format(index_),
            scalar_value=acc_mean, global_step=itr_index
        )
        writer.add_scalar(
            tag='loss95ci_meta_eval_task{}'.format(index_),
            scalar_value=loss_95ci, global_step=itr_index
        )
        writer.add_scalar(
            tag='acc95ci_meta_eval_task{}'.format(index_),
            scalar_value=acc_95ci, global_step=itr_index
        )

        # save the data to return
        all_loss_mean.append(loss_mean)
        all_loss_std.append(loss_95ci)
        all_acc_mean.append(acc_mean)
        all_acc_std.append(acc_95ci)
    
    writer.add_scalar(
        tag='loss_meta_eval_task_average',
        scalar_value=np.mean(all_loss_mean), global_step=itr_index
    )
    writer.add_scalar(
        tag='accuracy_meta_eval_task_average',
        scalar_value=np.mean(all_acc_mean), global_step=itr_index
    )
    writer.add_scalar(
        tag='loss95ci_meta_eval_task_average',
        scalar_value=np.mean(all_loss_std), global_step=itr_index
    )
    writer.add_scalar(
        tag='acc95ci_meta_eval_task_average',
        scalar_value=np.mean(all_acc_std), global_step=itr_index
    )
    
    return all_loss_mean, all_loss_std, all_acc_mean, all_acc_std
    
    