import pyomo.environ as pyo
import numpy as np
import torch
import torch.nn as nn
from pyomo.environ import NonNegativeReals

# Define a simple CNN model for demonstration
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=0, bias=True)
        self.conv2 = nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=0, bias=True)
        self.conv3 = nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=0, bias=True)
        self.conv4 = nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=0, bias=True)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(2 * 2 * 2, 2)  # 2 channels, 2x2 spatial
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)
        return x

def create_cnn_verification_model(cnn_model, input_shape=(1, 4, 4), input_bounds=(-1, 1), eps = 1e-5):
    """
    Create a Pyomo model for CNN verification
    
    Args:
        cnn_model: PyTorch CNN model
        input_shape: Shape of input (channels, height, width)
        input_bounds: Bounds for input perturbation
    """
    model = pyo.ConcreteModel()
    
    # Extract weights and biases from PyTorch model
    conv_weight = cnn_model.conv1.weight.data.numpy()  # Shape: (out_ch, in_ch, kh, kw)
    conv_bias = cnn_model.conv1.bias.data.numpy()
    fc_weight = cnn_model.fc.weight.data.numpy()
    fc_bias = cnn_model.fc.bias.data.numpy()
    
    # Input dimensions
    in_ch, in_h, in_w = input_shape
    out_ch, _, k_h, k_w = conv_weight.shape
    
    # Output dimensions after convolution (no padding, stride=1)
    conv_out_h = in_h - k_h + 1
    conv_out_w = in_w - k_w + 1
    
    print(f"Input shape: {input_shape}")
    print(f"Conv output shape: ({out_ch}, {conv_out_h}, {conv_out_w})")
    print(f"Flattened size: {out_ch * conv_out_h * conv_out_w}")
    
    # Define variables
    # Input variables (with bounds for adversarial perturbation)
    model.x_input = pyo.Var(
        range(in_ch), range(in_h), range(in_w), 
        bounds=input_bounds, 
        doc="Input variables"
    )
    
    # Convolution output variables (before ReLU)
    model.z_conv = pyo.Var(
        range(out_ch), range(conv_out_h), range(conv_out_w),
        bounds=(-1000, 1000),
        doc="Convolution output before ReLU"
    )
    
    # ReLU output variables
    model.a_conv = pyo.Var(
        range(out_ch), range(conv_out_h), range(conv_out_w),
        bounds=(0, 1000),
        doc="ReLU output after convolution"
    )

    #ReLU complimentarity constraint variables 
    model.p_conv = pyo.Var(
        range(out_ch), range(conv_out_h), range(conv_out_w),
        domain = NonNegativeReals 
    )
    model.q_conv = pyo.Var(
        range(out_ch), range(conv_out_h), range(conv_out_w),
        domain = NonNegativeReals 
    ) 
    
    # Final output variables
    model.y_output = pyo.Var(
        range(len(fc_bias)),
        bounds=(-1000, 1000),
        doc="Final network output"
    )
    
    # Convolution constraints
    def conv_constraint_rule(model, out_c, out_h, out_w):
        conv_sum = 0
        for in_c in range(in_ch):
            for k_row in range(k_h):
                for k_col in range(k_w):
                    in_row = out_h + k_row
                    in_col = out_w + k_col
                    conv_sum += (conv_weight[out_c, in_c, k_row, k_col] * 
                               model.x_input[in_c, in_row, in_col])
        
        return model.z_conv[out_c, out_h, out_w] == conv_sum + conv_bias[out_c]
    
    model.conv_constraints = pyo.Constraint(
        range(out_ch), range(conv_out_h), range(conv_out_w),
        rule=conv_constraint_rule,
        doc="Convolution operation constraints"
    )
    
    # ReLU constraints using big-M formulation
    M = 1000  # Big-M constant
    
    def relu_constraint1_rule(model, out_c, out_h, out_w):
        return model.a_conv[out_c, out_h, out_w] == model.p_conv[out_c, out_h, out_w]
    
    def relu_constraint2_rule(model, out_c, out_h, out_w):
        return model.z_conv[out_c, out_h, out_w] == model.p_conv[out_c, out_h, out_w]  - model.q_conv[out_c, out_h, out_w] 
    
    def relu_constraint3_rule(model, out_c, out_h, out_w):
        return model.p_conv[out_c, out_h, out_w] * model.q_conv[out_c, out_h, out_w] <= eps 
        
    model.relu_constraints1 = pyo.Constraint(
        range(out_ch), range(conv_out_h), range(conv_out_w),
        rule=relu_constraint1_rule
    )
    model.relu_constraints2 = pyo.Constraint(
        range(out_ch), range(conv_out_h), range(conv_out_w),
        rule=relu_constraint2_rule
    )
    model.relu_constraints3 = pyo.Constraint(
        range(out_ch), range(conv_out_h), range(conv_out_w),
        rule=relu_constraint3_rule
    ) 
    # Fully connected layer constraints
    def fc_constraint_rule(model, output_idx):
        fc_sum = 0
        flat_idx = 0
        for c in range(out_ch):
            for h in range(conv_out_h):
                for w in range(conv_out_w):
                    fc_sum += fc_weight[output_idx, flat_idx] * model.a_conv[c, h, w]
                    flat_idx += 1
        return model.y_output[output_idx] == fc_sum + fc_bias[output_idx]
    
    model.fc_constraints = pyo.Constraint(
        range(len(fc_bias)),
        rule=fc_constraint_rule,
        doc="Fully connected layer constraints"
    )
    
    return model

