import numpy as np
from .model_fitting import get_classifier

class QuadraticConstraint:
    """Example constraint g_θ(y) = ||y - c||^2 - r^2"""
    
    def __init__(self, center, radius):
        self.center = center
        self.radius_squared = radius**2
    
    def __call__(self, y):
        """Evaluate constraint: negative inside the region, positive outside"""
        return np.sum((y - self.center)**2) - self.radius_squared
    
    def grad(self, y):
        """Gradient of the constraint"""
        return 2 * (y - self.center)
    
    def hessian(self, y):
        """Hessian of the constraint"""
        return 2 * np.eye(len(y))

class SoftmaxLogisticConstraint:
    """
    Constraint function for softmax-based logistic regression.
    g_θ(y) represents the decision boundary of a softmax classifier.
    
    For binary classification, this measures how far a point is from the
    decision boundary. For multi-class, it measures distance from the
    margin between the predicted class and the second most likely class.
    """
    
    def __init__(self, weights, bias=None, fairness_threshold=0.0):
        """
        Initialize the softmax constraint.
        
        Parameters:
        -----------
        weights : array-like, shape (n_classes, n_features)
            Weight matrix of the softmax classifier
        bias : array-like, shape (n_classes,), optional
            Bias terms for each class
        fairness_threshold : float, optional
            Threshold to define the fairness constraint boundary
            Positive values create a margin around the decision boundary
        """
        self.weights = np.array(weights)
        self.n_classes, self.n_features = self.weights.shape
        
        if bias is None:
            self.bias = np.zeros(self.n_classes)
        else:
            self.bias = np.array(bias)
            
        self.fairness_threshold = fairness_threshold
    
    def _softmax(self, logits):
        """Compute softmax probabilities from logits"""
        # Subtract max for numerical stability
        exp_logits = np.exp(logits - np.max(logits))
        return exp_logits / np.sum(exp_logits)
    
    def _logits(self, y):
        """Compute the logits for input y"""
        return np.dot(self.weights, y) + self.bias
    
    def __call__(self, y):
        """
        Evaluate constraint: g_θ(y) = margin - fairness_threshold
        Where margin is the difference between the top two class probabilities
        
        Returns negative value if point satisfies fairness constraint (inside the region)
        Returns positive value if point violates fairness constraint (outside the region)
        """
        logits = self._logits(y)
        probs = self._softmax(logits)
        
        # Sort probabilities in descending order
        sorted_probs = np.sort(probs)[::-1]
        
        # Margin is the difference between top two probabilities
        if self.n_classes > 1:
            margin = sorted_probs[0] - sorted_probs[1]
        else:
            margin = sorted_probs[0]  # Edge case for one class
            
        # Return constraint value
        return margin - self.fairness_threshold
    
    def grad(self, y):
        """
        Gradient of the constraint with respect to y
        ∇g_θ(y) = ∇(margin - threshold) = ∇margin
        """
        logits = self._logits(y)
        probs = self._softmax(logits)
        
        # Get indices of top two classes
        top_indices = np.argsort(probs)[::-1]
        i_top, i_second = top_indices[0], top_indices[1] if self.n_classes > 1 else -1
        
        # For binary classification or only one class
        if self.n_classes <= 2 or i_second == -1:
            # Simple case: gradient is directly related to the weight vector
            # The gradient points in the direction of the decision boundary
            return self.weights[i_top]
        
        # For multiclass, more complex calculation
        # The gradient of the margin involves both top classes
        p_i, p_j = probs[i_top], probs[i_second]
        
        # Derivation from softmax properties:
        # ∂margin/∂y = (w_i - w_j) * p_i * p_j + terms involving other classes
        grad = (self.weights[i_top] - self.weights[i_second]) * p_i * p_j
        
        # Add terms for interaction with other classes
        for k in range(self.n_classes):
            if k != i_top and k != i_second:
                p_k = probs[k]
                grad += (self.weights[i_top] * p_i - self.weights[k] * p_k) * p_i * p_k
                grad -= (self.weights[i_second] * p_j - self.weights[k] * p_k) * p_j * p_k
                
        return grad
    
    def hessian(self, y):
        """
        Hessian of the constraint with respect to y
        This is an approximation for computational efficiency
        """
        # Use outer products of weight vectors to approximate the Hessian
        logits = self._logits(y)
        probs = self._softmax(logits)
        
        # Get indices of top two classes
        top_indices = np.argsort(probs)[::-1]
        i_top, i_second = top_indices[0], top_indices[1] if self.n_classes > 1 else -1
        
        if self.n_classes <= 2 or i_second == -1:
            # Simple approximation for binary case
            w = self.weights[i_top]
            return np.outer(w, w) * 0.25  # Scaled for stability
        
        # For multiclass, approximate with weighted sum of outer products
        p_i, p_j = probs[i_top], probs[i_second]
        w_i, w_j = self.weights[i_top], self.weights[i_second]
        
        # Primary term from top two classes
        H = np.outer(w_i - w_j, w_i - w_j) * p_i * p_j * (1 - p_i - p_j)
        
        # Add regularization for numerical stability
        H += np.eye(self.n_features) * 1e-4
        
        return H

