import torch
from torch import nn
from torch import distributions as D

class DiagDist(nn.Module):
    def __init__(self, base: D.Distribution, dim: int, trainable: bool = True):
        super().__init__()
        self.dim = dim
        self.trainable = trainable
        if trainable:
            self.initParams()
            self.base = base
        else:
            default_params = self.getParams()
            self.base = base(**default_params)
    
    def initParams(self) -> None:
        raise NotImplementedError
    
    def getParams(self) -> dict[str, torch.Tensor]:
        raise NotImplementedError
        
    def dist(self) -> D.Distribution:
        if self.trainable:
            params = self.getParams()
            dist = self.base(**params)
        else:
            dist = self.base
        return dist
            
    def log_prob(self, z: torch.Tensor):
        dist = self.dist()
        log_prob = dist.log_prob(z)
        return log_prob
        
    def sample(self, shape: tuple):
        dist = self.dist()
        if self.trainable:
            shape = shape[:-1]
        sample = dist.sample(shape)
        return sample
    

class DiagGaussian(DiagDist):
    def __init__(self, dim):
        super().__init__(D.Normal, dim)
    
    def getParams(self):
        if self.trainable:
            return {'loc': self.loc, 'scale': self.scale}
        return {'loc': 0, 'scale': 1}

    def initParams(self):
        self.loc = nn.Parameter(torch.ones(self.dim) * 0)
        self.scale = nn.Parameter(torch.ones(self.dim) * 1)
        

class DiagStudent(DiagDist):
    def __init__(self, dim):
        super().__init__(D.StudentT, dim)
    
    def getParams(self):
        if self.trainable:
            df = nn.functional.softplus(self.log_df + 1e-6)
            # scale = nn.functional.softplus(self.log_scale + 1e-6)
            scale = self.scale
            return {'df': df, 'loc': self.loc, 'scale': scale}
        return {'df': 5, 'loc': 0, 'scale': 1}

    def initParams(self):
        self.log_df = nn.Parameter(torch.log(torch.exp(torch.ones(self.dim) * 5) - 1))
        self.loc = nn.Parameter(torch.ones(self.dim) * 0)
        self.scale = nn.Parameter(torch.ones(self.dim) * 1)
        # self.log_scale = nn.Parameter(torch.log(torch.exp(torch.ones(self.dim)) - 1))
        
        

        
        
# class TrainableStudent(nn.Module):
#     def __init__(self,
#                  dim: int,
#                  init_df: float = 5.0,
#                  init_loc: float = 0.0,
#                  init_scale: float = 1.0):
#         super().__init__()
#         self.dim = dim
#         self.log_df = nn.Parameter(torch.log(torch.exp(torch.ones(dim) * init_df) - 1))  # inverse softplus init
#         self.loc = nn.Parameter(torch.ones(dim) * init_loc)
#         self.scale = nn.Parameter(torch.ones(dim) * init_scale)
    
#     def forward(self):
#         df = nn.functional.softplus(self.log_df)
#         return D.StudentT(df=df, loc=self.loc, scale=self.scale)
    
#     def sample(self, shape: tuple):
#         base = self.forward()
#         sample = base.sample(shape)
#         return sample
    
#     def log_prob(self, z: torch.Tensor):
#         base = self.forward()
#         log_prob = base.log_prob(z)
#         return log_prob
    
    


# #==============================================================================
# #  input-specific priors

# class DiagPrior(nn.Module):
#     def __init__(self, priors: list[D.Distribution]):
#         super().__init__()
#         self.priors = priors
#         self.instances = None
        
#     def unpack(self, params: torch.Tensor, dist: D.Distribution) -> tuple:
#         if dist == D.Normal:
#             loc = params[..., 0:1]
#             scale = params[..., 1:2]
#             return loc, scale
#         elif dist == D.HalfNormal:
#             scale = params[..., 1:2]
#             return (scale,)
        
#     def instantiate(self, params: torch.Tensor) -> None:
#         assert params.shape[1] == len(self.priors), 'dimension mismatch between dim of samples and number of priors'
#         instances = []
#         for i, dist in enumerate(self.priors):
#             param = self.unpack(params[:, i], dist)
#             instances += [dist(*param)]
#         self.instances = instances
        
#     def sample(self, shape: tuple) -> torch.Tensor:
#         assert self.instances is not None, 'priors not instantiated'
#         z = torch.zeros(*shape)
#         b = shape[0]
#         s = shape[-1]
#         for i, instance in enumerate(self.instances):
#             z[:, i] = instance.sample((s,)).view(b,s)
#         return z
    
#     def log_prob(self, z: torch.Tensor) -> torch.Tensor:
#         assert self.instances is not None, 'priors not instantiated'
#         if z.dim() == 2:
#             squeeze = True
#             z = z.unsqueeze(-1)
#         else:
#             squeeze = False
#         b, d, s = z.shape
#         assert d == len(self.priors), 'dimension mismatch between dim of samples and number of priors'
#         log_probs = torch.zeros_like(z)
#         for i, instance in enumerate(self.instances):
#             log_probs[:, i] = instance.log_prob(z[:, i])
#         if squeeze:
#             log_probs = log_probs.squeeze(-1)
#         return log_probs
