import torch
from the_well.benchmark.metrics.common import Metric
from the_well.data.datasets import WellMetadata


class MSE(Metric):
    @staticmethod
    def eval(
        x: torch.Tensor,
        y: torch.Tensor,
        meta: WellMetadata,
    ) -> torch.Tensor:
        """
        Mean Squared Error

        Parameters
        ----------
        x : torch.Tensor | np.ndarray
            Input tensor.
        y : torch.Tensor | np.ndarray
            Target tensor.
        meta : WellMetadata
            Metadata for the dataset.

        Returns
        -------
        torch.Tensor
            Mean squared error between x and y.
        """
        n_spatial_dims = tuple(range(-meta.n_spatial_dims - 1, -1))
        return torch.mean((x - y) ** 2, dim=n_spatial_dims)


class NMSE(Metric):
    @staticmethod
    def eval(
        x: torch.Tensor,
        y: torch.Tensor,
        meta: WellMetadata,
        eps: float = 1e-7,
        norm_mode: str = "norm",
    ) -> torch.Tensor:
        """
        Normalized Mean Squared Error

        Parameters
        ----------
        x : torch.Tensor | np.ndarray
            Input tensor.
        y : torch.Tensor | np.ndarray
            Target tensor.
        meta : WellMetadata
            Metadata for the dataset.
        eps : float
            Small value to avoid division by zero. Default is 1e-7.
        norm_mode : str
            Mode for computing the normalization factor. Can be 'norm' or 'std'. Default is 'norm'.

        Returns
        -------
        torch.Tensor
            Normalized mean squared error between x and y.
        """
        n_spatial_dims = tuple(range(-meta.n_spatial_dims - 1, -1))
        if norm_mode == "norm":
            norm = torch.mean(y**2, dim=n_spatial_dims)
        elif norm_mode == "std":
            norm = torch.std(y, dim=n_spatial_dims) ** 2
        else:
            raise ValueError(f"Invalid norm_mode: {norm_mode}")
        return MSE.eval(x, y, meta) / (norm + eps)


class RMSE(Metric):
    @staticmethod
    def eval(
        x: torch.Tensor,
        y: torch.Tensor,
        meta: WellMetadata,
    ) -> torch.Tensor:
        """
        Root Mean Squared Error

        Parameters
        ----------
        x : torch.Tensor | np.ndarray
            Input tensor.
        y : torch.Tensor | np.ndarray
            Target tensor.
        meta : WellMetadata
            Metadata for the dataset.

        Returns
        -------
        torch.Tensor
            Root mean squared error between x and y.
        """
        return torch.sqrt(MSE.eval(x, y, meta))


class NRMSE(Metric):
    @staticmethod
    def eval(
        x: torch.Tensor,
        y: torch.Tensor,
        meta: WellMetadata,
        eps: float = 1e-7,
        norm_mode: str = "norm",
    ) -> torch.Tensor:
        """
        Normalized Root Mean Squared Error

        Parameters
        ----------
        x : torch.Tensor | np.ndarray
            Input tensor.
        y : torch.Tensor | np.ndarray
            Target tensor.
        meta : WellMetadata
            Metadata for the dataset.
        eps : float
            Small value to avoid division by zero. Default is 1e-7.
        norm_mode : str
            Mode for computing the normalization factor. Can be 'norm' or 'std'. Default is 'norm'.

        Returns
        -------
        torch.Tensor
            Normalized root mean squared error between x and y.
        """
        return torch.sqrt(NMSE.eval(x, y, meta, eps=eps, norm_mode=norm_mode))


class VMSE(Metric):
    @staticmethod
    def eval(
        x: torch.Tensor,
        y: torch.Tensor,
        meta: WellMetadata,
    ) -> torch.Tensor:
        """
        Variance Scaled Mean Squared Error

        Parameters
        ----------
        x : torch.Tensor | np.ndarray
            Input tensor.
        y : torch.Tensor | np.ndarray
            Target tensor.
        meta : WellMetadata
            Metadata for the dataset.

        Returns
        -------
        torch.Tensor
            Variance mean squared error between x and y.
        """
        return NMSE.eval(x, y, meta, norm_mode="std")


class VRMSE(Metric):
    @staticmethod
    def eval(
        x: torch.Tensor,
        y: torch.Tensor,
        meta: WellMetadata,
    ) -> torch.Tensor:
        """
        Root Variance Scaled Mean Squared Error

        Parameters
        ----------
        x : torch.Tensor | np.ndarray
            Input tensor.
        y : torch.Tensor | np.ndarray
            Target tensor.
        meta : WellMetadata
            Metadata for the dataset.

        Returns
        -------
        torch.Tensor
            Root variance mean squared error between x and y.
        """
        return NRMSE.eval(x, y, meta, norm_mode="std")


class LInfinity(Metric):
    @staticmethod
    def eval(
        x: torch.Tensor,
        y: torch.Tensor,
        meta: WellMetadata,
    ) -> torch.Tensor:
        """
        L-Infinity Norm

        Parameters
        ----------
        x : torch.Tensor | np.ndarray
            Input tensor.
        y : torch.Tensor | np.ndarray
            Target tensor.
        meta : WellMetadata
            Metadata for the dataset.

        Returns
        -------
        torch.Tensor
            L-Infinity norm between x and y.
        """
        spatial_dims = tuple(range(-meta.n_spatial_dims - 1, -1))
        return torch.max(
            torch.abs(x - y).flatten(start_dim=spatial_dims[0], end_dim=-2), dim=-2
        ).values


class SpreadSkill(Metric):
    @staticmethod
    def eval(
        x: torch.Tensor,
        y: torch.Tensor,
        meta: WellMetadata,
        eps: float = 1e-7,
    ) -> torch.Tensor:
        """
        Spread-Skill Ratio according to Fortin et al. (2014)

        Args:
            x: Input tensor (ensemble predictions). Shape: [B, N_ensemble, T, (spatial_dims), C]
            y: Target tensor (ground truth). Shape: [B, N_ensemble, T, (spatial_dims), C]
            meta: Metadata for the dataset.
            eps: Small value to avoid division by zero. Default is 1e-7.

        Returns:
            Spread-Skill ratio [B, T, C] - per timestep, per channel
        """
        # Verify input shapes
        assert (
            x.shape == y.shape
        ), f"Predictions {x.shape} and target {y.shape} must have same shape"

        # Get spatial dimensions for averaging
        num_spatial_dims = len(meta.spatial_resolution)
        spatial_dims = tuple(range(-meta.n_spatial_dims - 1, -1))

        # Ensemble mean: [B, T, (spatial_dims), C]
        ens_mean = x.mean(dim=1)  # Average over ensemble dimension

        # RMSE between ensemble mean and ground truth
        # First average target over ensemble dimension (should be identical across ensemble members)
        y_mean = y.mean(dim=1)  # [B, T, (spatial_dims), C]

        # Calculate MSE and average over spatial dimensions: [B, T, C]
        mse = torch.mean((ens_mean - y_mean) ** 2, dim=spatial_dims)
        rmse = torch.sqrt(mse)  # [B, T, C]

        # Variance across ensemble members (population variance)
        # unbiased=False gives population variance (ddof=0)
        var = x.var(dim=1, unbiased=False)  # [B, T, (spatial_dims), C]

        # Mean variance across spatial dimensions: [B, T, C]
        mean_variance = var.mean(dim=spatial_dims)

        # Spread (square root of mean variance): [B, T, C]
        spread = torch.sqrt(mean_variance)

        # Spread-skill ratio: [B, T, C]
        ssr = spread / (rmse + eps)

        return ssr