class MLPConstraint:
    """
    Constraint function for a Multi-Layer Perceptron with two hidden layers.
    Each hidden layer has 5 nodes, and the network uses ReLU activations.
    
    g_θ(y) represents the confidence of the MLP in its prediction, which
    can be used to define a fairness constraint boundary.
    """
    
    def __init__(self, weights, biases, fairness_threshold=0.0):
        """
        Initialize the MLP constraint.
        
        Parameters:
        -----------
        weights : list of arrays
            List containing weight matrices for each layer
            weights[0]: first hidden layer (shape: n_features × 5)
            weights[1]: second hidden layer (shape: 5 × 5)
            weights[2]: output layer (shape: 5 × n_classes)
        
        biases : list of arrays
            List containing bias vectors for each layer
            biases[0]: first hidden layer (shape: 5)
            biases[1]: second hidden layer (shape: 5)
            biases[2]: output layer (shape: n_classes)
        
        fairness_threshold : float, optional
            Threshold to define the fairness constraint boundary
        """
        self.weights = [np.array(w) for w in weights]
        self.biases = [np.array(b) for b in biases]
        
        # Ensure we have correct number of layers
        assert len(self.weights) == 3, "MLP must have exactly 3 layers (2 hidden + output)"
        assert len(self.biases) == 3, "MLP must have exactly 3 bias vectors"
        
        # Ensure each hidden layer has 5 nodes
        assert self.weights[0].shape[1] == 5, "First hidden layer must have 5 nodes"
        assert self.weights[1].shape[0] == 5, "First hidden layer must have 5 nodes"
        assert self.weights[1].shape[1] == 5, "Second hidden layer must have 5 nodes"
        assert self.weights[2].shape[0] == 5, "Second hidden layer must have 5 nodes"
        
        self.n_features = self.weights[0].shape[0]
        self.n_classes = self.weights[2].shape[1]
        self.fairness_threshold = fairness_threshold
        
    def _relu(self, x):
        """ReLU activation function"""
        return np.maximum(0, x)
    
    def _relu_derivative(self, x):
        """Derivative of ReLU function"""
        return np.where(x > 0, 1.0, 0.0)
    
    def _softmax(self, logits):
        """Compute softmax probabilities from logits"""
        # Subtract max for numerical stability
        exp_logits = np.exp(logits - np.max(logits))
        return exp_logits / np.sum(exp_logits)
    
    def _forward(self, y):
        """Forward pass through the MLP, storing intermediate activations"""
        activations = []
        preactivations = []
        
        # Input
        activations.append(y)
        
        # First hidden layer
        preact = np.dot(y, self.weights[0]) + self.biases[0]
        preactivations.append(preact)
        act = self._relu(preact)
        activations.append(act)
        
        # Second hidden layer
        preact = np.dot(act, self.weights[1]) + self.biases[1]
        preactivations.append(preact)
        act = self._relu(preact)
        activations.append(act)
        
        # Output layer
        preact = np.dot(act, self.weights[2]) + self.biases[2]
        preactivations.append(preact)
        act = self._softmax(preact)
        activations.append(act)
        
        return activations, preactivations
    
    def __call__(self, y):
        """
        Evaluate constraint: g_θ(y) = margin - fairness_threshold
        Where margin is the difference between the top two class probabilities
        
        Returns negative value if point satisfies fairness constraint (inside the region)
        Returns positive value if point violates fairness constraint (outside the region)
        """
        activations, _ = self._forward(y)
        probs = activations[-1]  # Output probabilities
        
        # Sort probabilities in descending order
        sorted_probs = np.sort(probs)[::-1]
        
        # Margin is the difference between top two probabilities
        if self.n_classes > 1:
            margin = sorted_probs[0] - sorted_probs[1]
        else:
            margin = sorted_probs[0]  # Edge case for one class
            
        # Return constraint value
        return margin - self.fairness_threshold
    
    def grad(self, y):
        """
        Gradient of the constraint with respect to y using backpropagation
        ∇g_θ(y) = ∇(margin - threshold) = ∇margin
        """
        activations, preactivations = self._forward(y)
        probs = activations[-1]
        
        # Get indices of top two classes
        top_indices = np.argsort(probs)[::-1]
        i_top, i_second = top_indices[0], top_indices[1] if self.n_classes > 1 else -1
        
        # Initialize gradient at the output layer
        if self.n_classes <= 1 or i_second == -1:
            # Simple case: single class or only one class with non-zero probability
            grad_output = np.zeros(self.n_classes)
            grad_output[i_top] = 1.0  # Gradient for the top class
        else:
            # Gradient of the margin with respect to softmax output
            grad_output = np.zeros(self.n_classes)
            
            # For margin = p_i - p_j:
            # ∂margin/∂p_i = 1, ∂margin/∂p_j = -1, ∂margin/∂p_k = 0 (for k != i,j)
            grad_output[i_top] = 1.0
            grad_output[i_second] = -1.0
            
            # Adjust for softmax derivative: ∂p_i/∂z_j = p_i(δ_ij - p_j)
            softmax_grad = np.zeros((self.n_classes, self.n_classes))
            for i in range(self.n_classes):
                for j in range(self.n_classes):
                    if i == j:
                        softmax_grad[i, j] = probs[i] * (1 - probs[j])
                    else:
                        softmax_grad[i, j] = -probs[i] * probs[j]
            
            # Apply softmax derivative
            grad_output = np.dot(grad_output, softmax_grad)
        
        # Backpropagation through the network
        # Gradient at output layer (before softmax)
        delta = grad_output
        
        # Gradient at second hidden layer (after ReLU)
        delta = np.dot(delta, self.weights[2].T)
        
        # Gradient at second hidden layer (before ReLU)
        delta = delta * self._relu_derivative(preactivations[1])
        
        # Gradient at first hidden layer (after ReLU)
        delta = np.dot(delta, self.weights[1].T)
        
        # Gradient at first hidden layer (before ReLU)
        delta = delta * self._relu_derivative(preactivations[0])
        
        # Gradient at input layer
        grad_input = np.dot(delta, self.weights[0].T)
        
        return grad_input
    
    def hessian(self, y):
        """
        Hessian of the constraint with respect to y
        This is an approximation for computational stability
        """
        # Computing the exact Hessian for an MLP is very complex
        # We use a positive semi-definite approximation based on gradient outer product
        
        # Compute gradient
        g = self.grad(y)
        
        # Approximate Hessian as outer product of gradient
        H_approx = np.outer(g, g)
        
        # Add regularization for numerical stability
        H_approx += np.eye(self.n_features) * 1e-4
        
        return H_approx


