import torch
from torch import Tensor, nn
import torch.nn.functional as F

def mse(input, targ):
    return F.mse_loss(input, targ, reduction = 'mean')

def rmse(input, targ):
    return torch.sqrt(F.mse_loss(input, targ, reduction = 'mean'))

def r2_score(input, target):
    from sklearn.metrics import r2_score
    return r2_score(target, input)

# https://github.com/ceshine/quantile-regression-tensorflow/blob/master/notebooks/03-sklearn-example-pytorch.ipynb
class QuantileLoss(nn.Module):
    def __init__(self, quantiles=[0.001,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.999]):
        super().__init__()
        self.quantiles = quantiles
        
    def forward(self, preds, target):
        assert not target.requires_grad
        assert preds.size(0) == target.size(0)
        losses = []
        for i, q in enumerate(self.quantiles):
            errors = target - preds[:,:,:,i]
            losses.append(torch.max((q-1) * errors, q * errors).unsqueeze(3))
        loss = torch.mean(torch.sum(torch.cat(losses, dim=3), dim=3))
        return loss

class CRPSLoss(nn.Module):
    def __init__(self, quantiles=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1], adjusted=True, eps=1e-10):
        super().__init__()
        self.adjusted = adjusted
        if self.adjusted:
            self.quantiles = torch.tensor([0]+quantiles+[1])
        else:
            self.quantiles = torch.tensor(quantiles)
        self.eps = eps
        
    def forward(self, preds, target): #preds:[B,N,T,Q] target:[B,N,T]
        assert not target.requires_grad
        assert preds.size(0) == target.size(0)
        if self.adjusted:
            B = preds.shape[0]
            N = preds.shape[1]
            T = preds.shape[2]
            max_bound = 100
            min_bound = -100
            preds = torch.cat((min_bound*torch.ones((B, N, T, 1), device=preds.device), preds), dim=-1)
            preds = torch.cat((preds, max_bound*torch.ones((B, N, T, 1), device=preds.device)), dim=-1)
        
        q_i1 = self.quantiles[1:].to(preds.device)
        q_i = self.quantiles[:-1].to(preds.device)
        X_i1 = preds[:,:,:,1:]
        X_i = preds[:,:,:,:-1]
        X_t = target.unsqueeze(3).repeat(1, 1, 1, X_i.shape[-1])
        
        index = torch.full_like(X_i1, 2)
        index[X_t > X_i1] = 0
        index[X_t < X_i] = 1
        index = F.one_hot(index.to(torch.int64), num_classes=3) #ntqd
        
        term0 = 1/3*torch.einsum('bntq,q->bntq', X_i1-X_i, (q_i1**2+q_i1*q_i+q_i**2))
        term1 = 1/3*torch.einsum('bntq,q->bntq', X_i1-X_i, ((q_i1-1)**2+(q_i1-1)*(q_i-1)+(q_i-1)**2))
        #term2 = torch.einsum('bntq,q->bntq', X_i1-X_t, 1-2*q_i) + torch.einsum('bntq,q->bntq', ((X_i1-X_i)**2-(X_t-X_i)**2)/(X_i1-X_i), -(q_i1-q_i)) + term0
        term2 = torch.einsum('bntq,q->bntq', X_t-X_i, 2*q_i-1) + torch.einsum('bntq,q->bntq', (X_t-X_i)**2/(X_i1-X_i+self.eps), q_i1-q_i) + term1
        terms = torch.stack((term0,term1,term2),dim=-1)
        
        loss = torch.einsum('bntqd,bntqd->bntq', index.to(torch.float), terms)
        return torch.mean(torch.sum(loss,dim=-1))
    
# https://solarforecastarbiter-core.readthedocs.io/en/latest/_modules/solarforecastarbiter/metrics/probabilistic.html#crps_skill_score
class CRPS(nn.Module):
    
    def __init__(self, quantiles=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]):
        super().__init__()
        # Transform quantiles list into a PyTorch tensor for computation
        self.base_quantiles = torch.tensor(quantiles, dtype=torch.float32)

    def forward(self, fx, obs):
        # Expand base quantiles for each sample in the batch
        quantiles = self.base_quantiles.to(fx.device).repeat(fx.size(0), fx.size(2), 1).unsqueeze(1)

        # Prepare for extending the CDF
        fx_min = torch.min(obs, fx[..., 0])
        fx = torch.cat([fx_min.unsqueeze(-1), fx], dim=-1)
        
        fx_max = torch.max(obs, fx[..., -1])
        fx = torch.cat([fx, fx_max.unsqueeze(-1)], dim=-1)
        
        extended_quantiles = torch.cat([torch.zeros(fx.size(0), 1, fx.size(2), 1, device = fx.device), 
                                         quantiles, 
                                         torch.full((fx.size(0), 1, fx.size(2), 1), 100.0, device = fx.device)], dim=-1)

        # Calculate the indicator function
        indicator = torch.where(fx >= obs.unsqueeze(-1), 1.0, 0.0)
        
        idx = torch.where(torch.max(fx, dim=-1).values > obs, 0, 1)
        expanded_idx = idx.unsqueeze(-1).expand(-1, -1, -1, indicator.size(-1))

        # Convert expanded_idx to a boolean mask where True values will indicate positions to update
        mask = expanded_idx == 1

        # Update indicator using the boolean mask, setting the last element to 0 where mask is True
        indicator[mask] = 0

        # Calculate forecast probabilities
        probabilities = extended_quantiles / 100.0

        # Compute the CRPS for each element and then average
        crps = torch.mean(torch.trapezoid((probabilities - indicator) ** 2, fx, dim=-1))

        return crps

