import torch
import numpy as np
from torch.optim.lr_scheduler import (StepLR, ExponentialLR, ReduceLROnPlateau)
from skimage.metrics import structural_similarity as ssim
from typing import Iterable

def get_optimizer(params, cfg):
    """
    Set optimizer.
    Args:
        params: model trainable parameters
        cfg: Optimization configuration
    Returns:
        optimizer [torch.optim]
    """
    if cfg.optim_alg == "Adam":
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, params), lr=cfg.optim_lr)
    elif cfg.optim_alg == "AdamL2":
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, params), lr=cfg.optim_lr, weight_decay=cfg.optim_wd)
    elif cfg.optim_alg == "AdamW":
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, params), lr=cfg.optim_lr, weight_decay=cfg.optim_wd)
    return optimizer


def get_scheduler(optimizer, cfg):
    """get learning scheduler.
    Args:
        optimizer [torch.optim]
        cfg: Scheduler configuration.
    Returns:
        scheduler [torch.optim]
    """
    if cfg.name == "StepLR":
        params = cfg.StepLR
        scheduler = StepLR(optimizer, step_size=params.step_size, gamma=params.gamma)
    elif cfg.name == "ExponentialLR":
        params = cfg.ExponentialLR
        scheduler = ExponentialLR(optimizer, gamma=params.gamma)
    elif cfg.name == "ReduceLROnPlateau":
        params = cfg.ReduceLROnPlateau
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=params.factor, patience=params.patience, verbose=True, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-8)
    return scheduler


def get_loss(loss):
    """
    Set loss.
    Args:
        loss: string.
    Returns:
        Loss function will be use for modeling.
    """
    if loss == "MSELoss":
        criterion = torch.nn.MSELoss(reduction="mean")
        # criterion = torch.nn.MSELoss(reduction="sum")
    elif loss == "L1Loss":
        # criterion = torch.nn.HuberLoss(delta=1.0)
        criterion = torch.nn.L1Loss(reduction="sum")
    return criterion


def toNumpy(tensor):
    """
    Converts Pytorch tensor to numpy array
    """
    return tensor.detach().cpu().numpy()  

def req_masks(x):
    B,N,C = x.shape
    if C==1: 
        valid_mask = ~torch.isnan(x).squeeze(-1)
        invalid_mask = torch.isnan(x).squeeze(-1)
        masks = (valid_mask, invalid_mask)
        return masks
    else:
        raise ValueError("Data not suitable for calculating masks!")
    
def req_samples(y, masks):
    B, N, C = y.shape
    valid_mask , invalid_mask = masks
    y_valid = [y[b, valid_mask[b]] for b in range(B)]
    y_valid = torch.stack(y_valid,dim = 0)
    y_invalid = [y[b, invalid_mask[b]] for b in range(B)]
    y_invalid = torch.stack(y_invalid,dim = 0)
    return y_valid, y_invalid

def psnr(pred, target, max_val=1.0):
    # pred, target: (B, N, 1)
    mse = torch.mean((pred - target) ** 2, dim=(1,2))  # mean per batch
    psnr = 10 * torch.log10(max_val**2 / mse)
    return psnr.mean()


