from functools import partial

import torch
import torch.nn.functional as F
import torch.nn as nn
import torchmetrics
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss

from pytorch3d.loss import chamfer_distance

from register import Registry

def register_metrics(cfg, loaders):
    if 'mean' in loaders.keys() and 'std' in loaders.keys():
        LOSS_REGISTRY.register('relax_mse', partial(LOSS_REGISTRY.get('relax_mse'), mean=loaders['mean'], std=loaders['std']))
        LOSS_REGISTRY.register('relax_mae', partial(LOSS_REGISTRY.get('relax_mae'), mean=loaders['mean'], std=loaders['std']))
        LOSS_REGISTRY.register('mse', partial(LOSS_REGISTRY.get('mse'), mean=loaders['mean'], std=loaders['std']))
        LOSS_REGISTRY.register('mae', partial(LOSS_REGISTRY.get('mae'), mean=loaders['mean'], std=loaders['std']))
        ACC_REGISTRY.register('mae', partial(ACC_REGISTRY.get('mae'), mean=loaders['mean'], std=loaders['std']))
        ACC_REGISTRY.register('mse', partial(ACC_REGISTRY.get('mse'), mean=loaders['mean'], std=loaders['std']))
        ACC_REGISTRY.register('relax_mae', partial(ACC_REGISTRY.get('relax_mae'), mean=loaders['mean'], std=loaders['std']))

    if cfg['acc'] == 'multi':
        ACC_REGISTRY.register('multi', partial(ACC_REGISTRY.get('multi'), num_classes=cfg['num_classes']))

"""
LOSS
=======
"""
def ce_loss(pred, label):
    metric = CrossEntropyLoss()
    loss = metric(pred, label)
    return loss
def bce_loss(pred, label):
    metric = BCEWithLogitsLoss()
    loss = metric(pred, torch.squeeze(label).float())
    return loss
def mse_loss(pred, label, mean=0, std=1):
    metric = nn.MSELoss()
    loss = metric(pred.view(-1), (label.view(-1)-mean)/std)
    return loss
def mae_loss(pred, label, mean=0, std=1):
    metric = nn.L1Loss()
    loss = metric(pred.view(-1), (label.view(-1)-mean)/std)
    return loss
def relax_mae_loss(pred, label, mean=0, std=1):
    metric = nn.L1Loss()
    pred = pred.view(-1, label.shape[-1])
    losses = []
    for batch in range(label.shape[0]):
        loss = []
        for i in range(label.shape[1]):
            label_i = label[batch, i]
            loss.append(metric(pred[batch], (label_i-mean)/std))
        losses.append(min(loss))
    return sum(losses)/len(losses)
def relax_mse_loss(pred, label, mean=0, std=1):
    metric = nn.MSELoss()
    pred = pred.view(-1, label.shape[-1])
    losses = []
    for batch in range(label.shape[0]):
        loss = []
        for i in range(label.shape[1]):
            label_i = label[batch, i]
            loss.append(metric(pred[batch], (label_i-mean)/std))
        losses.append(min(loss))
    return sum(losses)/len(losses)
def chamfer_loss(pred, label):
    pred = pred.view(1,-1,3)
    loss, _ = chamfer_distance(pred, label.view(1,pred.shape[1],3))
    return loss
def relax_chamfer_loss(pred, label, mean=0, std=1):
    losses = []
    for batch in range(label.shape[0]):
        loss = []
        for i in range(label.shape[1]):
            label_i = label[batch, i].view(1, -1, 3)
            pred_i = pred.view(1, -1, 3)
            loss_, norm = chamfer_distance(pred_i, label_i)
            loss.append(loss_)
        losses.append(min(loss))
    return sum(losses)/len(losses)


LOSS_REGISTRY = Registry()
LOSS_REGISTRY.register('crossentropy', ce_loss)
LOSS_REGISTRY.register('bce', bce_loss)
LOSS_REGISTRY.register('relax_mse', relax_mse_loss)
LOSS_REGISTRY.register('relax_mae', relax_mae_loss)
LOSS_REGISTRY.register('mse', mse_loss)
LOSS_REGISTRY.register('mae', mae_loss)
LOSS_REGISTRY.register('relax_chamfer', relax_chamfer_loss)
LOSS_REGISTRY.register('chamfer', chamfer_loss)


