#!/usr/bin/env python
# models.py - Neural network models and loss functions
# --------------------------------------------------------------------
import torch
import torch.nn as nn
import math

# Neural Network
class ResidualBlock(nn.Module):
    def __init__(self, width):
        super().__init__()
        self.linear1 = nn.Linear(width, width)
        self.prelu1 = nn.PReLU()
        self.linear2 = nn.Linear(width, width)
        self.prelu2 = nn.PReLU()

    def forward(self, x):
        res = x
        out = self.prelu1(self.linear1(x))
        out = self.prelu2(self.linear2(out))
        return out + res

# FeatureExtractor
class FeatureExtractor(nn.Module):
    def __init__(self, in_dim, width=128, depth=4):
        super().__init__()
        self.input_proj = nn.Sequential(
            nn.Linear(in_dim, width),
            nn.PReLU()
        )

        self.blocks = nn.ModuleList([
            ResidualBlock(width) for _ in range(depth)
        ])

        self.extra1 = nn.Linear(width, width)
        self.extra2 = nn.Linear(width, width)

        self.output_layer = nn.Linear(width, width // 2)

    def forward(self, x):
        x = self.input_proj(x)
        for blk in self.blocks:
            x = blk(x)

        x = self.extra1(x)
        x = self.extra2(x)
 
        return self.output_layer(x)

# Final Model:
class ResMLP(nn.Module):
    def __init__(self, in_dim, width=128, depth=4):
        super().__init__()
        self.f = FeatureExtractor(in_dim, width, depth)
        self.h = nn.Linear(width // 2, 1, bias=False)

    def forward(self, x):
        feats = self.f(x)
        return self.h(feats)

# Two Links: Logit, Probit
class LogitLink:
    @staticmethod
    def cdf(x):
        return torch.sigmoid(x)
    
class ProbitLink:
    @staticmethod
    def cdf(x):
        return 0.5 * (1 + torch.erf(x / math.sqrt(2)))
    
# Two cases of thresholds: Fixed, Learnable
## Fixed:
class CLMFixed(nn.Module):
    def __init__(self, K, lo=None, hi=None, link="logit"):
        super().__init__()
        
        if lo is None or hi is None:
            if link == "logit":
                lo = -20
                hi = 20
            elif link == "probit":
                lo = -2
                hi = 2
        
        self.register_buffer('b', torch.linspace(lo, hi, K+1))
        
        if link == "logit":
            self.link = LogitLink()
        elif link == "probit":
            self.link = ProbitLink()
        else:
            raise ValueError(f"Unsupported link function: {link}")
    
    def forward(self, z, y):
        z = z.view(-1, 1)
        cdf = self.link.cdf(self.b - z)
        prob = cdf[:,1:] - cdf[:,:-1]
        idx = torch.arange(len(y), device=y.device)
        nll = -(prob[idx, y] + 1e-12).log().mean()
        return nll, prob
    
## Learnable:
class CLMLearnable(nn.Module):
    def __init__(self, K, link="logit"):
        super().__init__()
        self.delta = nn.Parameter(torch.zeros(K-1))
        
        if link == "logit":
            self.link = LogitLink()
        elif link == "probit":
            self.link = ProbitLink()
        else:
            raise ValueError(f"Unsupported link function: {link}")
    
    def _b(self):
        soft = nn.functional.softplus(self.delta)
        return torch.cumsum(soft,0) - soft.mean()
    
    def forward(self, z, y):
        z = z.view(-1, 1)
        b = self._b()
        cdf = self.link.cdf(b - z)
        prob = torch.cat([torch.zeros_like(z), cdf, torch.ones_like(z)], dim=1)
        prob = prob[:,1:] - prob[:,:-1]
        idx = torch.arange(len(y), device=y.device)
        nll = -(prob[idx, y] + 1e-12).log().mean()
        return nll, prob, b

# Predictive Label
def class_from_z(z, b):
    return (z > b).sum(1).long()