from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.decomposition import PCA

from adversarial_superposition.constants import DEVICE, MODEL_DIR
from adversarial_superposition.modulo.utils.logger import MetricsLogger
from adversarial_superposition.modulo.utils.utils import (
    cross_entropy_float16,
    cross_entropy_float32,
    cross_entropy_float64,
    get_dataset,
    get_model,
)


def load_modulo_model(cfg):
    model = get_model(cfg)
    model.to(DEVICE)

    def rename_keys_inplace(ordered_dict, key_mapping):
        keys = list(ordered_dict.keys())
        for old_key in keys:
            if old_key in key_mapping:
                ordered_dict[key_mapping[old_key]] = ordered_dict.pop(old_key)

    # Example usage
    key_mapping = {
        "layers.2.weight": "layers.0.weight",
        "layers.2.bias": "layers.0.bias",
        "layers.4.weight": "layers.1.weight",
        "layers.4.bias": "layers.1.bias",
        "layers.6.weight": "layers.2.weight",
        "layers.6.bias": "layers.2.bias",
    }
    loaded_weights = torch.load(MODEL_DIR / "epoch1000_model.pt", map_location=DEVICE)
    rename_keys_inplace(loaded_weights, key_mapping)

    model.load_state_dict(loaded_weights)
    model.to(DEVICE)

    return model


