import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class PenDer(nn.Module):

    def __init__(self,
                input_size,
                hidden_layers,
                output_size,
                monotonic_features=None,
                convex_features=None,
                lambda_penalty=1.0,
                rho=1.1,
                initial_mu=1.0):
        super(PenDer, self).__init__()

        if monotonic_features is None:
            monotonic_features = []
        if convex_features is None:
            convex_features = []

        self.input_size = input_size
        self.hidden_layers = hidden_layers
        self.output_size = output_size
        self.monotonic_features = monotonic_features
        self.convex_features = convex_features

        self.lambda_penalty = lambda_penalty  
        self.rho = rho
        self.mu = initial_mu

        # Lagrange multipliers
        self.lambda_monotonic = torch.zeros(len(monotonic_features), requires_grad=False)
        self.lambda_convex = torch.zeros(len(convex_features), requires_grad=False)

        # Build network
        layers = []
        prev_size = input_size
        for layer_size in hidden_layers:
            layers.append(nn.Linear(prev_size, layer_size))
            layers.append(nn.Sigmoid())
            prev_size = layer_size
        layers.append(nn.Linear(prev_size, output_size))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

    def compute_loss(self, x, y, criterion):
        device = next(self.parameters()).device
        x = x.to(device)
        y = y.to(device)

        # Forward pass with requires_grad for constraint checking
        x_ = x.detach().clone()
        x_.requires_grad_()
        outputs = self.forward(x_)

        data_loss = criterion(outputs, y)

        grads = torch.autograd.grad(
            outputs.sum(), x_,
            create_graph=True,  
            retain_graph=True
        )[0]
        monotonic_penalty = 0.0
        for idx, feature_idx in enumerate(self.monotonic_features):

            violation = torch.clamp(-grads[:, feature_idx], min=0)
            monotonic_penalty += (self.mu / x_.shape[0]) * torch.sum(violation ** 2)
            monotonic_penalty += (1.0 / x_.shape[0]) * self.lambda_monotonic[idx] * torch.sum(violation)

        convex_penalty = 0.0
        for idx, feature_idx in enumerate(self.convex_features):
            second_grads = torch.autograd.grad(
                grads[:, feature_idx].sum(), 
                x_,
                create_graph=True,
                retain_graph=True
            )[0]

            violation = torch.clamp(-second_grads[:, feature_idx], min=0)
            convex_penalty += (self.mu / x_.shape[0]) * torch.sum(violation ** 2)
            convex_penalty += (1.0 / x_.shape[0]) * self.lambda_convex[idx] * torch.sum(violation)

        total_loss = data_loss + monotonic_penalty + convex_penalty
        return total_loss

    @torch.no_grad()
    def update_multipliers_and_penalty(self, x):

        device = next(self.parameters()).device
        x = x.to(device)

        if (len(self.monotonic_features) == 0) and (len(self.convex_features) == 0):

            self.mu *= self.rho
            return

        with torch.enable_grad():
            x_ = x.detach().clone()
            x_.requires_grad_()
            outputs = self.forward(x_)
            grads = torch.autograd.grad(
                outputs.sum(), x_,
                create_graph=True,
                retain_graph=True
            )[0]

            for idx, feature_idx in enumerate(self.monotonic_features):
                violation = torch.clamp(-grads[:, feature_idx], min=0)
                self.lambda_monotonic[idx] += self.mu * violation.mean().item()

            for idx, feature_idx in enumerate(self.convex_features):
                second_grads = torch.autograd.grad(
                    grads[:, feature_idx].sum(),
                    x_,
                    create_graph=True,
                    retain_graph=True
                )[0]
                violation = torch.clamp(-second_grads[:, feature_idx], min=0)
                self.lambda_convex[idx] += self.mu * violation.mean().item()

        self.mu *= self.rho

    def calculate_monotonicity_score(self, x):

        if len(self.monotonic_features) == 0:
            return 1.0  

        device = next(self.parameters()).device
        x = x.to(device)

        with torch.enable_grad():
            x_ = x.detach().clone()
            x_.requires_grad_()
            outputs = self.forward(x_)
            grads = torch.autograd.grad(
                outputs.sum(), x_,
                create_graph=False
            )[0]

        monotonic_count = 0
        total_count = x_.shape[0] * len(self.monotonic_features)
        for feature_idx in self.monotonic_features:
            monotonic_count += (grads[:, feature_idx] >= 0).sum().item()

        score = monotonic_count / total_count if total_count > 0 else 1.0
        return score

    def calculate_convexity_score(self, x):

        if len(self.convex_features) == 0:
            return 1.0

        device = next(self.parameters()).device
        x = x.to(device)

        with torch.enable_grad():
            x_ = x.detach().clone()
            x_.requires_grad_()
            outputs = self.forward(x_)
            grads = torch.autograd.grad(
                outputs.sum(), x_,
                create_graph=True
            )[0]

            convex_count = 0
            total_count = x_.shape[0] * len(self.convex_features)
            for feature_idx in self.convex_features:
                second_grads = torch.autograd.grad(
                    grads[:, feature_idx].sum(),
                    x_,
                    create_graph=False
                )[0]

                convex_count += (second_grads[:, feature_idx] >= 0).sum().item()

        score = convex_count / total_count if total_count > 0 else 1.0
        return score


