from torch import nn, Tensor
import torch


class PICNN(nn.Module):
    """
    Partially input-convex neural network mapping from input example
    x of shape [input_dim] and y of shape [y_dim] to a scalar score output
    s(x, y)
    """

    def __init__(self, input_dim = 32, y_dim = 768, hidden_dim = 256, n_layers = 3,
                 y_in_output_layer = True, gamma = 0. ,
                 output_dim = 81,):
        super().__init__()
        self.input_dim = input_dim
        self.y_dim = y_dim
        self.hidden_dim = hidden_dim
        L = n_layers
        self.L = L

        # bag of tricks for feasibility
        self.gamma = gamma
        self.y_in_output_layer = y_in_output_layer

        null_module = nn.Module()

        self.W_hat_layers = nn.ModuleList(
            [null_module]
            + [nn.Linear(hidden_dim, hidden_dim) for _ in range(1, L+1)]
        )
        self.W_bar_layers = nn.ModuleList(
            [null_module]
            + [nn.Linear(hidden_dim, hidden_dim, bias=False) for _ in range(1, L)]
            + [nn.Linear(hidden_dim, output_dim, bias=False)]
        )
        self.V_hat_layers = nn.ModuleList(
            [nn.Linear(input_dim, y_dim)]
            + [nn.Linear(hidden_dim, y_dim) for _ in range(1, L+1)]
        )
        self.V_bar_layers = nn.ModuleList(
            [nn.Linear(y_dim, hidden_dim, bias=False) for _ in range(L)]
            + [nn.Linear(y_dim, 1, bias=False)]
        )
        self.b_layers = nn.ModuleList(
            [nn.Linear(input_dim, hidden_dim)]
            + [nn.Linear(hidden_dim, hidden_dim) for _ in range(1, L)]
            + [nn.Linear(hidden_dim, output_dim)]
        )
        self.u_layers = nn.ModuleList(
            [nn.Linear(input_dim, hidden_dim)]
            + [nn.Linear(hidden_dim, hidden_dim) for _ in range(1, L)]  # Used to be L+1
        )
        self.clamp_weights()

    def clamp_weights(self) -> None:
        """Clamps weights of all the W_bar layers to be ≥ 0."""
        with torch.no_grad():
            for layer in self.W_bar_layers:
                if isinstance(layer, nn.Linear):
                    layer.weight.clamp_(min=0)
            if not self.y_in_output_layer:
                self.V_bar_layers[-1].weight.fill_(0.)

    def forward(self, x: Tensor, y: Tensor) -> Tensor:
        """Computes the score for the given input examples (x, y).

        Args:
            x: shape [batch_size, input_dim], input
            y: shape [batch_size, y_dim], labels

        Returns:
            s: shape [batch_size], score
        """
        ReLU = nn.ReLU()
        u = x
        sigma = 0
        for l in range(self.L):
            if l > 0:
                W_hat_vec = ReLU(self.W_hat_layers[l](u))
                W_sigma = self.W_bar_layers[l](W_hat_vec * sigma)
            else:
                W_sigma = 0
            V_hat_vec = self.V_hat_layers[l](u)        # shape [batch, d]
            V_y = self.V_bar_layers[l](V_hat_vec * y)
            b = self.b_layers[l](u)
            sigma = ReLU(W_sigma + V_y + b)
            u = ReLU(self.u_layers[l](u))
        l = self.L
        W_hat_vec = ReLU(self.W_hat_layers[l](u))
        W_sigma = self.W_bar_layers[l](W_hat_vec * sigma)

        V_y = 0.
        if self.y_in_output_layer:
            V_hat_vec = self.V_hat_layers[l](u)
            V_y = self.V_bar_layers[l](V_hat_vec * y)
        b = self.b_layers[l](u)

        if self.gamma == 0.:
            output = W_sigma + V_y + b
        else:
            kappa = self.gamma * torch.norm(y, p=float('inf'), dim=1)
            kappa = kappa.unsqueeze(-1).repeat(1, 3)
            output = W_sigma + V_y + kappa + b

        output = output.view(-1, 9, 9)
        output = torch.softmax(output, dim=2)
        return output


def MALA_PICNN(
    model: PICNN,
    num_samples: int,
    burnin: int,
    x_init: Tensor,
    y_init: Tensor,
    std: float
) -> Tensor:
    """
    Args:
        model: PICNN
        num_samples: number of samples
        burnin: number of burn-in steps
        x_init: initial sample, shape [num_chains, x_dim], on same device as model
        y_init: initial sample, shape [num_chains, y_dim], on same device as model
        std: standard deviation of the Gaussian proposal

    Returns:
        samples: shape [num_samples, num_chains, x_dim + y_dim], on same device as model
    """
    model.eval()

    x_dim = x_init.shape[1]
    xy_init = torch.cat([x_init, y_init], dim=1)

    # Initial gradient computation
    x_init.requires_grad_(True)
    y_init.requires_grad_(True)
    output_init = model(x_init, y_init)

    init_grad_x, init_grad_y = torch.autograd.grad(
        outputs=output_init, inputs=(x_init, y_init),
        grad_outputs=torch.ones_like(output_init))

    # compute mean proposal
    mu = xy_init - 0.5 * std**2 * torch.cat([init_grad_x, init_grad_y], dim=1)

    samples = []
    for t in range(burnin + num_samples):
        with torch.no_grad():
            sample_candidate = mu + torch.randn_like(xy_init) * std
        sample_candidate.requires_grad_(True)

        sample_x = sample_candidate[:, :x_dim]
        sample_y = sample_candidate[:, x_dim:]

        output = model(sample_x, sample_y)
        grad_x, grad_y = torch.autograd.grad(
            outputs=output, inputs=(sample_x, sample_y),
            grad_outputs=torch.ones_like(output))

        if grad_x is None or grad_y is None:
            raise RuntimeError("One of the gradients is None. Ensure the inputs are "
                               "used in the model's forward pass.")

        with torch.no_grad():
            nu = sample_candidate - 0.5 * std**2 * torch.cat([grad_x, grad_y], dim=1)

            acceptance = torch.exp(
                - output
                - torch.norm(xy_init - nu, dim=1)**2 / (2 * std**2)
                + output_init
                + torch.norm(sample_candidate - mu, dim=1)**2 / (2 * std**2)
            )
            acceptance.clamp_(0., 1.)
            accept = torch.bernoulli(acceptance).to(dtype=torch.bool)

            mu[accept] = nu[accept]
            xy_init[accept] = sample_candidate[accept]
            output_init[accept] = output[accept]

            if t >= burnin:
                samples.append(xy_init)

    return torch.stack(samples)

if __name__ == "__main__":
    input_dim = 32
    y_dim = 768
    hidden_dim = 256
    n_layers = 3
    y_in_output_layer = True  
    gamma = 0.0  

    picnn_model = PICNN(input_dim=input_dim, y_dim=y_dim, hidden_dim=hidden_dim, n_layers=n_layers,
                        y_in_output_layer=y_in_output_layer, gamma=gamma)

    x = torch.randn(32, input_dim) 
    y = torch.randn(32, y_dim)     

    import time
    start_time = time.time()
    output = picnn_model(x, y)
    print("--- %s seconds ---" % (time.time() - start_time))

    print("Output shape:", output[0]) 
    # print("Output values:", output) 