import wandb
import torch
from torch.nn import functional as F
from sklearn.metrics import accuracy_score
import numpy as np

def get_pred( meta_module, dataloader, params=None, embedding=None):
    y_pred = []
    y_true = []
    for sample in dataloader:
        sample_image_features = sample["image_features"]
        sample_text_features = sample["text_features"]
        if embedding is None:
            embedding = sample["ques_emb"][0]
        labels = sample["label"].to(sample_image_features.device)
        similarity = meta_module(sample_image_features, sample_text_features, embedding, params=params)
        y_pred.append(similarity)
        y_true.append(labels)

    return torch.cat(y_pred), torch.cat(y_true)

def test_accuracy( meta_module, dataloader, params=None, embedding=None, params_list=None):
    # Validation inner-loop testing
    meta_module.eval()

    with torch.no_grad():
        if params_list is not None:
            output_list = []
            for p in params_list:
                output, y_true = get_pred(meta_module, dataloader, params=p, embedding=embedding)
                output_list.append(output)
            output = torch.stack(output_list).mean(0)
        else:
            output, y_true = get_pred(meta_module, dataloader, params=params, embedding=embedding)
        _, y_pred = output.topk(1)
        loss = F.cross_entropy(output, y_true)

    acc = accuracy_score(y_true.cpu().numpy(), y_pred.cpu().numpy())
    meta_module.train()
    return acc, loss.item()

def append_dict(dictionary, new_dict):
    for key in new_dict:
        if key not in dictionary:
            dictionary[key]=[]
        dictionary[key].append(new_dict[key])

def mean_dict(dictionary):
    out_dict = dict()
    for key in dictionary:
        if isinstance(dictionary[key], list):
            if isinstance(dictionary[key][0], list):
                out_dict[key] = [np.mean([dictionary[key][j][i] for j in range(len(dictionary[key]))]) for i in range(len(dictionary[key][0]))]
            else:
                out_dict[key] = np.mean(np.array(dictionary[key]))
        else:
            out_dict[key] = dictionary[key]
    return out_dict

def log_metric(log_dict, prefix=""):
    prefixed_dict = dict()
    for key in log_dict:
        if not isinstance(log_dict[key], list):
            prefixed_dict[prefix+key] = log_dict[key]

    wandb.log(prefixed_dict)

    for key in log_dict:
        if isinstance(log_dict[key], list):
            for i in range(len(log_dict[key])):
                wandb.log({prefix+key: log_dict[key][i]})
