from abc import ABC, abstractmethod
import torch.nn as nn
import torch
from einops import rearrange

def reduce(l, reduction = 'sum'):
    if reduction == 'sum':
        return torch.sum(l)
    elif reduction == 'mean':
        return torch.mean(l)
    elif reduction == 'none':
        return l
    else:
        raise ValueError(f"Invalid reduction {reduction}")

class Loss(nn.Module, ABC):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.epoch = 0
    
    @abstractmethod
    def forward(self, y_pred, y_gt, y_gt_prev = None, epoch = None, max_epochs = None):
        '''
        :param y_pred: predicted value
        :param y_gt: ground truth
        :param y_gt_prev: ground truth of the previous time step (optional)'''
        pass

    def update_epoch(self):
        self.epoch += 1


class nRMSE(Loss):
    def __init__(self, reduction = 'sum'):
        super().__init__('nRMSE')
        self.reduction = reduction
        self.EPS = 1e-8
    
    def forward(self, y_pred, y_gt, y_gt_prev = None):
        B = y_pred.shape[0]
        l = torch.norm(y_pred.reshape(B,-1) - y_gt.reshape(B,-1), p=2, dim=1) / (torch.norm(y_gt.reshape(B,-1), p=2, dim=1) + self.EPS)
        return reduce(l, self.reduction)


class RMSE(Loss):
    def __init__(self, reduction = 'sum'):
        super().__init__('RMSE')
        self.reduction = reduction
    
    def forward(self, y_pred, y_gt, y_gt_prev = None):
        B = y_pred.shape[0]
        l =  torch.norm(y_pred.reshape(B,-1) - y_gt.reshape(B,-1), p=2, dim=1)
        return reduce(l, self.reduction)
    

class SensitiveWeighted(Loss):
    def __init__(self, base_loss, max_epochs):
        super().__init__(f'SensitiveWeighted_{base_loss.name}')
        self.base_loss = base_loss
        self.max_epochs = max_epochs
    
    def forward(self, y_pred, y_gt, y_gt_prev = None):
        l = self.base_loss.forward(y_pred, y_gt, y_gt_prev)

        # Sensitive weight based on change of ground truth:
        sen_w = nRMSE()(y_gt, y_gt_prev)

        return l * sen_w ** (1 - self.epoch / self.max_epochs)

class PearsonCorrelation(Loss):
    def __init__(self, reduction = 'sum'):
        '''Calculates Pearson Correlation between prediction and ground truth'''
        super().__init__('PearsonCorrelation')
        self.reduction = reduction
    
    def forward(self, y_pred, y_gt, y_gt_prev = None):
        '''
        :param y_pred: (B, Sx, [Sy], [Sx], V)
        :param y_true: (B, Sx, [Sy], [Sz] V)
        '''
        EPS = 1e-8
        B = y_pred.shape[0]
        y_pred_ = y_pred.reshape(B, -1)
        y_gt_ = y_gt.reshape(B, -1)

        y_pred_mean = torch.mean(y_pred_, dim=(1), keepdim=True)
        y_gt_mean = torch.mean(y_gt_, dim=(1), keepdim=True)
        # Unbiased since we use unbiased estimates in covariance
        y_pred_std = torch.std(y_pred_, dim=(1), unbiased=False)
        y_gt_std = torch.std(y_gt_, dim=(1), unbiased=False)

        corr = torch.mean((y_pred_ - y_pred_mean) * (y_gt_ - y_gt_mean), dim=1) / (y_pred_std * y_gt_std + EPS) # shape (B, T)5
        corr = reduce(corr, self.reduction)
        return corr

loss_registry = {
    'nRMSE': nRMSE,
    'RMSE': RMSE,
    'SensitiveWeighted': SensitiveWeighted,
    'PearsonCorrelation': PearsonCorrelation
}


        