import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, List
from sklearn.metrics import roc_auc_score


def multi_roc(source: List[np.ndarray], target: np.ndarray) -> Tuple[float, List[float]]:
    list_roc = []
    n_m = len(source)
    for i in range(n_m):
        target_i = target[:, i]
        target_i = target_i[np.logical_not(np.isnan(target_i))]
        src = torch.softmax(torch.from_numpy(source[i]), dim=-1).numpy()
        tgt = np.zeros_like(src)
        for j in range(target_i.shape[0]):
            tgt[j, int(target_i[j])] = 1.
        try:
            roc = roc_auc_score(tgt, src)
        except ValueError:
            roc = 1
        list_roc.append(roc)
    return sum(list_roc) / len(list_roc), list_roc


def multi_mse_loss(source: torch.Tensor, target: torch.Tensor, explicit=False) -> torch.Tensor:
    se = (source - target) ** 2
    mse = torch.mean(se, dim=0)
    if explicit:
        return mse
    else:
        return torch.sum(mse)


def multi_mae_loss(source: torch.Tensor, target: torch.Tensor, explicit=False) -> torch.Tensor:
    ae = torch.abs(source - target)
    mae = torch.mean(ae, dim=0)
    if explicit:
        return mae
    else:
        return torch.sum(mae)


def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    return F.mse_loss(source, target)


def rmse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    return F.mse_loss(source, target).sqrt()


def mae_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    return torch.mean(torch.abs(source - target))