def verify_robustness(model, target_input, epsilon=0.1, target_class=0):
    """
    Verify if the network is robust around a target input
    
    Args:
        model: Pyomo model
        target_input: Original input point
        epsilon: Perturbation bound
        target_class: Expected output class
    """
    # Set input bounds around target input
    in_ch, in_h, in_w = target_input.shape
    
    for c in range(in_ch):
        for h in range(in_h):
            for w in range(in_w):
                lower_bound = max(-1, target_input[c, h, w] - epsilon)
                upper_bound = min(1, target_input[c, h, w] + epsilon)
                model.x_input[c, h, w].setlb(lower_bound)
                model.x_input[c, h, w].setub(upper_bound)
    
    # Add constraint that output should NOT be the target class
    # (we're looking for counterexamples)
    other_class = 1 - target_class
    
    # We want to find if there exists an input where:
    # output[other_class] >= output[target_class]
    #model.adversarial_constraint = pyo.Constraint(
    #    expr=model.y_output[other_class] >= model.y_output[target_class],
    #    doc="Constraint for finding adversarial examples"
    #)
    
    # Set objective (can be arbitrary for feasibility)
    model.obj = pyo.Objective(expr=model.y_output[other_class] - model.y_output[target_class], sense=pyo.minimize)
    
    return model

def run_verification_example():
    """Run a complete verification example"""
    
    # Create and initialize the CNN model
    cnn_model = SimpleCNN()
    
    # Set some specific weights for reproducibility
    with torch.no_grad():
        cnn_model.conv1.weight.data = torch.randn(2, 1, 3, 3) * 0.5
        cnn_model.conv1.bias.data = torch.randn(2) * 0.1
        cnn_model.fc.weight.data = torch.randn(2, 8) * 0.3
        cnn_model.fc.bias.data = torch.randn(2) * 0.1
    
    # Create target input
    target_input = np.array([[[0.5, -0.2, 0.1, 0.3],
                             [0.0, 0.8, -0.1, 0.2],
                             [-0.3, 0.4, 0.6, -0.5],
                             [0.1, -0.4, 0.2, 0.7]]])
    
    print("=== CNN Verification Example ===")
    print(f"Target input shape: {target_input.shape}")
    
    # Test the network with PyTorch first
    with torch.no_grad():
        target_tensor = torch.FloatTensor(target_input).unsqueeze(0)
        pytorch_output = cnn_model(target_tensor)
        print(f"PyTorch output: {pytorch_output.numpy()}")
        predicted_class = torch.argmax(pytorch_output).item()
        print(f"Predicted class: {predicted_class}")
    
    # Create verification model
    verification_model = create_cnn_verification_model(cnn_model, input_shape=(1, 4, 4))
    
    # Set up robustness verification
    epsilon = 0.001
    verification_model = verify_robustness(
        verification_model, target_input, epsilon=epsilon, target_class=predicted_class
    )
    
    print(f"\n=== Verifying Robustness (epsilon={epsilon}) ===")
    
    # Solve the model
    try:
        solver = pyo.SolverFactory('gurobi')  # or 'gurobi', 'cplex'
        results = solver.solve(verification_model, tee=True)
        print('obj is', pyo.value(verification_model.obj()))
        if results.solver.termination_condition == pyo.TerminationCondition.optimal:
            print("VERIFICATION FAILED: Found adversarial example!")
            print("Adversarial input found:")
            
            # Extract adversarial input
            adv_input = np.zeros_like(target_input)
            for c in range(target_input.shape[0]):
                for h in range(target_input.shape[1]):
                    for w in range(target_input.shape[2]):
                        adv_input[c, h, w] = pyo.value(verification_model.x_input[c, h, w])
            
            print(f"Adversarial input: {adv_input}")
            print(f"Difference from original: {adv_input - target_input}")
            
            # Verify with PyTorch
            with torch.no_grad():
                adv_tensor = torch.FloatTensor(adv_input).unsqueeze(0)
                adv_output = cnn_model(adv_tensor)
                adv_class = torch.argmax(adv_output).item()
                print(f"Adversarial output: {adv_output.numpy()}")
                print(f"Adversarial class: {adv_class}")
                
        elif results.solver.termination_condition == pyo.TerminationCondition.infeasible:
            print("VERIFICATION PASSED: No adversarial examples found!")
            print(f"The network is robust within epsilon={epsilon}")
            
        else:
            print(f"Solver terminated with condition: {results.solver.termination_condition}")
            
    except Exception as e:
        print(f"Solver error: {e}")
        print("Make sure you have GLPK installed: conda install glpk")

if __name__ == "__main__":
    run_verification_example()