import numpy as np
import torch
import torch.nn.functional as F
from typing import Dict, Optional, List
from torchmetrics.image import StructuralSimilarityIndexMeasure


def crps(
    y_true_tensor: torch.Tensor,  # Expects (b, T, c, h, w) on device
    y_pred_tensor: torch.Tensor,  # Expects (b, n_samples, T, c, h, w) on device
    pool_type: str = "none",
    scale: int = 1,
    mode: str = "mean",
    eps: float = 1e-10,
) -> float:
    """
    Computes the Continuous Ranked Probability Score (CRPS) using PyTorch tensors.

    This implementation assumes a Gaussian distribution for the forecast ensemble and is
    optimized for GPU execution. It can optionally perform spatial pooling on the data
    before CRPS calculation.

    Args:
        y_true_tensor (torch.Tensor): Ground truth tensor with shape (b, T, c, h, w),
                                      where b is batch size, T is number of lead times,
                                      c is channels, h is height, and w is width.
        y_pred_tensor (torch.Tensor): Prediction tensor with shape (b, n, T, c, h, w),
                                      where n is the number of ensemble samples.
        pool_type (str, optional): Type of spatial pooling to apply.
                                   Can be "none", "avg", or "max". Defaults to "none".
        scale (int, optional): The kernel size and stride for spatial pooling. Defaults to 1.
        mode (str, optional): Aggregation mode for the final score. Can be "mean" or "sum".
                              Defaults to "mean".
        eps (float, optional): A small epsilon value to prevent division by zero in
                               standard deviation calculation. Defaults to 1e-10.

    Returns:
        float: The computed CRPS value, aggregated according to the `mode`.
    """
    device = y_true_tensor.device  # Get device from input tensor
    _normal_dist = torch.distributions.Normal(
        torch.tensor(0.0, device=device),  # loc on the correct device
        torch.tensor(1.0, device=device),  # scale on the correct device
        validate_args=False,
    )
    _frac_sqrt_pi = 1.0 / np.sqrt(np.pi)
    b_shape, T_shape, c_shape, h_shape, w_shape = y_true_tensor.shape
    _, n_shape, _, _, _, _ = y_pred_tensor.shape

    gt_for_pool = y_true_tensor.reshape(b_shape * T_shape, c_shape, h_shape, w_shape)
    pred_for_pool = y_pred_tensor.reshape(
        b_shape * n_shape * T_shape, c_shape, h_shape, w_shape
    )

    if scale > 1 and pool_type in ["avg", "max"]:
        if pool_type == "avg":
            gt_pooled = F.avg_pool2d(gt_for_pool, kernel_size=scale, stride=scale)
            pred_pooled = F.avg_pool2d(pred_for_pool, kernel_size=scale, stride=scale)
        elif pool_type == "max":  # pool_type == "max"
            gt_pooled = F.max_pool2d(gt_for_pool, kernel_size=scale, stride=scale)
            pred_pooled = F.max_pool2d(pred_for_pool, kernel_size=scale, stride=scale)
    else:  # No pooling or invalid pool_type
        gt_pooled = gt_for_pool
        pred_pooled = pred_for_pool

    new_h, new_w = gt_pooled.shape[-2:]

    # Reshape back to (b, T, c, new_h, new_w) for gt
    # and (b, n, T, c, new_h, new_w) for pred
    gt_rearr = gt_pooled.reshape(b_shape, T_shape, c_shape, new_h, new_w)
    pred_rearr = pred_pooled.reshape(b_shape, n_shape, T_shape, c_shape, new_h, new_w)

    pred_mean = torch.mean(pred_rearr, dim=1)
    pred_std = (
        torch.std(pred_rearr, dim=1, unbiased=True)
        if n_shape > 1
        else torch.zeros_like(pred_mean)
    )

    # Add eps to stddev in denominator to prevent division by zero
    normed_diff = (pred_mean - gt_rearr) / (
        pred_std + eps
    )  # Removed eps from numerator, typically not needed there

    cdf = _normal_dist.cdf(normed_diff)
    pdf = torch.exp(_normal_dist.log_prob(normed_diff))

    crps_val_tensor = (pred_std + eps) * (
        normed_diff * (2 * cdf - 1) + 2 * pdf - _frac_sqrt_pi
    )

    if mode == "mean":
        return torch.mean(crps_val_tensor).item()
    elif mode == "sum":
        return torch.sum(crps_val_tensor).item()
    return torch.mean(crps_val_tensor).item()