class LinearSVMConstraint:
    """
    Constraint function for a linear Support Vector Machine.
    g_θ(y) represents the decision function of a linear SVM.
    
    For binary classification, this measures the signed distance to the decision boundary.
    Positive values indicate one class, negative values indicate the other class.
    Points closer to the margin have constraint values closer to 0.
    """
    
    def __init__(self, weights, bias=0.0, fairness_threshold=0.0):
        """
        Initialize the linear SVM constraint.
        
        Parameters:
        -----------
        weights : array-like, shape (n_features,)
            Weight vector of the linear SVM (normal vector to the hyperplane)
        bias : float, optional
            Bias term for the hyperplane (intercept)
        fairness_threshold : float, optional
            Threshold to define the fairness constraint boundary
            Positive values create a margin around the decision boundary
        """
        self.weights = np.array(weights).flatten()
        self.bias = bias            
        self.fairness_threshold = fairness_threshold
    
    def __call__(self, y):
        """
        Evaluate constraint: g_θ(y) = margin - fairness_threshold
        Where margin is the distance to the SVM decision boundary
        
        Returns negative value if point satisfies fairness constraint (inside the region)
        Returns positive value if point violates fairness constraint (outside the region)
        """
        # SVM decision function: f(y) = <w,y> + b
        decision_value = np.dot(self.weights, y) + self.bias
        
        # Normalize by the norm of weights to get geometric distance
        distance = np.abs(decision_value) / np.linalg.norm(self.weights)
        
        # Compute margin (distance to the boundary)
        # Return constraint value
        return distance - self.fairness_threshold
    
    def grad(self, y):
        """
        Gradient of the constraint with respect to y
        ∇g_θ(y) = ∇(margin - threshold) = ∇margin
        """
        # Decision value
        decision_value = np.dot(self.weights, y) + self.bias
        
        # Norm of weights
        weight_norm = np.linalg.norm(self.weights)
        
        # Sign of the decision value
        sign = np.sign(decision_value)
        
        # Gradient of the distance to the boundary
        # d(|wx+b|/||w||)/dx = sign(wx+b)*w/||w||
        return sign * self.weights / weight_norm
    
    def hessian(self, y):
        """
        Hessian of the constraint with respect to y
        This is mostly zeros for a linear SVM since the decision boundary is linear
        """
        # The true Hessian would be zero, but for numerical stability
        # we return a small positive definite matrix
        n_features = len(self.weights)
        
        # Return a small regularization matrix
        return np.eye(n_features) * 1e-4