class PenDerConcave(nn.Module):

    def __init__(self,
                input_size,
                hidden_layers,
                output_size,
                monotonic_features=None,
                concave_features=None,
                lambda_penalty=1.0,
                rho=1.1,
                initial_mu=1.0):
        super(PenDerConcave, self).__init__()

        if monotonic_features is None:
            monotonic_features = []
        if concave_features is None:
            concave_features = []

        self.input_size = input_size
        self.hidden_layers = hidden_layers
        self.output_size = output_size
        self.monotonic_features = monotonic_features
        self.concave_features = concave_features

        self.lambda_penalty = lambda_penalty
        self.rho = rho
        self.mu = initial_mu

        # Lagrange multipliers
        self.lambda_monotonic = torch.zeros(len(monotonic_features), requires_grad=False)
        self.lambda_concave = torch.zeros(len(concave_features), requires_grad=False)

        # Build network
        layers = []
        prev_size = input_size
        for layer_size in hidden_layers:
            layers.append(nn.Linear(prev_size, layer_size))
            layers.append(nn.Sigmoid())
            prev_size = layer_size
        layers.append(nn.Linear(prev_size, output_size))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

    def compute_loss(self, x, y, criterion):
        device = next(self.parameters()).device
        x = x.to(device)
        y = y.to(device)

        x_ = x.detach().clone()
        x_.requires_grad_()
        outputs = self.forward(x_)

        data_loss = criterion(outputs, y)

        grads = torch.autograd.grad(
            outputs.sum(), x_,
            create_graph=True,
            retain_graph=True
        )[0]
        monotonic_penalty = 0.0
        for idx, feature_idx in enumerate(self.monotonic_features):
            violation = torch.clamp(-grads[:, feature_idx], min=0)
            monotonic_penalty += (self.mu / x_.shape[0]) * torch.sum(violation ** 2)
            monotonic_penalty += (1.0 / x_.shape[0]) * self.lambda_monotonic[idx] * torch.sum(violation)

        concave_penalty = 0.0
        for idx, feature_idx in enumerate(self.concave_features):
            second_grads = torch.autograd.grad(
                grads[:, feature_idx].sum(),
                x_,
                create_graph=True,
                retain_graph=True
            )[0]

            violation = torch.clamp(second_grads[:, feature_idx], min=0)
            concave_penalty += (self.mu / x_.shape[0]) * torch.sum(violation ** 2)
            concave_penalty += (1.0 / x_.shape[0]) * self.lambda_concave[idx] * torch.sum(violation)

        total_loss = data_loss + monotonic_penalty + concave_penalty
        return total_loss

    @torch.no_grad()
    def update_multipliers_and_penalty(self, x):
        device = next(self.parameters()).device
        x = x.to(device)

        if (len(self.monotonic_features) == 0) and (len(self.concave_features) == 0):
            self.mu *= self.rho
            return

        with torch.enable_grad():
            x_ = x.detach().clone()
            x_.requires_grad_()
            outputs = self.forward(x_)
            grads = torch.autograd.grad(
                outputs.sum(), x_,
                create_graph=True,
                retain_graph=True
            )[0]

            for idx, feature_idx in enumerate(self.monotonic_features):
                violation = torch.clamp(-grads[:, feature_idx], min=0)
                self.lambda_monotonic[idx] += self.mu * violation.mean().item()

            for idx, feature_idx in enumerate(self.concave_features):
                second_grads = torch.autograd.grad(
                    grads[:, feature_idx].sum(),
                    x_,
                    create_graph=True,
                    retain_graph=True
                )[0]
                violation = torch.clamp(second_grads[:, feature_idx], min=0)
                self.lambda_concave[idx] += self.mu * violation.mean().item()

        self.mu *= self.rho

    def calculate_monotonicity_score(self, x):
        if len(self.monotonic_features) == 0:
            return 1.0

        device = next(self.parameters()).device
        x = x.to(device)
        with torch.enable_grad():
            x_ = x.detach().clone()
            x_.requires_grad_()
            outputs = self.forward(x_)
            grads = torch.autograd.grad(
                outputs.sum(), x_,
                create_graph=False
            )[0]

        monotonic_count = 0
        total_count = x_.shape[0] * len(self.monotonic_features)
        for feature_idx in self.monotonic_features:
            monotonic_count += (grads[:, feature_idx] >= 0).sum().item()
        score = monotonic_count / total_count if total_count > 0 else 1.0
        return score

    def calculate_concavity_score(self, x):

        if len(self.concave_features) == 0:
            return 1.0

        device = next(self.parameters()).device
        x = x.to(device)
        with torch.enable_grad():
            x_ = x.detach().clone()
            x_.requires_grad_()
            outputs = self.forward(x_)
            grads = torch.autograd.grad(
                outputs.sum(), x_,
                create_graph=True
            )[0]

            concave_count = 0
            total_count = x_.shape[0] * len(self.concave_features)
            for feature_idx in self.concave_features:
                second_grads = torch.autograd.grad(
                    grads[:, feature_idx].sum(),
                    x_,
                    create_graph=False
                )[0]

                concave_count += (second_grads[:, feature_idx] <= 0).sum().item()

        score = concave_count / total_count if total_count > 0 else 1.0
        return score