class MetricsAccumulator:
    """
    A class to accumulate and compute various metrics for probabilistic weather forecasts
    in a streaming (chunk-by-chunk) manner.

    This class is designed to handle large datasets that may not fit into memory by
    processing data in chunks. It computes metrics both from the ensemble members
    (per-sample) and from the ensemble mean.
    """

    def __init__(
        self,
        lead_time: int,
        thresholds: Optional[List[float]] = None,
        pool_size: int = 16,
        compute_mse: bool = True,
        compute_threshold: bool = True,
        compute_apsd: bool = True,
        compute_ssim: bool = False,
        compute_crps: bool = True,
        crps_pool_type: str = "none",
        crps_scale: int = 1,
        crps_eps: float = 1e-10,
        ssim_data_range: float = 255.0,
        device: Optional[torch.device] = None,
    ):
        """
        Initializes the MetricsAccumulator.

        Args:
            lead_time (int): The specific lead time index to compute metrics for.
            thresholds (Optional[List[float]], optional): A list of thresholds for
                categorical metrics (CSI, POD, etc.). Defaults to [0.5].
            pool_size (int, optional): The kernel size for max-pooling when computing
                the pooled CSI metric. Defaults to 16.
            compute_mse (bool, optional): Whether to compute Mean Squared Error. Defaults to True.
            compute_threshold (bool, optional): Whether to compute threshold-based metrics.
                Defaults to True.
            compute_apsd (bool, optional): Whether to compute Average Power Spectral Density error.
                Defaults to True.
            compute_ssim (bool, optional): Whether to compute Structural Similarity Index Measure.
                Defaults to False.
            compute_crps (bool, optional): Whether to compute Continuous Ranked Probability Score.
                Defaults to True.
            crps_pool_type (str, optional): The type of pooling to apply before CRPS calculation
                ("none", "avg", "max"). Defaults to "none".
            crps_scale (int, optional): The scale/kernel size for CRPS pooling. Defaults to 1.
            crps_eps (float, optional): Epsilon for CRPS calculation to avoid division by zero.
                Defaults to 1e-10.
            ssim_data_range (float, optional): The data range for SSIM calculation. Defaults to 255.0.
            device (Optional[torch.device], optional): The device to perform computations on.
                If None, defaults to CUDA if available, otherwise CPU.
        """
        self.lead_time = lead_time
        self.thresholds = thresholds if thresholds is not None else [0.5]
        self.pool_size = pool_size  # This is for csi_pooled and csi_pooled_from_mean
        self.compute_mse = compute_mse
        self.compute_threshold = compute_threshold
        self.compute_apsd = compute_apsd
        self.compute_ssim = compute_ssim
        self.compute_crps = compute_crps
        self.crps_pool_type = crps_pool_type  # This is for CRPS pooling
        self.crps_scale = crps_scale  # This is for CRPS pooling
        self.crps_eps = crps_eps
        self.ssim_data_range = ssim_data_range

        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
        print(f"MetricsAccumulator using device: {self.device}")

        self.thresholds_tensor = torch.tensor(
            self.thresholds, device=self.device
        ).float()

        # Initialize accumulators
        self.mse_sum = 0.0
        self.mse_count = 0
        self.csi_sum = {th_val: 0.0 for th_val in self.thresholds}
        self.csi_count = {th_val: 0 for th_val in self.thresholds}
        self.pod_sum = {th_val: 0.0 for th_val in self.thresholds}
        self.pod_count = {th_val: 0 for th_val in self.thresholds}
        self.far_sum = {th_val: 0.0 for th_val in self.thresholds}
        self.far_count = {th_val: 0 for th_val in self.thresholds}
        self.hss_sum = {th_val: 0.0 for th_val in self.thresholds}
        self.hss_count = {th_val: 0 for th_val in self.thresholds}
        self.csi_pooled_sum = {
            th_val: 0.0 for th_val in self.thresholds
        }  # Per-sample csi_pooled
        self.csi_pooled_count = {
            th_val: 0 for th_val in self.thresholds
        }  # Per-sample csi_pooled

        self.apsd_sum = 0.0
        self.apsd_count = 0
        self.ssim_sum = 0.0
        self.ssim_score_count = 0

        self.mse_from_mean_sum = 0.0
        self.mse_from_mean_count = 0

        # For csi_from_mean (pixel-level from mean)
        self.csi_from_mean_hits = {th_val: 0 for th_val in self.thresholds}
        self.csi_from_mean_misses = {th_val: 0 for th_val in self.thresholds}
        self.csi_from_mean_false_alarms = {th_val: 0 for th_val in self.thresholds}

        # For pod_from_mean, far_from_mean, hss_from_mean (pixel-level contingency from mean)
        self.pixel_contingency_from_mean_hits = {
            th_val: 0 for th_val in self.thresholds
        }
        self.pixel_contingency_from_mean_misses = {
            th_val: 0 for th_val in self.thresholds
        }
        self.pixel_contingency_from_mean_false_alarms = {
            th_val: 0 for th_val in self.thresholds
        }
        self.pixel_contingency_from_mean_correct_negatives = {
            th_val: 0 for th_val in self.thresholds
        }

        # For csi_pooled_from_mean
        self.csi_pooled_from_mean_hits = {th_val: 0 for th_val in self.thresholds}
        self.csi_pooled_from_mean_misses = {th_val: 0 for th_val in self.thresholds}
        self.csi_pooled_from_mean_false_alarms = {
            th_val: 0 for th_val in self.thresholds
        }

        self.apsd_from_mean_sum = 0.0
        self.apsd_from_mean_count = 0
        self.ssim_from_mean_sum = 0.0
        self.ssim_from_mean_frames_count = 0

        if self.compute_ssim:
            self.ssim_metric_module = StructuralSimilarityIndexMeasure(
                data_range=self.ssim_data_range
            ).to(self.device)

        self.crps_sum = 0.0
        self.crps_count = 0

    def update(self, y_true_chunk_np: np.ndarray, y_pred_chunk_np: np.ndarray):
        """
        Updates the metric accumulators with a new chunk of data.

        Args:
            y_true_chunk_np (np.ndarray): A numpy array of ground truth data with shape
                                          (b, T, H, W), where b is batch size, T is number
                                          of lead times, H is height, and W is width.
            y_pred_chunk_np (np.ndarray): A numpy array of prediction data with shape
                                          (b, n_samples, T, H, W), where n_samples is the
                                          number of ensemble members.
        """
        y_true_chunk = torch.from_numpy(y_true_chunk_np).float().to(self.device)
        y_pred_chunk = torch.from_numpy(y_pred_chunk_np).float().to(self.device)

        # Assuming y_true_chunk is (b, T, H, W)
        b, T, H, W = y_true_chunk.shape
        # y_pred_chunk is (b, n_samples, T, H, W)
        n_samples = y_pred_chunk.shape[1]

        y_pred_mean = torch.mean(y_pred_chunk, dim=1)  # Shape: (b, T, H, W)

        # ----- METRICS FROM ENSEMBLE MEAN -----
        if self.compute_mse:
            y_true_lead_mse_mean = y_true_chunk[:, self.lead_time, :, :]
            y_pred_lead_mse_mean = y_pred_mean[:, self.lead_time, :, :]
            is_nan_true_mean_mse = torch.isnan(y_true_lead_mse_mean)
            is_nan_pred_mean_mse = torch.isnan(y_pred_lead_mse_mean)
            valid_mask_mean_mse = ~torch.logical_or(
                is_nan_true_mean_mse, is_nan_pred_mean_mse
            )
            diff2_mean = (y_true_lead_mse_mean - y_pred_lead_mse_mean) ** 2
            self.mse_from_mean_sum += torch.sum(diff2_mean[valid_mask_mean_mse]).item()
            self.mse_from_mean_count += torch.sum(valid_mask_mean_mse).item()

        if self.compute_threshold:
            # Get the relevant time slice for threshold metrics from mean
            y_true_lead_continuous = y_true_chunk[:, self.lead_time, :, :]
            y_pred_lead_continuous_mean = y_pred_mean[:, self.lead_time, :, :]

            # Prepare pooled versions if pool_size > 1 (for csi_pooled_from_mean)
            # These are pooled versions of the ensemble mean prediction and ground truth
            y_true_pooled_fm_for_threshold_loop = None
            y_pred_pooled_fm_for_threshold_loop = None
            if self.pool_size > 1:  # pool_size for CSI_POOL metric
                # Unsqueeze to add channel dimension for pool2d: (b, 1, H, W)
                y_true_lead_for_pool_mean = y_true_lead_continuous.unsqueeze(1)
                y_pred_lead_for_pool_mean = y_pred_lead_continuous_mean.unsqueeze(1)

                y_true_pooled_fm_for_threshold_loop = F.max_pool2d(
                    y_true_lead_for_pool_mean,
                    kernel_size=self.pool_size,
                    stride=self.pool_size,
                ).squeeze(
                    1
                )  # Squeeze channel dim back
                y_pred_pooled_fm_for_threshold_loop = F.max_pool2d(
                    y_pred_lead_for_pool_mean,
                    kernel_size=self.pool_size,
                    stride=self.pool_size,
                ).squeeze(
                    1
                )  # Squeeze channel dim back

            for i, th_val_tensor in enumerate(self.thresholds_tensor):
                th_key_float = self.thresholds[i]

                # --- Calculate Pixel-level contingency from mean (for csi_from_mean, pod_from_mean, etc.) ---
                y_true_bin_mean_pix = (y_true_lead_continuous > th_val_tensor).float()
                y_pred_bin_mean_pix = (
                    y_pred_lead_continuous_mean > th_val_tensor
                ).float()

                nan_mask_true_mean_pix = torch.isnan(y_true_lead_continuous)
                nan_mask_pred_mean_pix = torch.isnan(y_pred_lead_continuous_mean)
                invalid_mask_mean_pix = torch.logical_or(
                    nan_mask_true_mean_pix, nan_mask_pred_mean_pix
                )

                y_true_bin_mean_pix[invalid_mask_mean_pix] = (
                    0  # Set NaNs to 0 for binary operations
                )
                y_pred_bin_mean_pix[invalid_mask_mean_pix] = (
                    0  # Set NaNs to 0 for binary operations
                )

                # Current chunk's pixel-level contingency counts
                current_chunk_hits_fm_pix = torch.sum(
                    (y_pred_bin_mean_pix == 1) & (y_true_bin_mean_pix == 1)
                ).item()
                current_chunk_misses_fm_pix = torch.sum(
                    (y_pred_bin_mean_pix == 0) & (y_true_bin_mean_pix == 1)
                ).item()
                current_chunk_fa_fm_pix = torch.sum(
                    (y_pred_bin_mean_pix == 1) & (y_true_bin_mean_pix == 0)
                ).item()
                current_chunk_cn_fm_pix = torch.sum(
                    (y_pred_bin_mean_pix == 0) & (y_true_bin_mean_pix == 0)
                ).item()

                # Accumulate for csi_from_mean
                self.csi_from_mean_hits[th_key_float] += current_chunk_hits_fm_pix
                self.csi_from_mean_misses[th_key_float] += current_chunk_misses_fm_pix
                self.csi_from_mean_false_alarms[th_key_float] += current_chunk_fa_fm_pix

                # Accumulate for general pixel-level contingency (pod, far, hss from mean)
                self.pixel_contingency_from_mean_hits[
                    th_key_float
                ] += current_chunk_hits_fm_pix
                self.pixel_contingency_from_mean_misses[
                    th_key_float
                ] += current_chunk_misses_fm_pix
                self.pixel_contingency_from_mean_false_alarms[
                    th_key_float
                ] += current_chunk_fa_fm_pix
                self.pixel_contingency_from_mean_correct_negatives[
                    th_key_float
                ] += current_chunk_cn_fm_pix

                # --- Calculate and Accumulate for csi_pooled_from_mean ---
                if self.pool_size > 1:
                    # Use the pre-calculated pooled tensors
                    y_true_event_pooled_fm = (
                        y_true_pooled_fm_for_threshold_loop > th_val_tensor
                    ).float()
                    y_pred_event_pooled_fm = (
                        y_pred_pooled_fm_for_threshold_loop > th_val_tensor
                    ).float()

                    # Handle NaNs in pooled data (NaNs in original continuous data might propagate)
                    pooled_nan_mask_true_fm = torch.isnan(
                        y_true_pooled_fm_for_threshold_loop
                    )
                    pooled_nan_mask_pred_fm = torch.isnan(
                        y_pred_pooled_fm_for_threshold_loop
                    )
                    pooled_invalid_mask_fm = torch.logical_or(
                        pooled_nan_mask_true_fm, pooled_nan_mask_pred_fm
                    )

                    y_true_event_pooled_fm[pooled_invalid_mask_fm] = 0
                    y_pred_event_pooled_fm[pooled_invalid_mask_fm] = 0

                    # Current chunk's pooled contingency counts
                    current_chunk_pooled_hits = torch.sum(
                        (y_pred_event_pooled_fm == 1) & (y_true_event_pooled_fm == 1)
                    ).item()
                    current_chunk_pooled_misses = torch.sum(
                        (y_pred_event_pooled_fm == 0) & (y_true_event_pooled_fm == 1)
                    ).item()
                    current_chunk_pooled_false_alarms = torch.sum(
                        (y_pred_event_pooled_fm == 1) & (y_true_event_pooled_fm == 0)
                    ).item()

                    self.csi_pooled_from_mean_hits[
                        th_key_float
                    ] += current_chunk_pooled_hits
                    self.csi_pooled_from_mean_misses[
                        th_key_float
                    ] += current_chunk_pooled_misses
                    self.csi_pooled_from_mean_false_alarms[
                        th_key_float
                    ] += current_chunk_pooled_false_alarms
                else:
                    # No pooling for csi_pooled_from_mean: use current chunk's pixel-level counts
                    self.csi_pooled_from_mean_hits[
                        th_key_float
                    ] += current_chunk_hits_fm_pix
                    self.csi_pooled_from_mean_misses[
                        th_key_float
                    ] += current_chunk_misses_fm_pix
                    self.csi_pooled_from_mean_false_alarms[
                        th_key_float
                    ] += current_chunk_fa_fm_pix

        if self.compute_apsd:  # APSD from mean
            # Ensure y_true_chunk and y_pred_mean are sliced for the specific lead_time
            y_true_lead_apsd_mean_np = np.nan_to_num(
                y_true_chunk[:, self.lead_time, :, :].cpu().numpy()
            )
            y_pred_lead_apsd_mean_np = np.nan_to_num(
                y_pred_mean[:, self.lead_time, :, :].cpu().numpy()
            )
            for j in range(b):  # Iterate over batch items
                fft_true_mean = np.fft.fft2(y_true_lead_apsd_mean_np[j])
                fft_pred_mean = np.fft.fft2(y_pred_lead_apsd_mean_np[j])
                psd_true_mean = np.abs(fft_true_mean) ** 2
                psd_pred_mean = np.abs(fft_pred_mean) ** 2
                self.apsd_from_mean_sum += np.mean((psd_true_mean - psd_pred_mean) ** 2)
                self.apsd_from_mean_count += 1

        if self.compute_ssim:  # SSIM from mean (iterating over all time steps T)
            for t_idx in range(T):
                # Unsqueeze to add channel dimension: (b, 1, H, W)
                frame_true_ssim_mean = y_true_chunk[:, t_idx, :, :].unsqueeze(1)
                frame_pred_ssim_mean = y_pred_mean[:, t_idx, :, :].unsqueeze(1)

                # Clamp predictions to the data range for SSIM
                frame_pred_ssim_mean = torch.clamp(
                    frame_pred_ssim_mean, 0, self.ssim_metric_module.data_range
                )

                # ssim_metric_module expects (N, C, H, W)
                ssim_val_batch_mean = self.ssim_metric_module(
                    frame_pred_ssim_mean, frame_true_ssim_mean
                )

                # ssim_val_batch_mean is the mean SSIM value for the current batch of 'b' frames.
                # We accumulate the sum of individual SSIM scores (batch_mean * b)
                # and the total count of individual frames (b).
                self.ssim_from_mean_sum += ssim_val_batch_mean.item() * b
                self.ssim_from_mean_frames_count += b

        # ----- PER-SAMPLE AVERAGED METRICS -----
        if self.compute_mse:  # Per-sample MSE
            y_true_lead_mse = y_true_chunk[:, self.lead_time, :, :]  # (b, H, W)
            current_chunk_sample_mses_sum = 0.0
            current_chunk_sample_mse_count = 0
            for s_idx in range(n_samples):
                y_pred_sample_lead = y_pred_chunk[
                    :, s_idx, self.lead_time, :, :
                ]  # (b, H, W)
                is_nan_true = torch.isnan(y_true_lead_mse)
                is_nan_pred = torch.isnan(y_pred_sample_lead)
                valid_mask_mse = ~torch.logical_or(is_nan_true, is_nan_pred)

                diff2 = (y_true_lead_mse - y_pred_sample_lead) ** 2

                # Calculate MSE for this sample over the batch
                sum_diff2_sample_batch = torch.sum(diff2[valid_mask_mse])
                count_valid_sample_batch = torch.sum(valid_mask_mse)

                if count_valid_sample_batch.item() > 0:
                    # This is the MSE for one sample, averaged over batch and pixels
                    # We want to average these sample MSEs.
                    sample_mse_avg_over_batch = (
                        sum_diff2_sample_batch.item() / count_valid_sample_batch.item()
                    )
                    current_chunk_sample_mses_sum += sample_mse_avg_over_batch
                    current_chunk_sample_mse_count += 1

            if current_chunk_sample_mse_count > 0:
                self.mse_sum += (
                    current_chunk_sample_mses_sum  # Sum of average MSEs for each sample
                )
                self.mse_count += (
                    current_chunk_sample_mse_count  # Number of samples processed
                )

        if self.compute_threshold:  # Per-sample threshold metrics
            y_true_lead_continuous_sample = y_true_chunk[
                :, self.lead_time, :, :
            ]  # (b, H, W)

            # Pooled ground truth for per-sample csi_pooled (if pooling)
            # This is done once per chunk for the ground truth.
            y_true_continuous_pooled_orig_for_sample_metrics = None
            if self.pool_size > 1:  # pool_size for CSI_POOL metric
                y_true_lead_for_pool_orig = y_true_lead_continuous_sample.unsqueeze(
                    1
                )  # (b, 1, H, W)
                y_true_continuous_pooled_orig_for_sample_metrics = F.max_pool2d(
                    y_true_lead_for_pool_orig,
                    kernel_size=self.pool_size,
                    stride=self.pool_size,
                ).squeeze(
                    1
                )  # (b, H_pooled, W_pooled)

            for s_idx in range(n_samples):
                y_pred_sample_lead_thresh = y_pred_chunk[
                    :, s_idx, self.lead_time, :, :
                ]  # (b, H, W)

                # Pooled prediction for this sample (if pooling)
                y_pred_continuous_pooled_sample_for_metrics = None
                if self.pool_size > 1:
                    y_pred_lead_for_pool_sample = y_pred_sample_lead_thresh.unsqueeze(
                        1
                    )  # (b, 1, H, W)
                    y_pred_continuous_pooled_sample_for_metrics = F.max_pool2d(
                        y_pred_lead_for_pool_sample,
                        kernel_size=self.pool_size,
                        stride=self.pool_size,
                    ).squeeze(
                        1
                    )  # (b, H_pooled, W_pooled)

                for i, th_val_tensor in enumerate(self.thresholds_tensor):
                    th_key_float = self.thresholds[i]

                    # Pixel-level for this sample
                    y_true_bin_pix_s = (
                        y_true_lead_continuous_sample > th_val_tensor
                    ).float()
                    y_pred_bin_sample_pix_s = (
                        y_pred_sample_lead_thresh > th_val_tensor
                    ).float()

                    nan_mask_true_pix_s = torch.isnan(y_true_lead_continuous_sample)
                    nan_mask_pred_pix_s = torch.isnan(y_pred_sample_lead_thresh)
                    invalid_mask_pix_s = torch.logical_or(
                        nan_mask_true_pix_s, nan_mask_pred_pix_s
                    )

                    y_true_bin_pix_s[invalid_mask_pix_s] = 0
                    y_pred_bin_sample_pix_s[invalid_mask_pix_s] = 0

                    # Contingency for current sample, current threshold, averaged over batch
                    hits_s_pix = torch.sum(
                        (y_pred_bin_sample_pix_s == 1) & (y_true_bin_pix_s == 1)
                    ).item()
                    misses_s_pix = torch.sum(
                        (y_pred_bin_sample_pix_s == 0) & (y_true_bin_pix_s == 1)
                    ).item()
                    fa_s_pix = torch.sum(
                        (y_pred_bin_sample_pix_s == 1) & (y_true_bin_pix_s == 0)
                    ).item()
                    cn_s_pix = torch.sum(
                        (y_pred_bin_sample_pix_s == 0) & (y_true_bin_pix_s == 0)
                    ).item()

                    # Calculate scores for this sample, this threshold
                    csi_s_pix_denom = hits_s_pix + misses_s_pix + fa_s_pix
                    csi_s_pix = (
                        hits_s_pix / csi_s_pix_denom if csi_s_pix_denom > 0 else np.nan
                    )

                    pod_s_denom = hits_s_pix + misses_s_pix
                    pod_s = hits_s_pix / pod_s_denom if pod_s_denom > 0 else np.nan

                    far_s_denom = hits_s_pix + fa_s_pix
                    far_s = fa_s_pix / far_s_denom if far_s_denom > 0 else np.nan

                    hss_num_s = 2 * (hits_s_pix * cn_s_pix - misses_s_pix * fa_s_pix)
                    hss_den_term1_s = (hits_s_pix + misses_s_pix) * (
                        misses_s_pix + cn_s_pix
                    )
                    hss_den_term2_s = (hits_s_pix + fa_s_pix) * (fa_s_pix + cn_s_pix)
                    hss_den_s = hss_den_term1_s + hss_den_term2_s
                    hss_s = hss_num_s / hss_den_s if hss_den_s != 0 else np.nan

                    if not np.isnan(csi_s_pix):
                        self.csi_sum[th_key_float] += csi_s_pix
                        self.csi_count[th_key_float] += 1
                    if not np.isnan(pod_s):
                        self.pod_sum[th_key_float] += pod_s
                        self.pod_count[th_key_float] += 1
                    if not np.isnan(far_s):
                        self.far_sum[th_key_float] += far_s
                        self.far_count[th_key_float] += 1
                    if not np.isnan(hss_s):
                        self.hss_sum[th_key_float] += hss_s
                        self.hss_count[th_key_float] += 1

                    # Pooled CSI for this sample
                    if self.pool_size > 1:
                        y_true_event_pooled = (
                            y_true_continuous_pooled_orig_for_sample_metrics
                            > th_val_tensor
                        ).float()
                        y_pred_event_pooled_sample = (
                            y_pred_continuous_pooled_sample_for_metrics > th_val_tensor
                        ).float()

                        pooled_nan_mask_true = torch.isnan(
                            y_true_continuous_pooled_orig_for_sample_metrics
                        )
                        pooled_nan_mask_pred = torch.isnan(
                            y_pred_continuous_pooled_sample_for_metrics
                        )
                        pooled_invalid_mask = torch.logical_or(
                            pooled_nan_mask_true, pooled_nan_mask_pred
                        )

                        y_true_event_pooled[pooled_invalid_mask] = 0
                        y_pred_event_pooled_sample[pooled_invalid_mask] = 0

                        pool_hits_s = torch.sum(
                            (y_pred_event_pooled_sample == 1)
                            & (y_true_event_pooled == 1)
                        ).item()
                        pool_misses_s = torch.sum(
                            (y_pred_event_pooled_sample == 0)
                            & (y_true_event_pooled == 1)
                        ).item()
                        pool_false_alarms_s = torch.sum(
                            (y_pred_event_pooled_sample == 1)
                            & (y_true_event_pooled == 0)
                        ).item()

                        pooled_denom_s = (
                            pool_hits_s + pool_misses_s + pool_false_alarms_s
                        )
                        csi_pooled_s = (
                            pool_hits_s / pooled_denom_s
                            if pooled_denom_s > 0
                            else np.nan
                        )

                        if not np.isnan(csi_pooled_s):
                            self.csi_pooled_sum[th_key_float] += csi_pooled_s
                            self.csi_pooled_count[th_key_float] += 1
                    else:  # No pooling for csi_pooled (per-sample), use pixel-level CSI for this sample
                        if not np.isnan(
                            csi_s_pix
                        ):  # csi_s_pix is the pixel-level CSI for this sample
                            self.csi_pooled_sum[th_key_float] += csi_s_pix
                            self.csi_pooled_count[th_key_float] += 1

        if self.compute_apsd:  # Per-sample APSD
            y_true_lead_apsd_torch_allbatch = y_true_chunk[
                :, self.lead_time, :, :
            ]  # (b, H, W)
            for s_idx in range(n_samples):
                y_pred_sample_lead_torch_allbatch = y_pred_chunk[
                    :, s_idx, self.lead_time, :, :
                ]  # (b, H, W)

                # Convert whole batch for this sample to numpy
                y_true_lead_apsd_np_allbatch = np.nan_to_num(
                    y_true_lead_apsd_torch_allbatch.cpu().numpy()
                )
                y_pred_sample_lead_apsd_np_allbatch = np.nan_to_num(
                    y_pred_sample_lead_torch_allbatch.cpu().numpy()
                )

                for j in range(b):  # Iterate over items in the batch
                    sample_true_np = y_true_lead_apsd_np_allbatch[j]
                    sample_pred_np = y_pred_sample_lead_apsd_np_allbatch[j]

                    fft_true = np.fft.fft2(sample_true_np)
                    fft_pred = np.fft.fft2(sample_pred_np)
                    psd_true = np.abs(fft_true) ** 2
                    psd_pred = np.abs(fft_pred) ** 2
                    diff_psd_score = np.mean(
                        (psd_true - psd_pred) ** 2
                    )  # APSD for one image (one batch item, one sample)

                    self.apsd_sum += diff_psd_score  # Sum of APSD scores
                    self.apsd_count += 1  # Count of images processed (b * n_samples)

        if self.compute_ssim:  # Per-sample SSIM (iterating over all T frames)
            for s_idx in range(n_samples):
                y_pred_sample_chunk_for_ssim = y_pred_chunk[
                    :, s_idx, :, :, :
                ]  # (b, T, H, W)
                for t_idx in range(T):  # Iterate over time steps
                    # Unsqueeze for channel dim: (b, 1, H, W)
                    frame_true_ssim = y_true_chunk[:, t_idx, :, :].unsqueeze(1)
                    frame_pred_ssim_sample = y_pred_sample_chunk_for_ssim[
                        :, t_idx, :, :
                    ].unsqueeze(1)

                    frame_pred_ssim_sample = torch.clamp(
                        frame_pred_ssim_sample, 0, self.ssim_metric_module.data_range
                    )

                    # ssim_metric_module returns scalar (mean over batch)
                    ssim_value_batch_avg_for_sample_frame = self.ssim_metric_module(
                        frame_pred_ssim_sample, frame_true_ssim
                    )
                    # ssim_value_batch_avg_for_sample_frame is the mean SSIM for the current batch.
                    # Accumulate the sum of individual SSIM scores (batch_mean * b)
                    # and the total count of individual frames (b). 'b' is the batch size.
                    self.ssim_sum += ssim_value_batch_avg_for_sample_frame.item() * b
                    self.ssim_score_count += b

        if self.compute_crps:
            try:
                # CRPS expects y_true (b, T, c, h, w) and y_pred (b, n, T, c, h, w)
                # Assuming input y_true_chunk is (b, T, H, W) and y_pred_chunk is (b, n, T, H, W)
                # Add channel dimension 'c' = 1
                y_true_for_crps = y_true_chunk.unsqueeze(2)  # (b, T, 1, H, W)
                y_pred_for_crps = y_pred_chunk.unsqueeze(2)  # (b, n, T, 1, H, W)

                crps_val_chunk = crps(  # This crps function returns a single scalar (mean over everything)
                    y_true_for_crps,
                    y_pred_for_crps,
                    pool_type=self.crps_pool_type,
                    scale=self.crps_scale,
                    mode="mean",
                    eps=self.crps_eps,
                )
            except Exception as e:
                print(f"CRPS computation failed for chunk: {e}")
                crps_val_chunk = np.nan

            if not np.isnan(crps_val_chunk):
                # crps_val_chunk is already the mean CRPS for the chunk.
                # To get the overall mean CRPS, we need to weight by the number of elements it was averaged over.
                # The crps function calculates mean over b, T, h_eff, w_eff.
                h_eff_crps = (
                    H // self.crps_scale
                    if self.crps_scale > 0 and self.crps_pool_type in ["avg", "max"]
                    else H
                )
                w_eff_crps = (
                    W // self.crps_scale
                    if self.crps_scale > 0 and self.crps_pool_type in ["avg", "max"]
                    else W
                )

                # Number of values CRPS was averaged over in this chunk
                num_elements_in_chunk_crps_avg = b * T * h_eff_crps * w_eff_crps

                self.crps_sum += crps_val_chunk * num_elements_in_chunk_crps_avg
                self.crps_count += num_elements_in_chunk_crps_avg

    def compute(self):
        """
        Computes the final metrics from the accumulated values.

        This method should be called after all data chunks have been processed by the
        `update` method.

        Returns:
            Dict[str, Optional[object]]: A dictionary containing the computed metrics.
            Metrics are provided for both per-sample averages and for the ensemble mean.
            Includes metrics like MSE, APSD, SSIM, CRPS, and various threshold-based
            scores (CSI, POD, FAR, HSS).
        """
        results: Dict[str, Optional[object]] = {}
        # Per-sample averaged metrics
        results["mse"] = self.mse_sum / self.mse_count if self.mse_count > 0 else np.nan
        results["apsd"] = (
            self.apsd_sum / self.apsd_count if self.apsd_count > 0 else np.nan
        )

        # For SSIM (per-sample): ssim_sum is sum of (batch_mean_ssim scores), ssim_score_count is (n_samples * T)
        results["ssim"] = (
            self.ssim_sum / self.ssim_score_count
            if self.ssim_score_count > 0
            else np.nan
        )

        # Metrics from ensemble mean
        results["mse_from_mean"] = (
            self.mse_from_mean_sum / self.mse_from_mean_count
            if self.mse_from_mean_count > 0
            else np.nan
        )
        results["apsd_from_mean"] = (
            self.apsd_from_mean_sum / self.apsd_from_mean_count
            if self.apsd_from_mean_count > 0
            else np.nan
        )

        # For SSIM from mean: ssim_from_mean_sum is sum of (batch_mean_ssim_scores * b), ssim_from_mean_frames_count is (b*T)
        # This calculates sum_of_all_individual_ssim_scores / total_individual_images
        results["ssim_from_mean"] = (
            self.ssim_from_mean_sum / self.ssim_from_mean_frames_count
            if self.ssim_from_mean_frames_count > 0
            else np.nan
        )

        if self.compute_threshold:
            csi_dict = {}
            pod_dict = {}
            far_dict = {}
            hss_dict = {}
            csi_pooled_dict = {}  # Per-sample csi_pooled

            csi_from_mean_dict = {}
            pod_from_mean_dict = {}
            far_from_mean_dict = {}
            hss_from_mean_dict = {}
            csi_pooled_from_mean_dict = (
                {}
            )  # CSI from mean, with pooling if pool_size > 1

            for th_val_float in self.thresholds:
                # Per-sample averaged threshold metrics
                csi_dict[th_val_float] = (
                    self.csi_sum[th_val_float] / self.csi_count[th_val_float]
                    if self.csi_count[th_val_float] > 0
                    else np.nan
                )
                pod_dict[th_val_float] = (
                    self.pod_sum[th_val_float] / self.pod_count[th_val_float]
                    if self.pod_count[th_val_float] > 0
                    else np.nan
                )
                far_dict[th_val_float] = (
                    self.far_sum[th_val_float] / self.far_count[th_val_float]
                    if self.far_count[th_val_float] > 0
                    else np.nan
                )
                hss_dict[th_val_float] = (
                    self.hss_sum[th_val_float] / self.hss_count[th_val_float]
                    if self.hss_count[th_val_float] > 0
                    else np.nan
                )
                csi_pooled_dict[th_val_float] = (
                    self.csi_pooled_sum[th_val_float]
                    / self.csi_pooled_count[th_val_float]
                    if self.csi_pooled_count[th_val_float] > 0
                    else np.nan
                )

                # Threshold metrics from ensemble mean (based on accumulated H, M, FA)
                # csi_from_mean (pixel-level)
                csi_fm_h = self.csi_from_mean_hits[th_val_float]
                csi_fm_m = self.csi_from_mean_misses[th_val_float]
                csi_fm_fa = self.csi_from_mean_false_alarms[th_val_float]
                csi_from_mean_denom = csi_fm_h + csi_fm_m + csi_fm_fa
                csi_from_mean_dict[th_val_float] = (
                    csi_fm_h / csi_from_mean_denom
                    if csi_from_mean_denom > 0
                    else np.nan
                )

                # pod_from_mean, far_from_mean, hss_from_mean (using pixel_contingency accumulators)
                pix_fm_h = self.pixel_contingency_from_mean_hits[th_val_float]
                pix_fm_m = self.pixel_contingency_from_mean_misses[th_val_float]
                pix_fm_fa = self.pixel_contingency_from_mean_false_alarms[th_val_float]
                pix_fm_cn = self.pixel_contingency_from_mean_correct_negatives[
                    th_val_float
                ]

                pod_from_mean_dict[th_val_float] = (
                    pix_fm_h / (pix_fm_h + pix_fm_m)
                    if (pix_fm_h + pix_fm_m) > 0
                    else np.nan
                )
                far_from_mean_dict[th_val_float] = (
                    pix_fm_fa / (pix_fm_h + pix_fm_fa)
                    if (pix_fm_h + pix_fm_fa) > 0
                    else np.nan
                )

                hss_fm_num = 2 * (pix_fm_h * pix_fm_cn - pix_fm_m * pix_fm_fa)
                hss_fm_den1 = (pix_fm_h + pix_fm_m) * (pix_fm_m + pix_fm_cn)
                hss_fm_den2 = (pix_fm_h + pix_fm_fa) * (pix_fm_fa + pix_fm_cn)
                hss_fm_den = hss_fm_den1 + hss_fm_den2
                hss_from_mean_dict[th_val_float] = (
                    hss_fm_num / hss_fm_den if hss_fm_den != 0 else np.nan
                )

                # csi_pooled_from_mean (uses its own H, M, FA accumulators)
                csi_pfm_h = self.csi_pooled_from_mean_hits[th_val_float]
                csi_pfm_m = self.csi_pooled_from_mean_misses[th_val_float]
                csi_pfm_fa = self.csi_pooled_from_mean_false_alarms[th_val_float]
                csi_pfm_denom = csi_pfm_h + csi_pfm_m + csi_pfm_fa
                csi_pooled_from_mean_dict[th_val_float] = (
                    csi_pfm_h / csi_pfm_denom if csi_pfm_denom > 0 else np.nan
                )

            results["csi"] = csi_dict
            results["pod"] = pod_dict
            results["false_alarm_rate"] = far_dict
            results["heidke_skill_score"] = hss_dict
            results["csi_pooled"] = csi_pooled_dict  # Per-sample csi_pooled

            results["csi_from_mean"] = csi_from_mean_dict
            results["pod_from_mean"] = pod_from_mean_dict
            results["far_from_mean"] = far_from_mean_dict
            results["hss_from_mean"] = hss_from_mean_dict
            results["csi_pooled_from_mean"] = csi_pooled_from_mean_dict
        else:  # if not self.compute_threshold
            results["csi"] = None
            results["pod"] = None
            results["false_alarm_rate"] = None
            results["heidke_skill_score"] = None
            results["csi_pooled"] = None
            results["csi_from_mean"] = None
            results["pod_from_mean"] = None
            results["far_from_mean"] = None
            results["hss_from_mean"] = None
            results["csi_pooled_from_mean"] = None

        results["crps"] = (
            self.crps_sum / self.crps_count if self.crps_count > 0 else np.nan
        )
        return results