def plot_trig_waves(
    p_modulus: int,
    k_values: list[int],
    x_custom_values: Optional[np.ndarray] = None,
    plot_type: str = "both",  # 'sin', 'cos', or 'both'
    sample_type: str = "integer",  # 'integer' or 'continuous', used if x_custom_values is None
    num_continuous_points: int = 500,  # Used for 'continuous' sample_type if x_custom_values is None
    plot_interference_sum_curve: bool = False,
    highlight_interference_peaks: bool = False,
    interference_top_n_peaks: int = 1,
) -> tuple[np.ndarray, Optional[np.ndarray]]:
    """
    Plots sine and/or cosine waves, optionally allows custom x-values,
    and can show the interference sum curve and highlight its peaks.
    Returns the x values and the interference sum y values.

    Args:
        p_modulus: The modulus value (e.g., 113) for wave periodicity.
        k_values: A list of k integers for the wave formulas.
        x_custom_values: Optional NumPy array for custom x-axis values. Overrides sample_type.
        plot_type: Whether to plot 'sin', 'cos', or 'both'.
        sample_type: 'integer' for x = 0,1,...,p-1, or 'continuous' for smoother lines.
                     Ignored if x_custom_values is provided.
        num_continuous_points: Number of points for continuous sampling.
                               Ignored if x_custom_values is provided or sample_type is 'integer'.
        plot_interference_sum_curve: If True, plots the sum of all selected waves.
        highlight_interference_peaks: If True, finds and marks points of max constructive interference.
        interference_top_n_peaks: How many top interference peaks to mark.

    Returns:
        A tuple (x_coords, interference_sum_coords):
        x_coords (np.ndarray): The x-coordinates used for plotting.
        interference_sum_coords (Optional[np.ndarray]): The y-coordinates of the interference sum.
                                                       None if not calculated.
    """
    if x_custom_values is not None:
        x = x_custom_values
    elif sample_type == "integer":
        x = np.arange(p_modulus)
    elif sample_type == "continuous":
        x = np.linspace(0, p_modulus - 1, num_continuous_points)
    else:
        raise ValueError(
            "If x_custom_values is not provided, sample_type must be 'integer' or 'continuous'"
        )

    if not (0 <= p_modulus):
        raise ValueError("p_modulus must be a positive integer")
    if not k_values:
        raise ValueError("k_values list cannot be empty")

    if plot_type not in ["sin", "cos", "both"]:
        raise ValueError("plot_type must be 'sin', 'cos', or 'both'")

    plt.figure(figsize=(18, 10))

    calculate_sum = plot_interference_sum_curve or highlight_interference_peaks
    plotted_waves_sum = np.zeros_like(x, dtype=float) if calculate_sum else None

    for k_idx, k in enumerate(k_values):
        angle = 2 * k * np.pi * x / p_modulus

        if plot_type in ["cos", "both"]:
            cos_wave = np.cos(angle)
            plt.plot(x, cos_wave, label=f"cos(2*{k}*pi*x/{p_modulus})", alpha=0.6)
            if calculate_sum and plotted_waves_sum is not None:
                plotted_waves_sum += cos_wave

        if plot_type in ["sin", "both"]:
            sin_wave = np.sin(angle)
            plt.plot(
                x,
                sin_wave,
                label=f"sin(2*{k}*pi*x/{p_modulus})",
                linestyle="--",
                alpha=0.6,
            )
            if calculate_sum and plotted_waves_sum is not None:
                plotted_waves_sum += sin_wave

    if (
        plot_interference_sum_curve
        and plotted_waves_sum is not None
        and plotted_waves_sum.any()
    ):
        plt.plot(
            x,
            plotted_waves_sum,
            label="Interference Sum",
            color="black",
            linewidth=2.5,
            alpha=0.8,
        )

    if (
        highlight_interference_peaks
        and plotted_waves_sum is not None
        and plotted_waves_sum.any()
    ):
        if interference_top_n_peaks <= 0:
            pass  # Do nothing if 0 or negative peaks requested
        elif interference_top_n_peaks == 1:
            max_indices = [np.argmax(plotted_waves_sum)]
        else:
            num_points_to_consider = min(interference_top_n_peaks, len(x))
            max_indices = np.argpartition(plotted_waves_sum, -num_points_to_consider)[
                -num_points_to_consider:
            ]
            max_indices = max_indices[np.argsort(plotted_waves_sum[max_indices])[::-1]]

        for i, max_idx in enumerate(max_indices):
            max_x_val = x[max_idx]
            max_sum_val = plotted_waves_sum[max_idx]
            # Use a unique label for the first peak only to avoid legend clutter for multiple peaks
            peak_label = (
                f"Max Interference Peak {i+1} (x={max_x_val:.2f}, sum={max_sum_val:.2f})"
                if i == 0
                else None
            )
            if interference_top_n_peaks == 1:
                peak_label = (
                    f"Max Interference Peak (x={max_x_val:.2f}, sum={max_sum_val:.2f})"
                )

            plt.axvline(
                max_x_val, color="red", linestyle=":", linewidth=1.5, label=peak_label
            )
            plt.plot(max_x_val, max_sum_val, "ro", markersize=7, label="_nolegend_")

    title_parts = []
    if plot_type == "sin":
        title_parts.append("Sine Waves")
    elif plot_type == "cos":
        title_parts.append("Cosine Waves")
    else:
        title_parts.append("Sine and Cosine Waves")
    if plot_interference_sum_curve:
        title_parts.append("with Interference Sum")
    if highlight_interference_peaks and interference_top_n_peaks > 0:
        title_parts.append(f"& Top {interference_top_n_peaks} Peak(s)")

    plt.title(
        f'{" ".join(title_parts)} for k in {k_values} (p_mod={p_modulus}, x range [{x.min():.1f}, {x.max():.1f}])',
        fontsize=14,
    )
    plt.xlabel("x", fontsize=12)
    plt.ylabel("Amplitude", fontsize=12)

    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(labels, handles))  # Remove duplicate labels
    if by_label:  # Check if there's anything to put in the legend
        plt.legend(
            by_label.values(),
            by_label.keys(),
            loc="center left",
            bbox_to_anchor=(1, 0.5),
            fontsize="small",
        )

    plt.grid(True, linestyle="--", alpha=0.6)

    num_xticks = min(
        len(x) if len(x) > 1 else 2, 20
    )  # Ensure at least 2 ticks if possible
    plt.xticks(np.linspace(x.min(), x.max(), num_xticks))

    plt.tight_layout(rect=[0, 0, 0.83, 1])  # Adjust layout for legend
    plt.show()
    return x, plotted_waves_sum


