from collections import OrderedDict
import math
import numpy as np
from sklearn.datasets import load_files
from tqdm import tqdm

import torch
import torch.nn.functional as F

import util

def meta_test(model, testloaders, inner_args, config, itr_index, writer=None):
    
    model.eval()
    all_loss_mean = []
    all_loss_std = []
    all_acc_mean = []
    all_acc_std = []

    #* for each dataset
    for index_ in range(len(testloaders)):
        
        testloader = testloaders[index_]

        #* define the variable to store the result from the same distribution
        loss_one_dis = []
        accuracy_one_dis = []
        
        for images, labels in testloader:
            
            # divide the data intp the device
            imgs = images.to(config['device'])
            lbls = labels.to(config['device'])
            supp_idx = config['num_way']*config['num_shot']
            support_img, query_img = imgs[:supp_idx], imgs[supp_idx:]
            support_lbl, query_lbl = lbls[:supp_idx], lbls[supp_idx:]

            Y_sup_hat = model(support_img)
            loss = F.cross_entropy(Y_sup_hat, support_lbl)
            zero_grad(model.parameters())
            grads = torch.autograd.grad(loss, model.parameters())

            adapted_state_dict = model.cloned_state_dict()
            adapted_params = OrderedDict()
            for (key, val), grad in zip(model.named_parameters(), grads):
                # NOTE Here Meta-SGD is different from naive MAML
                task_lr = model.task_lr[key]
                adapted_params[key] = val - task_lr * grad
                adapted_state_dict[key] = adapted_params[key]
            Y_que_hat = model(query_img, adapted_state_dict)            
            
            pred = torch.argmax(Y_que_hat, dim=1)
            acc = util.compute_acc(pred, query_lbl)
            loss = F.cross_entropy(Y_que_hat, query_lbl)

            accuracy_one_dis.append(acc)
            loss_one_dis.append(loss.item())

        
        #* compute the mean and the std of result on one distribution
        loss_mean = np.mean(loss_one_dis)
        acc_mean = np.mean(accuracy_one_dis)

        sqrt_nsample = math.sqrt(config['num_val_task'])
        loss_95ci = 1.96 * np.std(loss_one_dis, ddof=1) / sqrt_nsample
        acc_95ci = 1.96 * np.std(accuracy_one_dis, ddof=1) / sqrt_nsample

        # print('loss: {} + {}'.format(loss_mean, loss_95ci))
        # print('acc: {} + {}'.format(acc_mean, acc_95ci))

        writer.add_scalar(
            tag='loss_meta_eval_task{}'.format(index_),
            scalar_value=loss_mean, global_step=itr_index
        )
        writer.add_scalar(
            tag='accuracy_meta_eval_task{}'.format(index_),
            scalar_value=acc_mean, global_step=itr_index
        )
        writer.add_scalar(
            tag='loss95ci_meta_eval_task{}'.format(index_),
            scalar_value=loss_95ci, global_step=itr_index
        )
        writer.add_scalar(
            tag='acc95ci_meta_eval_task{}'.format(index_),
            scalar_value=acc_95ci, global_step=itr_index
        )

        # save the data to return
        all_loss_mean.append(loss_mean)
        all_loss_std.append(loss_95ci)
        all_acc_mean.append(acc_mean)
        all_acc_std.append(acc_95ci)
    
    writer.add_scalar(
        tag='loss_meta_eval_task_average',
        scalar_value=np.mean(all_loss_mean), global_step=itr_index
    )
    writer.add_scalar(
        tag='accuracy_meta_eval_task_average',
        scalar_value=np.mean(all_acc_mean), global_step=itr_index
    )
    writer.add_scalar(
        tag='loss95ci_meta_eval_task_average',
        scalar_value=np.mean(all_loss_std), global_step=itr_index
    )
    writer.add_scalar(
        tag='acc95ci_meta_eval_task_average',
        scalar_value=np.mean(all_acc_std), global_step=itr_index
    )
    
    return all_loss_mean, all_loss_std, all_acc_mean, all_acc_std
    

def zero_grad(params):
    for p in params:
        if p.grad is not None:
            p.grad.zero_()