class KernelSVMConstraint:
    """
    Constraint function for a kernel-based Support Vector Machine.
    g_θ(y) represents the decision function of a kernel SVM with RBF kernel.
    
    Implements the decision function f(y) = sum_i α_i y_i K(x_i, y) + b
    where K is the kernel function, α_i are the dual coefficients, and x_i are the support vectors.
    """
    
    def __init__(self, support_vectors, dual_coef, bias=0.0, gamma=1.0, kernel='rbf', fairness_threshold=0.0):
        """
        Initialize the kernel SVM constraint.
        
        Parameters:
        -----------
        support_vectors : array-like, shape (n_support, n_features)
            Support vectors from the trained SVM model
        dual_coef : array-like, shape (n_support,)
            Dual coefficients (α_i * y_i) for each support vector
        bias : float, optional
            Bias term for the decision function
        gamma : float, optional
            Parameter for RBF kernel: exp(-gamma * ||x-y||^2)
        kernel : str, optional
            Kernel type, currently supports 'rbf' and 'linear'
        fairness_threshold : float, optional
            Threshold to define the fairness constraint boundary
        """
        self.support_vectors = np.array(support_vectors)
        self.dual_coef = np.array(dual_coef).flatten()
        self.bias = bias
        self.gamma = gamma
        self.kernel = kernel.lower()
        self.fairness_threshold = fairness_threshold
        
        # Validate kernel type
        if self.kernel not in ['rbf', 'linear', 'poly']:
            raise ValueError("Kernel type not supported. Use 'rbf', 'linear', or 'poly'.")
        
        # Number of features
        self.n_features = self.support_vectors.shape[1]
        
    def _kernel_function(self, x, y):
        """Compute kernel value between two points"""
        if self.kernel == 'linear':
            return np.dot(x, y)
        elif self.kernel == 'rbf':
            return np.exp(-self.gamma * np.sum((x - y)**2))
        elif self.kernel == 'poly':
            # Polynomial kernel (x⋅y + 1)^3
            # For simplicity, using degree 3
            return (np.dot(x, y) + 1)**3
    
    def _kernel_gradient(self, x, y):
        """Compute gradient of kernel with respect to y"""
        if self.kernel == 'linear':
            return x
        elif self.kernel == 'rbf':
            # ∇_y K(x,y) = -2γ * (y-x) * K(x,y)
            kernel_val = self._kernel_function(x, y)
            return -2 * self.gamma * (y - x) * kernel_val
        elif self.kernel == 'poly':
            # ∇_y (x⋅y + 1)^3 = 3(x⋅y + 1)^2 * x
            return 3 * (np.dot(x, y) + 1)**2 * x
    
    def _kernel_hessian(self, x, y):
        """Compute Hessian of kernel with respect to y"""
        if self.kernel == 'linear':
            # Linear kernel has zero Hessian
            return np.zeros((self.n_features, self.n_features))
        elif self.kernel == 'rbf':
            # For RBF kernel:
            # H = 4γ² * (y-x)(y-x)ᵀ * K(x,y) - 2γ * I * K(x,y)
            kernel_val = self._kernel_function(x, y)
            diff = y - x
            
            outer_product = np.outer(diff, diff)
            identity = np.eye(self.n_features)
            
            return (4 * self.gamma**2 * outer_product - 2 * self.gamma * identity) * kernel_val
        elif self.kernel == 'poly':
            # For poly kernel (simplified):
            # H = 6(x⋅y + 1) * x⊗x
            dot_product = np.dot(x, y) + 1
            return 6 * dot_product * np.outer(x, x)
    
    def __call__(self, y):
        """
        Evaluate constraint: g_θ(y) = margin - fairness_threshold
        Where margin is the absolute value of the SVM decision function
        
        Returns negative value if point satisfies fairness constraint (inside the region)
        Returns positive value if point violates fairness constraint (outside the region)
        """
        # Compute decision function value
        decision_value = 0.0
        
        for i, sv in enumerate(self.support_vectors):
            decision_value += self.dual_coef[i] * self._kernel_function(sv, y)
        
        decision_value += self.bias
        
        # Use absolute value as the margin
        margin = np.abs(decision_value)
        
        # Return constraint value
        return margin - self.fairness_threshold
    
    def grad(self, y):
        """
        Gradient of the constraint with respect to y
        ∇g_θ(y) = ∇(margin - threshold) = ∇margin
        """
        # Compute decision function and its gradient
        decision_value = 0.0
        gradient = np.zeros(self.n_features)
        
        for i, sv in enumerate(self.support_vectors):
            decision_value += self.dual_coef[i] * self._kernel_function(sv, y)
            gradient += self.dual_coef[i] * self._kernel_gradient(sv, y)
        
        decision_value += self.bias
        
        # Gradient of |f(y)| is sign(f(y)) * ∇f(y)
        sign = np.sign(decision_value)
        
        return sign * gradient
    
    def hessian(self, y):
        """
        Hessian of the constraint with respect to y
        This is an approximation for computational efficiency
        """
        # Compute decision value
        decision_value = 0.0
        for i, sv in enumerate(self.support_vectors):
            decision_value += self.dual_coef[i] * self._kernel_function(sv, y)
        decision_value += self.bias
        
        # Sign of the decision value
        sign = np.sign(decision_value)
        
        # Compute Hessian of the decision function
        hessian = np.zeros((self.n_features, self.n_features))
        
        for i, sv in enumerate(self.support_vectors):
            hessian += self.dual_coef[i] * self._kernel_hessian(sv, y)
        
        # Multiply by sign
        hessian = sign * hessian
        
        # Add regularization for numerical stability
        hessian += np.eye(self.n_features) * 1e-4
        
        return hessian