"""
ACC
=======
"""
def multi_acc(output, target, num_classes):
    predictions = torch.argmax(output, dim=1)
    return torchmetrics.functional.accuracy(predictions, target, task="multiclass", num_classes=num_classes)
def mse_acc(output, target, mean=0, std=1):
    return torchmetrics.functional.mean_squared_error(std*output.view(-1) + mean, target.view(-1))
def mae_acc(output, target, mean=0, std=1):
    return torchmetrics.functional.mean_absolute_error(std*output.view(-1) + mean, target.view(-1))
def relax_mae_acc(pred, label, mean=0, std=1):
    pred = pred.view(-1, label.shape[-1])
    accs = []
    for batch in range(label.shape[0]):
        acc = []
        for i in range(label.shape[1]):
            label_i = label[batch, i]
            acc.append(torchmetrics.functional.mean_absolute_error(std*pred[batch]+mean, label_i))
        accs.append(min(acc))
    return sum(accs)/len(accs)

ACC_REGISTRY = Registry()
ACC_REGISTRY.register('relax_mae', relax_mae_acc)
ACC_REGISTRY.register('relax_chamfer', relax_chamfer_loss)
ACC_REGISTRY.register('chamfer', chamfer_loss)
ACC_REGISTRY.register('multi', multi_acc)
ACC_REGISTRY.register('mse', mse_acc)
ACC_REGISTRY.register('mae', mae_acc)
ACC_REGISTRY.register('roc_auc', torchmetrics.ROC)



def auroc(xhat, data, mode: str):
    metric = BinaryAUROC()
    if hasattr(data, f"{mode}_mask"):
        mask = getattr(data, f"{mode}_mask")
        loss = metric(xhat[mask], torch.squeeze(data.y[mask]))
    else:
        loss = metric(xhat, torch.squeeze(data.y))
    return loss


def regression_acc(xhat, data, mode: str):
    if hasattr(data, f"{mode}_mask"):
        mask = getattr(data, f"{mode}_mask")
        acc = torch.mean((torch.round(xhat[mask]) == data.y[mask]).float(),
                         dtype=torch.float32)
    else:
        acc = torch.mean((torch.round(xhat) == data.y).float(),
                         dtype=torch.float32)
    return acc


def average_multilabel_precision(xhat, data, mode: str):
    num_labels = xhat.size(1)
    metric = MultilabelAveragePrecision(num_labels=num_labels).to(xhat.device)
    if hasattr(data, f"{mode}_mask"):
        mask = getattr(data, f"{mode}_mask")
        loss = metric(xhat[mask], torch.squeeze(data.y[mask]).long())
    else:
        loss = metric(xhat, torch.squeeze(data.y).long())
    return loss


def regression_precision(xhat, data, mode: str):
    metric = BinaryPrecision().to(xhat.device)
    xhat_greater = (xhat >= 0.5).long()
    y_greater = torch.squeeze((data.y >= 1)).long()
    if hasattr(data, f"{mode}_mask"):
        mask = getattr(data, f"{mode}_mask")
        loss = metric(xhat_greater[mask], y_greater[mask])
    else:
        loss = metric(xhat_greater, y_greater)
    return loss


def regression_recall(xhat, data, mode: str):
    metric = BinaryRecall().to(xhat.device)
    xhat_greater = (xhat >= 0.5).long()
    y_greater = torch.squeeze((data.y >= 1)).long()
    if hasattr(data, f"{mode}_mask"):
        mask = getattr(data, f"{mode}_mask")
        loss = metric(xhat_greater[mask], y_greater[mask])
    else:
        loss = metric(xhat_greater, y_greater)
    return loss


def num_true_positives(xhat, data, mode: str):
    return torch.sum(torch.squeeze(data.y)).item()
