"""Energy Score helper functions."""

from typing import Optional

import torch
import torch.nn as nn
from einops import rearrange


class ES(nn.Module):
    """
    Energy Score (ES) loss for ensemble predictions.
    """

    def __init__(self, p=2):
        super().__init__()
        self.p = p

    def forward(
        self, predictions, target, metadata, mem_efficient: Optional[bool] = False
    ):
        """
        Compute ES loss for ensemble predictions

        Args:
            predictions: [B, N_ensemble, L,  H, [W], [D], C]
            target: [B, N_ensemble, L,  H, [W], [D], C]
            metadata: Metadata for the dataset
            mem_efficient: Whether to use memory-efficient computation (computes one time slice at a time)
        Returns:
            ES loss [B, L] averaged over spatial dimensions and channels
        """
        if mem_efficient:
            return compute_es_loss_mem_efficient(
                predictions,
                target,
                metadata,
            )
        else:
            return compute_es_loss(
                predictions,
                target,
                metadata,
            )


def compute_es_loss(predictions, target, metadata, p=2):
    """
    Compute ES loss for ensemble predictions

    Args:
        predictions: [B, N_ensemble, L, H, [W], [D], C] - Ensemble predictions
        target: [B, N_ensemble, L, H, [W], [D], C] - Ground truth (same shape as predictions)
        metadata: Metadata for the dataset

    Returns:
        ES loss [B, L] averaged over spatial dimensions and channels
    """

    deterministic = False
    # Verify input shapes
    assert (
        predictions.shape == target.shape
    ), f"Predictions {predictions.shape} and target {target.shape} must have same shape"
    assert (
        predictions.ndim >= 5
    ), f"Expected at least 5D tensor [B, N, L, H, C], got {predictions.ndim}D"

    # If we have deterministic predictions (N=1), add ensemble dimension
    num_spatial_dims = len(metadata.spatial_resolution)
    if len(predictions.shape) != 4 + num_spatial_dims:
        assert (
            len(predictions.shape) == 3 + num_spatial_dims
        ), f"Unexpected predictions shape {predictions.shape}"
        predictions = predictions.unsqueeze(1)
        target = target.unsqueeze(1)
        deterministic = True

    B, N, L = predictions.shape[:3]

    # Flatten spatial and channel dimensions: [B, N, L, D_vec], where D_vec = H * [W] * [D] * C
    predictions_flat = rearrange(predictions, "B N L ... -> B N L (...)")
    target_flat = rearrange(target, "B N L ... -> B N L (...)")

    # For efficient batched computation, combine B and L dims
    # Shape becomes [B*L, N, D_vec]
    predictions_cdist = rearrange(predictions_flat, "B N L D -> (B L) N D")
    target_cdist = rearrange(target_flat, "B N L D -> (B L) N D")

    # Handle deterministic case (ES is just Mean L2 Error)
    if deterministic or N == 1:
        # Compute L2 norm (Euclidean distance) along the D_vec dim
        l2_distances = torch.norm(
            predictions_cdist - target_cdist, p=p, dim=-1
        )  # [B*L, 1]
        term1 = l2_distances.squeeze(1)  # [B*L]
        # Reshape back to [B, L]
        return rearrange(term1, "(B L) -> B L", B=B)

    # --- Term 1: E[||X - y||] ---
    # (Ensemble-Truth Distance)
    # [B*L, N, D_vec]
    diff_truth = predictions_cdist - target_cdist
    # Compute L2 norm (Euclidean distance) along the D_vec dim
    # Shape becomes [B*L, N]
    l2_dist_truth = torch.norm(diff_truth, p=p, dim=-1)
    # Average over ensemble
    # Shape becomes [B*L]
    term1 = torch.mean(l2_dist_truth, dim=1)

    # --- Term 2: 0.5 * E[||X - X'||] ---
    # (Ensemble-Ensemble Distance)
    # We use torch.cdist for a memory-efficient O(N^2) pairwise distance calculation.
    # l2_dist_pairs shape: [B*L, N, N]
    l2_dist_pairs = torch.cdist(predictions_cdist, predictions_cdist, p=p)

    # Sum all N*N pairwise distances.
    # We use the same "fair" (N-1) denominator as your CRPS implementation.
    # The sum includes i==j terms, but they are 0.
    term2_sum = torch.sum(l2_dist_pairs, dim=(1, 2))  # Shape [B*L]
    term2 = 0.5 * (1 / (N * (N - 1))) * term2_sum

    # Final Energy Score
    es = term1 - term2  # Shape [B*L]

    # Reshape back to [B, L]
    es = rearrange(es, "(B L) -> B L", B=B)

    return es


def compute_es_loss_mem_efficient(predictions, target, metadata, p=2):
    """
    Compute ES loss for ensemble predictions - memory efficient version where we compute one time slice at a time

    Args:
        predictions: [B, N_ensemble, L, H, [W], [D], C] - Ensemble predictions
        target: [B, N_ensemble, L, H, [W], [D], C] - Ground truth (same shape as predictions)
        metadata: Metadata for the dataset

    Returns:
        ES loss [B, L] averaged over spatial dimensions and channels
    """

    deterministic = False
    # Verify input shapes
    assert (
        predictions.shape == target.shape
    ), f"Predictions {predictions.shape} and target {target.shape} must have same shape"
    assert (
        predictions.ndim >= 5
    ), f"Expected at least 5D tensor [B, N, L, H, C], got {predictions.ndim}D"

    # If we have deterministic predictions (N=1), add ensemble dimension
    num_spatial_dims = len(metadata.spatial_resolution)
    if len(predictions.shape) != 4 + num_spatial_dims:
        assert (
            len(predictions.shape) == 3 + num_spatial_dims
        ), f"Unexpected predictions shape {predictions.shape}"
        predictions = predictions.unsqueeze(1)
        target = target.unsqueeze(1)
        deterministic = True
    elif predictions.shape[1] == 1:
        deterministic = True

    B, N, L = predictions.shape[:3]

    # Flatten spatial and channel dimensions: [B, N, L, D_vec], where D_vec = H * [W] * [D] * C
    predictions_flat = rearrange(predictions, "B N L ... -> B N L (...)")
    target_flat = rearrange(target, "B N L ... -> B N L (...)")

    # Initialize list to store per-timestep ES values
    es_per_timestep = []

    # Process each timestep separately to save memory
    for t in range(L):
        # Extract current timestep: [B, N, D_vec]
        pred_t = predictions_flat[:, :, t, :]  # [B, N, D_vec]
        target_t = target_flat[:, :, t, :]  # [B, N, D_vec]

        # Term 1: E[|X - Y|] for this timestep
        diff_truth = pred_t - target_t
        l2_dist_truth = torch.norm(diff_truth, p=p, dim=-1)
        term1_t = torch.mean(l2_dist_truth, dim=1)  # Average over ensemble: [B]

        # Term 2: 0.5 * E[|X - X'|] for this timestep (ensemble spread)
        if not deterministic:
            l2_dist_pairs = torch.cdist(pred_t, pred_t, p=p)
            term2_t = (
                0.5 * (1 / (N * (N - 1))) * torch.sum(l2_dist_pairs, dim=(1, 2))
            )  # [B]
        else:
            term2_t = torch.zeros_like(term1_t)

        # ES for this timestep: [B]
        es_t = term1_t - term2_t

        # Store this timestep's ES
        es_per_timestep.append(es_t)

    # Stack all timesteps to get [B, L, C]
    es = torch.stack(es_per_timestep, dim=1)  # [B, L]

    return es