def extract_model_parameters(model, classifier_name):
    """
    Extracts relevant parameters from a trained model.
    
    Parameters:
    -----------
    model : sklearn.base.BaseEstimator
        Trained classifier model
    classifier_name : str
        Name of the classifier type
        
    Returns:
    --------
    params : dict
        Dictionary containing extracted model parameters
    """
    params = {}
    
    if classifier_name == 'logistic':
        # For logistic regression, extract weights and biases
        params['weights'] = model.coef_
        params['bias'] = model.intercept_
        
    elif classifier_name == 'linear_svm':
        # For linear SVM, extract weights and biases
        params['weights'] = model.coef_
        params['bias'] = model.intercept_[0] if hasattr(model, 'intercept_') else 0.0
        
    elif classifier_name == 'nonlinear_svm':
        # For nonlinear SVM, extract support vectors, dual coefficients, bias, and kernel parameters
        if hasattr(model, 'support_vectors_'):
            params['support_vectors'] = model.support_vectors_
        if hasattr(model, 'dual_coef_'):
            params['dual_coef'] = model.dual_coef_[0]  # First row for binary classification
        if hasattr(model, 'intercept_'):
            params['bias'] = model.intercept_[0] if len(model.intercept_) > 0 else 0.0
        if hasattr(model, '_gamma'):
            params['gamma'] = model._gamma
        else:
            params['gamma'] = 1.0 / model.n_features_in_ if hasattr(model, 'n_features_in_') else 0.5
        if hasattr(model, 'kernel'):
            params['kernel'] = model.kernel
        
    elif classifier_name == 'mlp':
        # For MLP, extract weights and biases for each layer
        if hasattr(model, 'coefs_') and hasattr(model, 'intercepts_'):
            params['weights'] = model.coefs_
            params['biases'] = model.intercepts_
        
    elif classifier_name == 'gbm' or classifier_name == 'adaboost':
        # These ensemble methods are not directly supported by the constraint classes
        # You would need custom constraint implementations for these
        params['model'] = model  # Store the full model for now
            
    return params