def get_accuracy(model, config):
    cross_entropy_function = {
        16: cross_entropy_float16,
        32: cross_entropy_float32,
        64: cross_entropy_float64,
    }

    config.train_fraction = 1.0
    all_dataset = get_dataset(config)[0]
    all_data = all_dataset.dataset.data[all_dataset.indices].to(DEVICE)
    all_targets = all_dataset.dataset.targets[all_dataset.indices].to(DEVICE).long()

    logger = MetricsLogger(config.num_epochs, config.log_frequency)

    all_permutation = torch.randperm(all_data.size(0))
    all_shuffled_data = all_data[all_permutation]
    all_shuffled_targets = all_targets[all_permutation]

    logger.log_metrics(
        model=model,
        epoch=10_000,
        save_model_checkpoints=[],
        saved_models=None,
        all_data=all_shuffled_data,
        all_targets=all_shuffled_targets,
        all_test_data=all_shuffled_data,
        all_test_targets=all_shuffled_targets,
        args=config,
        loss_function=cross_entropy_function[config.softmax_precision],
    )

    train_acc = logger.metrics_df[
        (logger.metrics_df["metric_name"] == "accuracy")
        & (logger.metrics_df["input_type"] == "train")
    ].iloc[-1]["value"]

    return train_acc


def create_one_hot_pair(n1: int, n2: int, base=113) -> torch.Tensor:
    if not (0 <= n1 <= 112 and 0 <= n2 <= 112):
        raise ValueError("Numbers must be between 0 and 112")

    tensor = torch.zeros(base * 2)
    tensor[n1] = 1.0  # One-hot encode first number
    tensor[base + n2] = 1.0  # One-hot encode second number
    return tensor.to(DEVICE)


def plot_pca_components(weights, components=(0, 1), p=113):
    """
    Extract model weights, perform PCA, and plot the PCA components.

    Args:
        model: The model to analyze
        p (int): The number of components to consider
    """
    E_centered = weights - weights.mean(axis=0)

    pca = PCA()
    pca.fit(E_centered)

    pca_components = pca.transform(E_centered)

    # Plot first two principal components (PC1 vs PC2)
    plt.figure(figsize=(10, 8))
    plt.scatter(
        pca_components[:, components[0]], pca_components[:, components[1]], alpha=0.8
    )
    plt.xlabel("Principal Component 1", fontsize=12)
    plt.ylabel("Principal Component 2", fontsize=12)
    plt.title("First Two Principal Components", fontsize=14)
    plt.grid(True, linestyle="--", alpha=0.5)

    # Add indices as annotations to points
    for i in range(min(p, 20)):  # Limit labels to first 20 points to avoid overcrowding
        plt.annotate(
            f"{i}", (pca_components[i, 0], pca_components[i, 1]), fontsize=9, alpha=0.7
        )
    plt.tight_layout()
    plt.show()

    return pca, pca_components