def crps_ensemble(pred, obs, ensemble_dim=-1, reduction='mean', validate=True):
    """
    Compute CRPS for ensemble forecasts.
    pred: torch.Tensor with shape (B, N, M)  (batch, n_points, ensemble_members)
          or (B, N) / (B, N, 1) for deterministic forecasts.
    obs:  torch.Tensor with shape (B, N) or (B, N, 1)
    ensemble_dim: dimension index of ensemble members if pred has different layout
    reduction: 'mean' | 'none'  ('none' returns tensor (B,N))
    """
    # Basic checks
    
    if validate:
        if pred.device != obs.device:
            raise ValueError("pred and obs must be on same device")
        if pred.dtype.is_complex or obs.dtype.is_complex:
            raise ValueError("complex tensors not supported")

    # Normalize shapes: ensure pred is (B, N, M)
    if pred.dim() == 2:
        pred = pred.unsqueeze(-1)   # (B, N) -> (B, N, 1)
    if ensemble_dim != -1 and ensemble_dim != (pred.dim() - 1):
        pred = pred.transpose(ensemble_dim, -1)  # move ensemble to last dim

    # Normalize obs to (B, N)
    if obs.dim() == 3 and obs.shape[-1] == 1:
        obs = obs.squeeze(-1)
    if obs.dim() == 2:
        pass
    else:
        raise ValueError("obs must be shape (B,N) or (B,N,1)")

    B, N, M = pred.shape

    # term1 = (1/M) * sum_m |x_m - y|
    term1 = torch.mean(torch.abs(pred - obs.unsqueeze(-1)), dim=-1)  # (B, N)

    if M == 1:
        crps = term1
    else:

        x_sorted, _ = torch.sort(pred, dim=-1)   # (B, N, M)
        idx = torch.arange(1, M+1, device=pred.device, dtype=pred.dtype)  # 1..M
        weights = (2 * idx - M - 1).view(*(1, 1, -1))   # (1,1,M) for broadcasting
        s = torch.sum(x_sorted * weights, dim=-1)       # (B, N) = sum_k x_k*(2k - M -1)

        second_term = s / (M ** 2)
        crps = term1 - second_term

    if reduction == 'mean':
        return crps.mean()
    elif reduction == 'none':
        return crps



