import torch
import torch.nn as nn


class OrdinalModel(nn.Module):
    def __init__(self, numberOfLevels: int, symmetrize=False, offset=0):
        super(OrdinalModel, self).__init__()
        if symmetrize:
            assert offset == 0
        self.symmetrize = symmetrize
        self.offset = offset
        self.bias = nn.Parameter(torch.zeros(1))
        self.distances = nn.Parameter(torch.randn(numberOfLevels - 2) / 2 - 1)

    # def calculateThetas(self):
    #     middleValues = self.bias + torch.hstack([torch.zeros(1), torch.exp(self.distances).cumsum(0)])
    #     offset = self.offset
    #     if self.symmetrize:
    #         offset = middleValues.shape[0]
    #         # middleValues = torch.hstack([-middleValues[1:].flip(0), middleValues])
    #         middleValues = torch.hstack([-middleValues.flip(0), middleValues])
    #     return torch.hstack([torch.tensor(-float("inf")), middleValues, torch.tensor(float("inf"))]), offset
    
    def calculateThetas(self):
        # Use either the original bias or its exponentiated version
        effective_bias = torch.exp(self.bias) if self.symmetrize else self.bias
        
        middleValues = effective_bias + torch.hstack([torch.zeros(1), torch.exp(self.distances).cumsum(0)])
        offset = self.offset
        if self.symmetrize:
            offset = middleValues.shape[0]
            middleValues = torch.hstack([-middleValues.flip(0), middleValues])
        return torch.hstack([torch.tensor(-float("inf")), middleValues, torch.tensor(float("inf"))]), offset