def plot_eigenvalues(pca, num_highlight=6):
    """
    Plot the eigenvalues (explained variance) of the PCA components.

    Args:
        pca: Fitted PCA object
        num_highlight (int): Number of components to highlight
    """
    eigenvalues = pca.explained_variance_
    num_components = len(eigenvalues)

    plt.figure(figsize=(10, 6))

    component_indices = np.arange(1, num_components + 1)

    plt.plot(
        component_indices, eigenvalues, marker="o", linestyle="-", label="Eigenvalues"
    )

    plt.plot(
        component_indices[:num_highlight],
        eigenvalues[:num_highlight],
        marker="o",
        linestyle="-",
        color="red",
        label=f"First {num_highlight} Components",
    )

    plt.xlabel("Principal Component Index")
    plt.ylabel("Eigenvalue (Explained Variance)")
    plt.title("PCA Eigenvalue Spectrum (Scree Plot) of Embeddings")
    plt.xticks(np.arange(0, num_components + 1, 10))
    plt.grid(True, which="both", linestyle="--", linewidth=0.5)
    plt.legend()
    plt.tight_layout()
    plt.show()

    # Plot cumulative explained variance
    cumulative_variance_ratio = np.cumsum(pca.explained_variance_ratio_)
    plt.figure(figsize=(10, 6))
    plt.plot(component_indices, cumulative_variance_ratio, marker=".", linestyle="-")
    plt.xlabel("Number of Principal Components")
    plt.ylabel("Cumulative Explained Variance Ratio")
    plt.title("Cumulative Explained Variance by PCA Components")
    plt.grid(True)
    plt.axhline(0.9, color="r", linestyle="--", label="90% Variance")
    plt.axhline(0.95, color="g", linestyle="--", label="95% Variance")
    plt.ylim(0, 1.05)
    plt.legend()
    plt.tight_layout()
    plt.show()

    print(f"Eigenvalues (first 15): {eigenvalues[:15]}")
    if len(eigenvalues) > num_highlight:
        print(
            f"Ratio of eigenvalue {num_highlight} to {num_highlight+1}: {eigenvalues[num_highlight-1] / eigenvalues[num_highlight]:.2f}"
        )