class Evaluator(object):
    
    def __init__(self,  thresholds: Iterable[int], **kwargs):
        # thresholds should be in the same integer scale as float2int output
        self.metrics = {}
        self.thresholds = list(thresholds)
        for t in self.thresholds:
            self.metrics[t] = {
                "hits": [],
                "misses": [],
                "falsealarms": [],
                "correctnegs": [],

                "hits44": [],
                "misses44": [],
                "falsealarms44": [],
                "correctnegs44": [],

                "hits16": [],
                "misses16": [],
                "falsealarms16": [],
                "correctnegs16": [],
            }
        self.losses = {
            "mse":  [],
            "mae":  [],
            "rmse": [],
            "psnr": [],
            "ssim": [],
            "crps": [],
            "lpips": [],
        }
       
        self.total = 0


    def cal_frame(self, obs, sim, threshold):
        """
        obs, sim: tensors (any shape), integer or boolean after float2int.
        threshold: integer threshold in same scale as float2int output.
        Returns python ints (hits, misses, falsealarms, correctnegatives)
        summed over all elements in the tensors.
        """
        # ensure torch tensor
        if not torch.is_tensor(obs):
            obs = torch.from_numpy(np.asarray(obs))
        if not torch.is_tensor(sim):
            sim = torch.from_numpy(np.asarray(sim))
        if threshold >= 0:
            obs_bin = (obs >= threshold)
            sim_bin = (sim >= threshold)
        else:
            obs_bin = (obs <= threshold)
            sim_bin = (sim <= threshold)

        # ensure boolean dtype
        obs_bin = obs_bin.bool()
        sim_bin = sim_bin.bool()

        hits = int(torch.sum(obs_bin & sim_bin).item())
        misses = int(torch.sum(obs_bin & (~sim_bin)).item())
        falsealarms = int(torch.sum((~obs_bin) & sim_bin).item())
        correctnegatives = int(torch.sum((~obs_bin) & (~sim_bin)).item())

        return hits, misses, falsealarms, correctnegatives

    def evaluate(self, true_batch, pred_batch):
        """
        Accept tensors or numpy arrays.
        Expected shapes:
          - (B, seq_len, 1)  OR
          - (B, seq_len, H, W)  OR
          - (B, seq_len)  etc.
        This function will squeeze trailing singleton dims (e.g. last dim == 1).
        """
        # convert numpy -> torch
        if not torch.is_tensor(pred_batch):
            pred_batch = torch.from_numpy(np.asarray(pred_batch))
        if not torch.is_tensor(true_batch):
            true_batch = torch.from_numpy(np.asarray(true_batch))

        # move to CPU for counting (counts are small). If you want GPU counting,
        # remove .cpu() below.
        pred_batch = pred_batch.detach().cpu().float()
        true_batch = true_batch.detach().cpu().float()

        # remove trailing singleton dims (e.g. (B, N, 1) -> (B, N))
        # but only remove dims of size 1
        while pred_batch.dim() > 2 and pred_batch.size(-1) == 1:
            pred_batch = pred_batch.squeeze(-1)
        while true_batch.dim() > 2 and true_batch.size(-1) == 1:
            true_batch = true_batch.squeeze(-1)

        assert pred_batch.shape == true_batch.shape, f"pred_batch.shape: {pred_batch.shape}, true_batch.shape: {true_batch.shape}"

        batch_size = pred_batch.shape[0]

        for threshold in self.thresholds:
            for b in range(batch_size):
                # compute counts across all elements for this sample
                hit, miss, falsealarm, correctneg = self.cal_frame(true_batch[b], pred_batch[b], threshold)

                # append scalar ints
                self.metrics[threshold]["hits"].append(hit)
                self.metrics[threshold]["misses"].append(miss)
                self.metrics[threshold]["falsealarms"].append(falsealarm)
                self.metrics[threshold]["correctnegs"].append(correctneg)

        self.total += batch_size

    def done(self):
        """
        Compute aggregated metrics from accumulated counts.
        Returns dict: {'csi', 'far', 'avg_pod', 'hss'} (floats).
        """
        res_dict = {}

        avg_csi, avg_far, avg_pod, avg_hss = [], [], [], []

        for threshold in self.thresholds:
            hits = torch.tensor(self.metrics[threshold]["hits"], dtype=torch.float32)
            misses = torch.tensor(self.metrics[threshold]["misses"], dtype=torch.float32)
            falsealarms = torch.tensor(self.metrics[threshold]["falsealarms"], dtype=torch.float32)
            correctnegs = torch.tensor(self.metrics[threshold]["correctnegs"], dtype=torch.float32)

            # replace NaN/inf with zeros / finite numbers
            hits = torch.nan_to_num(hits, nan=0.0, posinf=0.0, neginf=0.0)
            misses = torch.nan_to_num(misses, nan=0.0, posinf=0.0, neginf=0.0)
            falsealarms = torch.nan_to_num(falsealarms, nan=0.0, posinf=0.0, neginf=0.0)
            correctnegs = torch.nan_to_num(correctnegs, nan=0.0, posinf=0.0, neginf=0.0)

            # compute means (scalar)
            mean_hits = hits.mean() if hits.numel() else torch.tensor(0.0)
            mean_misses = misses.mean() if misses.numel() else torch.tensor(0.0)
            mean_false = falsealarms.mean() if falsealarms.numel() else torch.tensor(0.0)
            mean_cn = correctnegs.mean() if correctnegs.numel() else torch.tensor(0.0)

            # denominators (avoid div by zero)
            denom_csi = mean_hits + mean_misses + mean_false
            csi1 = mean_hits / denom_csi if denom_csi != 0 else torch.tensor(0.0)

            denom_far = mean_hits + mean_false
            far1 = mean_false / denom_far if denom_far != 0 else torch.tensor(0.0)

            denom_pod = mean_hits + mean_misses
            pod1 = mean_hits / denom_pod if denom_pod != 0 else torch.tensor(0.0)

            # HSS formula (same as your numpy formula)
            a = mean_hits
            b = mean_misses
            c = mean_false
            d = mean_cn
            num = 2 * (a * d - b * c)
            den = (a + b) * (b + d) + (a + c) * (c + d)
            hss1 = num / den if den != 0 else torch.tensor(0.0)

            # ensure finite
            csi1 = torch.nan_to_num(csi1, nan=0.0, posinf=0.0, neginf=0.0)
            far1 = torch.nan_to_num(far1, nan=0.0, posinf=0.0, neginf=0.0)
            pod1 = torch.nan_to_num(pod1, nan=0.0, posinf=0.0, neginf=0.0)
            hss1 = torch.nan_to_num(hss1, nan=0.0, posinf=0.0, neginf=0.0)

            avg_csi.append(csi1.item())
            avg_far.append(far1.item())
            avg_pod.append(pod1.item())
            avg_hss.append(hss1.item())

        res_dict['csi'] = float(np.mean(avg_csi)) if avg_csi else 0.0
        res_dict['far'] = float(np.mean(avg_far)) if avg_far else 0.0
        res_dict['avg_pod'] = float(np.mean(avg_pod)) if avg_pod else 0.0
        res_dict['hss'] = float(np.mean(avg_hss)) if avg_hss else 0.0

        return res_dict