import torch.distributions as distributions
import torch.nn as nn
from einops import rearrange
from abc import ABC, abstractmethod
import torch

TORCH_DISTRIBUTIONS = {
                       "student-t": distributions.StudentT,
                       "normal": distributions.Normal,
                       "normal_fixed_variance": distributions.Normal,
                       }

EPS = 1e-5


class NLL:
    def __init__(self, distribution):
        self.distribution = distribution
    
    def __call__(self, parameters_pred, obs):
        '''Computes the negative log likelihood
        '''
        parameters_pred = torch.split(parameters_pred, 1, dim=-1)
        return -self.distribution(*parameters_pred, validate_args=True).log_prob(obs).sum()


def get_prob_loss(distribution):
    '''Returns the loss function for the given distribution'''
    if distribution not in TORCH_DISTRIBUTIONS:
        raise ValueError(f"Invalid distribution: {distribution}")
    return NLL(TORCH_DISTRIBUTIONS[distribution])


class ProbForecaster(nn.Module, ABC):
    def __init__(self, base_model, distribution):
        super(ProbForecaster, self).__init__()
        self.base_model = base_model
        self.distribution = TORCH_DISTRIBUTIONS[distribution]
    
    def forward(self, *args, **kwargs):
        '''Returns the parameters of the distribution during training and the mean during inference'''
        if self.training:
            return self.predict_parameters(*args, **kwargs)
        else:
            return self.inference(*args, **kwargs)

    @abstractmethod
    def predict_parameters(self, *args, **kwargs):
        '''Returns the parameters of the distribution'''
        pass

    @abstractmethod
    def inference(self, *args, **kwargs):
        pass
    
    @abstractmethod
    def inference_variance(self, *args, **kwargs):
        '''Returns inference and variance of the distribution'''
        pass


class NormalForecaster(ProbForecaster):
    def predict_parameters(self, *args, **kwargs):
        '''Returns the parameters of the normal distribution'''
        out = self.base_model(*args, **kwargs)
        n_parameters = 2
        out = rearrange(out, '... (z n) -> ... z n', n=n_parameters)
        mu = out[..., 0].squeeze(-1)
        # in essence, we are predicting the log of the variance
        sigma = nn.Softplus()(out[..., 1].squeeze(-1)) + EPS
        return torch.stack([mu, sigma], dim=-1)

    def inference(self, *args, **kwargs):
        parameters_pred = self.predict_parameters(*args, **kwargs)
        # simply return the mean
        return parameters_pred[..., 0].unsqueeze(-1)

    def inference_variance(self, *args, **kwargs):
        parameters_pred = self.predict_parameters(*args, **kwargs)
        mu = parameters_pred[..., 0]
        sigma = parameters_pred[..., 1]
        return mu.unsqueeze(-1), sigma.unsqueeze(-1)


class NormalFixedVarianceForecaster(ProbForecaster):
    def __init__(self, base_model, distribution, variance):
        super(NormalFixedVarianceForecaster, self).__init__(base_model, distribution)
        self.variance = variance

    def predict_parameters(self, *args, **kwargs):
        '''Returns the parameters of the normal distribution'''
        out = self.base_model(*args, **kwargs)
        n_parameters = 1
        out = rearrange(out, '... (z n) -> ... z n', n=n_parameters)
        mu = out[..., 0].squeeze(-1)
        sigma = torch.ones_like(mu) * self.variance
        return torch.stack([mu, sigma], dim=-1)
    
    def inference(self, *args, **kwargs):
        parameters_pred = self.predict_parameters(*args, **kwargs)
        # simply return the mean
        return parameters_pred[..., 0].unsqueeze(-1)
    
    def inference_variance(self, *args, **kwargs):
        parameters_pred = self.predict_parameters(*args, **kwargs)
        mu = parameters_pred[..., 0]
        sigma = torch.ones_like(mu) * self.variance
        return mu.unsqueeze(-1), sigma.unsqueeze(-1)
    

MIN_DF = 2


class StudentTForecaster(ProbForecaster):
    def predict_parameters(self, *args, **kwargs):
        '''Returns the parameters of the normal distribution'''
        out = self.base_model(*args, **kwargs)
        n_parameters = 3
        out = rearrange(out, '... (z n) -> ... z n', n=n_parameters)
        df = nn.Softplus()(out[..., 0].squeeze(-1)) + MIN_DF + EPS
        location = out[..., 1].squeeze(-1)
        scale = nn.Softplus()(out[..., 2].squeeze(-1)) + EPS
        return torch.stack([df, location, scale], dim=-1)

    def inference(self, *args, **kwargs):
        parameters_pred = self.predict_parameters(*args, **kwargs)
        # parameters_pred = torch.split(parameters_pred, 1, dim=-1)
        # return self.distribution(*parameters_pred, validate_args = True).mean
        location = parameters_pred[..., 1]
        return location.unsqueeze(-1)
    
    def inference_variance(self, *args, **kwargs):
        parameters_pred = self.predict_parameters(*args, **kwargs)
        location = parameters_pred[..., 1]
        scale = parameters_pred[..., 2]
        df = parameters_pred[..., 0]
        assert torch.all(df > 2.0), f"Can't compute variance for df <= 2.0, {df[df <= 2.0]}"
        return location.unsqueeze(-1), (scale ** 2 * df / (df - 2.0)).unsqueeze(-1)
        

    