def create_constraint_from_model(model, classifier_name, fairness_threshold=0.0):
    """
    Creates a constraint object for DRUNE based on a trained model.
    
    Parameters:
    -----------
    model : sklearn.base.BaseEstimator
        Trained classifier model
    classifier_name : str
        Name of the classifier type
    fairness_threshold : float, optional
        Threshold parameter for the fairness constraint
        
    Returns:
    --------
    constraint : object
        An instance of a constraint class compatible with DRUNE
    """
    # Extract model parameters
    params = extract_model_parameters(model, classifier_name)
    
    # Create the appropriate constraint object based on classifier type
    if classifier_name == 'logistic':  # Assuming this is imported from the DRUNE module
        return SoftmaxLogisticConstraint(
            weights=params['weights'],
            bias=params['bias'],
            fairness_threshold=fairness_threshold
        )
        
    elif classifier_name == 'linear_svm':
        return LinearSVMConstraint(
            weights=params['weights'][0],  # First row for binary classification
            bias=params['bias'],
            fairness_threshold=fairness_threshold
        )
        
    elif classifier_name == 'nonlinear_svm':
        return KernelSVMConstraint(
            support_vectors=params['support_vectors'],
            dual_coef=params['dual_coef'],
            bias=params['bias'],
            gamma=params['gamma'],
            kernel=params['kernel'],
            fairness_threshold=fairness_threshold
        )
        
    elif classifier_name == 'mlp':
        
        # Adjust if MLP architecture doesn't match expected structure (2 hidden layers of 5 nodes)
        if len(params['weights']) != 3:
            raise ValueError("MLPConstraint requires a model with exactly 2 hidden layers")
        
        # Convert to expected structure if necessary
        weights = params['weights']
        biases = params['biases']
        
        return MLPConstraint(
            weights=weights,
            biases=biases,
            fairness_threshold=fairness_threshold
        )
        
    else:
        raise ValueError(f"Constraints for classifier '{classifier_name}' are not implemented")


# Example workflow function 
def train_and_create_constraint(X, y, classifier_name, fairness_threshold=0.0, **model_params):
    """
    Trains a classifier and creates a corresponding constraint object for DRUNE.
    
    Parameters:
    -----------
    X : array-like, shape (n_samples, n_features)
        Training data
    y : array-like, shape (n_samples,)
        Target values
    classifier_name : str
        Name of the classifier to use
    fairness_threshold : float, optional
        Threshold parameter for the fairness constraint
    **model_params : dict
        Additional parameters to pass to the classifier constructor
        
    Returns:
    --------
    model : sklearn.base.BaseEstimator
        Trained classifier model
    constraint : object
        Constraint object for use with DRUNE
    """
    # Get and train the classifier
    model = get_classifier(classifier_name, **model_params)
    model.fit(X, y)
    
    # Create the constraint object
    constraint = create_constraint_from_model(model, classifier_name, fairness_threshold)
    
    return model, constraint
