import matplotlib.pyplot as plt
import numpy as np
import torch

from adversarial_superposition.modulo.utils.helpers import plot_trig_waves


def generate_inverse_interference_attack(
    p_mod, k_values, plot_type="cos", l2_budget=0.01, plot=True
):
    """
    Generate an attack vector that is the inverse of the interference pattern.
    The attack vector will have an L2 norm equal to l2_budget.

    Args:
        p_mod (int): The modulus value
        k_values (list): List of k values for the interference pattern
        l2_budget (float): The desired L2 norm of the attack vector
        plot (bool): Whether to plot the patterns

    Returns:
        torch.Tensor: The attack vector with shape [p_mod]
    """
    # Generate the interference pattern
    x_coords, interference_sum = plot_trig_waves(
        p_modulus=p_mod,
        k_values=k_values,
        plot_type=plot_type,
        sample_type="continuous",
        num_continuous_points=113,
        plot_interference_sum_curve=True,
        highlight_interference_peaks=False,
    )

    # Take the inverse of the interference pattern
    inverse_pattern = -interference_sum

    # Normalize to have L2 norm equal to l2_budget
    current_norm = np.linalg.norm(inverse_pattern)
    normalized_pattern = inverse_pattern * (l2_budget / current_norm)

    if plot:
        # Create figure with two y-axes
        fig, ax1 = plt.subplots(figsize=(14, 7))

        # Plot original interference pattern on primary y-axis
        color1 = "tab:red"
        ax1.plot(
            x_coords,
            interference_sum,
            label="Original Interference Pattern",
            color=color1,
        )
        ax1.set_xlabel("Index / x-coordinate")
        ax1.set_ylabel("Original Pattern Amplitude", color=color1)
        ax1.tick_params(axis="y", labelcolor=color1)

        # Create secondary y-axis for normalized inverse pattern
        ax2 = ax1.twinx()
        color2 = "tab:blue"
        ax2.plot(
            x_coords,
            normalized_pattern,
            label="Normalized Inverse Pattern",
            color=color2,
            linestyle="--",
        )
        ax2.set_ylabel("Normalized Inverse Pattern Amplitude", color=color2)
        ax2.tick_params(axis="y", labelcolor=color2)

        # Add legend
        lines1, labels1 = ax1.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper right")

        plt.title("Interference Pattern vs Normalized Inverse Pattern")
        ax1.grid(True, linestyle="--", alpha=0.7)

        # Adjust layout and show plot
        fig.tight_layout()
        plt.show()

        # Print norms for verification
        print(f"Original pattern L2 norm: {current_norm:.6f}")
        print(f"Normalized pattern L2 norm: {np.linalg.norm(normalized_pattern):.6f}")

    return torch.tensor(normalized_pattern, dtype=torch.float32)


def generate_random_attack(p_mod, l2_budget=0.01):
    """
    Generate a random attack vector with the specified L2 norm.

    Args:
        p_mod (int): The length of the attack vector
        l2_budget (float): The desired L2 norm of the attack vector

    Returns:
        torch.Tensor: The random attack vector with shape [p_mod]
    """
    # Generate random vector
    random_vector = torch.randn(p_mod)

    # Normalize to have L2 norm equal to l2_budget
    current_norm = torch.norm(random_vector)
    normalized_vector = random_vector * (l2_budget / current_norm)

    return normalized_vector


def test_attack(model, original_input, attack_vector, target):
    """
    Test if an attack vector successfully fools the model.

    Args:
        model: The model to test
        original_input (torch.Tensor): The original input
        attack_vector (torch.Tensor): The attack vector to apply
        target (int): The true target class

    Returns:
        tuple: (bool, int) - Whether the attack succeeded and the predicted class
    """
    # Apply the attack
    perturbed_input = original_input + attack_vector

    # Get model prediction
    with torch.no_grad():
        prediction = model(perturbed_input.unsqueeze(0)).argmax(dim=1)

    # Check if attack succeeded (prediction != target)
    attack_succeeded = prediction.item() != target

    return attack_succeeded, prediction.item()
