"""This is a subset of the original torch_geometric.utils.metric from PyG 2.0.4.

PyG removed metrics in favor of torchmetrics, but we still need the PyG batching.
See: https://github.com/pyg-team/pytorch_geometric/discussions/7434
"""

from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor


def intersection_and_union(
    pred: Tensor, target: Tensor, num_classes: int, batch: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor]:
    r"""Computes intersection and union of predictions.

    Args:
        pred (LongTensor): The predictions.
        target (LongTensor): The targets.
        num_classes (int): The number of classes.
        batch (LongTensor): The assignment vector which maps each pred-target
            pair to an example.

    :rtype: (LongTensor, LongTensor)
    """
    pred, target = F.one_hot(pred, num_classes), F.one_hot(target, num_classes)

    if batch is None:
        i = (pred & target).sum(dim=0)
        u = (pred | target).sum(dim=0)
    else:
        i = scatter_add(pred & target, batch, dim=0)
        u = scatter_add(pred | target, batch, dim=0)

    return i, u


def mean_iou(
    pred: Tensor,
    target: Tensor,
    num_classes: int,
    batch: Optional[Tensor] = None,
    omitnans: bool = False,
) -> Tensor:
    r"""Computes the mean intersection over union score of predictions.

    Args:
        pred (LongTensor): The predictions.
        target (LongTensor): The targets.
        num_classes (int): The number of classes.
        batch (LongTensor): The assignment vector which maps each pred-target
            pair to an example.
        omitnans (bool, optional): If set to True, will ignore any
            NaN values encountered during computation. Otherwise, will
            treat them as 1. (default: False)

    :rtype: Tensor
    """
    i, u = intersection_and_union(pred, target, num_classes, batch)
    iou = i.to(torch.float) / u.to(torch.float)

    if omitnans:
        iou = iou[~iou.isnan()].mean()
    else:
        iou[torch.isnan(iou)] = 1.0
        iou = iou.mean(dim=-1)

    return iou


def my_mse(pred, target, batch):
    """Compute MSE loss for each pytorch geometric batch.
    Args:
        pred: (n_nodes, num_features)
        target: (n_nodes, num_features)
        batch: (n_nodes,) integers denoting the batch index of each node

    Returns:
        loss: (num_batches,) tensor with MSE loss for each batch
    """
    loss = ((pred - target) ** 2).mean(-1)
    return scatter_mean(loss, batch, dim=0)


def correlation(preds, target):
    '''
    requires shapes (n_timestep n_points n_trajectories)
    '''
    assert preds.shape == target.shape, f"Shapes do not match: preds: {preds.shape} target: {target.shape}"

    preds_mean = torch.mean(preds, dim=-2, keepdim=True)
    target_mean = torch.mean(target, dim=-2, keepdim=True)
    preds_std = torch.std(preds, dim=-2)
    target_std = torch.std(target, dim=-2)

    # calculate mean correlation per timestep
    mean_corr_per_timestep = (
        torch.mean((preds - preds_mean) * (target - target_mean), dim=-2)
        / (preds_std * target_std).clamp(min=1e-12)
    ).mean(dim=-1)
    return mean_corr_per_timestep
