import numpy as np
import torch
from tqdm import tqdm
import torch.nn.functional as F
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator


def l2_norm(input):
    input_size = input.size()
    buffer = torch.pow(input, 2)
    normp = torch.sum(buffer, 1).add_(1e-12)
    norm = torch.sqrt(normp)
    _output = torch.div(input, norm.view(-1, 1).expand_as(input))
    output = _output.view(input_size)

    return output


def calc_recall_at_k(T, Y, k):
    """
    T : [nb_samples] (target labels)
    Y : [nb_samples x k] (k predicted labels/neighbours)
    """

    s = 0
    for t, y in zip(T, Y):
        if t in torch.Tensor(y).long()[:k]:
            s += 1
    return s / (1.0 * len(T))


def predict_batchwise(model, dataloader):
    device = "cuda"
    model_is_training = model.training
    model.eval()

    ds = dataloader.dataset
    A = [[] for i in range(len(ds[0]))]
    with torch.no_grad():
        # extract batches (A becomes list of samples)
        for batch in tqdm(dataloader):
            for i, J in enumerate(batch):
                # i = 0: sz_batch * images
                # i = 1: sz_batch * labels
                # i = 2: sz_batch * indices
                if i == 0:
                    # move images to device of model (approximate device)
                    J = model(J.cuda())

                for j in J:
                    A[i].append(j)
    model.train()
    model.train(model_is_training)  # revert to previous training state

    return [torch.stack(A[i]) for i in range(len(A))]


def proxy_init_calc(model, dataloader):
    nb_classes = dataloader.dataset.nb_classes()
    X, T, *_ = predict_batchwise(model, dataloader)

    proxy_mean = torch.stack(
        [X[T == class_idx].mean(0) for class_idx in range(nb_classes)]
    )

    return proxy_mean


def evaluate_cos(model, dataloader):
    acc = AccuracyCalculator()

    # calculate embeddings with model and get targets
    X, T = predict_batchwise(model, dataloader)
    X = l2_norm(X)

    X = X.detach().cpu().numpy()
    T = T.detach().cpu().numpy()
    pml_acc = acc.get_accuracy(
        query=X,
        reference=X,
        query_labels=T,
        reference_labels=T,
        embeddings_come_from_same_source=True,
        exclude=("AMI", "NMI"),
    )
    print(pml_acc)
    return pml_acc


def evaluate_cos_Inshop(model, query_dataloader, gallery_dataloader):
    acc = AccuracyCalculator()

    # calculate embeddings with model and get targets
    query_X, query_T = predict_batchwise(model, query_dataloader)
    gallery_X, gallery_T = predict_batchwise(model, gallery_dataloader)

    query_X = l2_norm(query_X)
    gallery_X = l2_norm(gallery_X)

    pml_acc = acc.get_accuracy(
        query=query_X.detach().cpu().numpy(),
        reference=gallery_X.detach().cpu().numpy(),
        query_labels=query_T.detach().cpu().numpy(),
        reference_labels=gallery_T.detach().cpu().numpy(),
        embeddings_come_from_same_source=False,
        exclude=("AMI", "NMI"),
    )
    print(pml_acc)

    return pml_acc


def evaluate_cos_SOP(model, dataloader):
    acc = AccuracyCalculator()

    # calculate embeddings with model and get targets
    X, T = predict_batchwise(model, dataloader)
    X = l2_norm(X)

    X = X.detach().cpu().numpy()
    T = T.detach().cpu().numpy()
    pml_acc = acc.get_accuracy(
        query=X,
        reference=X,
        query_labels=T,
        reference_labels=T,
        embeddings_come_from_same_source=True,
        exclude=("AMI", "NMI"),
    )
    print(pml_acc)

    return pml_acc