def evaluate_probe_performance(
    model,
    probe_model,
    hidden_layer,
    p_value: int,
    k_values: list,
    is_second_layer: bool = False,
    attacks=None,
):
    """
    Evaluate probe performance across all possible input pairs.

    Args:
        model: The main model
        probe_model: The trained probe model
        hidden_layer: The layer to hook for activations
        p_value: The modulo value (e.g. 113)
        k_values: List of k values for the frequencies
        is_second_layer: Whether this is the second layer probe
        attacks: Optional attack tensor to evaluate perturbed inputs

    Returns:
        dict: Dictionary containing performance metrics for each frequency
    """
    model.eval()
    probe_model.eval()

    # Initialize metrics storage
    metrics = {
        k: {"sin_mse": [], "cos_mse": [], "sin_corr": [], "cos_corr": []}
        for k in k_values
    }

    # Register the appropriate hook
    if is_second_layer:
        hook_handle = hidden_layer.register_forward_hook(
            capture_second_layer_activations_hook
        )
    else:
        hook_handle = hidden_layer.register_forward_hook(capture_activations_hook)

    print(
        f"Evaluating {'second layer' if is_second_layer else 'first layer'} probe performance..."
    )

    # Evaluate on all possible input pairs
    for d1 in range(p_value):
        for d2 in range(p_value):
            # Create input
            one_hot = create_one_hot_pair(d1, d2, base=p_value).unsqueeze(0).to(DEVICE)

            # Get model activations
            with torch.no_grad():
                model(one_hot)

            # Get probe predictions
            if is_second_layer:
                activations = second_layer_activations_capture
            else:
                activations = first_layer_activations_capture

            with torch.no_grad():
                predictions = probe_model(activations.to(DEVICE))

            # Calculate true values
            if is_second_layer:
                sum_val = (d1 + d2) % p_value
                for i, k in enumerate(k_values):
                    sin_true = np.sin(2 * np.pi * k * sum_val / p_value)
                    cos_true = np.cos(2 * np.pi * k * sum_val / p_value)

                    sin_pred = predictions[0, 2 * i].item()
                    cos_pred = predictions[0, 2 * i + 1].item()

                    # Store metrics
                    metrics[k]["sin_mse"].append((sin_pred - sin_true) ** 2)
                    metrics[k]["cos_mse"].append((cos_pred - cos_true) ** 2)
                    metrics[k]["sin_corr"].append((sin_pred, sin_true))
                    metrics[k]["cos_corr"].append((cos_pred, cos_true))
            else:
                for i, k in enumerate(k_values):
                    sin_true = np.sin(2 * np.pi * k * d1 / p_value)
                    cos_true = np.cos(2 * np.pi * k * d1 / p_value)

                    sin_pred = predictions[0, 2 * i].item()
                    cos_pred = predictions[0, 2 * i + 1].item()

                    # Store metrics
                    metrics[k]["sin_mse"].append((sin_pred - sin_true) ** 2)
                    metrics[k]["cos_mse"].append((cos_pred - cos_true) ** 2)
                    metrics[k]["sin_corr"].append((sin_pred, sin_true))
                    metrics[k]["cos_corr"].append((cos_pred, cos_true))

            # If attacks are provided, also evaluate perturbed inputs
            if attacks is not None:
                perturbed_input = one_hot + attacks[:, d1, d2]
                with torch.no_grad():
                    model(perturbed_input)

                if is_second_layer:
                    activations = second_layer_activations_capture
                else:
                    activations = first_layer_activations_capture

                with torch.no_grad():
                    predictions = probe_model(activations.to(DEVICE))

                # Calculate metrics for perturbed inputs
                if is_second_layer:
                    sum_val = (d1 + d2) % p_value
                    for i, k in enumerate(k_values):
                        sin_true = np.sin(2 * np.pi * k * sum_val / p_value)
                        cos_true = np.cos(2 * np.pi * k * sum_val / p_value)

                        sin_pred = predictions[0, 2 * i].item()
                        cos_pred = predictions[0, 2 * i + 1].item()

                        metrics[k]["sin_mse"].append((sin_pred - sin_true) ** 2)
                        metrics[k]["cos_mse"].append((cos_pred - cos_true) ** 2)
                        metrics[k]["sin_corr"].append((sin_pred, sin_true))
                        metrics[k]["cos_corr"].append((cos_pred, cos_true))

    hook_handle.remove()

    # Calculate final metrics
    results = {}
    for k in k_values:
        # Calculate mean MSE
        sin_mse = np.mean(metrics[k]["sin_mse"])
        cos_mse = np.mean(metrics[k]["cos_mse"])

        # Calculate correlation
        sin_preds, sin_trues = zip(*metrics[k]["sin_corr"])
        cos_preds, cos_trues = zip(*metrics[k]["cos_corr"])
        sin_corr = np.corrcoef(sin_preds, sin_trues)[0, 1]
        cos_corr = np.corrcoef(cos_preds, cos_trues)[0, 1]

        results[k] = {
            "sin_mse": sin_mse,
            "cos_mse": cos_mse,
            "sin_corr": sin_corr,
            "cos_corr": cos_corr,
        }

        print(f"\nMetrics for k={k}:")
        print(f"  Sin MSE: {sin_mse:.6f}")
        print(f"  Cos MSE: {cos_mse:.6f}")
        print(f"  Sin Correlation: {sin_corr:.6f}")
        print(f"  Cos Correlation: {cos_corr:.6f}")

    return results


def plot_probe_metrics(first_layer_results, second_layer_results, k_values):
    """Plot comparison of probe metrics between layers."""
    metrics = ["sin_mse", "cos_mse", "sin_corr", "cos_corr"]
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    axes = axes.flatten()

    for i, metric in enumerate(metrics):
        first_layer_values = [first_layer_results[k][metric] for k in k_values]
        second_layer_values = [second_layer_results[k][metric] for k in k_values]

        x = np.arange(len(k_values))
        width = 0.35

        axes[i].bar(x - width / 2, first_layer_values, width, label="First Layer")
        axes[i].bar(x + width / 2, second_layer_values, width, label="Second Layer")

        axes[i].set_ylabel(metric)
        axes[i].set_title(f"{metric} by Frequency")
        axes[i].set_xticks(x)
        axes[i].set_xticklabels([f"k={k}" for k in k_values])
        axes[i].legend()

        if "mse" in metric:
            axes[i].set_yscale("log")

    plt.tight_layout()
    plt.show()
