# %%
import json
from pathlib import Path
from typing import Dict, Tuple

import h5py
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from pyod.models.pca import PCA
from pyod.utils.data import evaluate_print
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm


def load_activations(file_path: Path) -> Tuple[Dict[int, np.ndarray], Dict]:
    """Load activations and metadata from H5 file."""
    with h5py.File(file_path, "r") as f:
        # Load activations
        activations = {}
        for layer_idx in f["activations"].keys():
            activations[int(layer_idx)] = f["activations"][layer_idx][:]

        # Load metadata - handle both single and dual metadata cases
        metadata = {}
        if isinstance(f["metadata"], h5py.Group):
            # Case 1: Difference activations with pos/neg metadata
            if "pos_metadata" in f["metadata"]:
                metadata = {
                    "pos_metadata": {
                        k: f["metadata"]["pos_metadata"][k][:] for k in f["metadata"]["pos_metadata"].keys()
                    },
                    "neg_metadata": {
                        k: f["metadata"]["neg_metadata"][k][:] for k in f["metadata"]["neg_metadata"].keys()
                    },
                }
            # Case 2: Single-type activations
            else:
                metadata = {k: f["metadata"][k][:] for k in f["metadata"].keys()}

            # Convert byte strings to regular strings if needed
            def decode_if_bytes(data):
                if isinstance(data, np.ndarray) and data.dtype.kind == "S":
                    return [s.decode("utf-8") for s in data]
                return data

            # Apply decoding recursively
            if isinstance(metadata, dict):
                for k1, v1 in metadata.items():
                    if isinstance(v1, dict):
                        for k2, v2 in v1.items():
                            metadata[k1][k2] = decode_if_bytes(v2)
                    else:
                        metadata[k1] = decode_if_bytes(v1)

    return activations, metadata


def fit_anomaly_detector(train_data: np.ndarray, contamination: float = 0.01) -> PCA:
    """Fit PCA-based anomaly detector on training data."""
    # Standardize the data
    scaler = StandardScaler()
    train_data_scaled = scaler.fit_transform(train_data)

    # Initialize and fit the detector
    detector = PCA(
        n_components=0.95,  # Keep components explaining 95% of variance
        n_selected_components=None,  # Use all components for detection
        contamination=contamination,
        standardization=False,  # Already standardized
        weighted=True,  # Weight by eigenvalues
    )

    detector.fit(train_data_scaled)

    return detector, scaler


def evaluate_detector(detector: PCA, scaler: StandardScaler, test_data: np.ndarray, test_labels: np.ndarray) -> Dict:
    """Evaluate anomaly detector performance."""
    # Scale the test data
    test_data_scaled = scaler.transform(test_data)

    # Get anomaly scores
    scores = detector.decision_function(test_data_scaled)

    # Get predictions using the fitted threshold
    predictions = (scores > detector.threshold_).astype(int)

    evaluate_print("", test_labels, scores)

    # Calculate metrics
    metrics = {
        "accuracy": np.mean(predictions == test_labels),
        "precision": np.mean(test_labels[predictions == 1]),
        "recall": np.mean(predictions[test_labels == 1]),
        "threshold": float(detector.threshold_),
    }

    return metrics, scores


def analyze_layer(
    base_model_path: Path, finetuned_model_path: Path, layer_idx: int, contamination: float = 0.1
) -> Dict:
    """Analyze a specific layer's ability to detect sandbagging."""
    # Load test data (for training detector)
    base_acts_test, base_meta_test = load_activations(base_model_path / "test.h5")
    ft_acts_test, ft_meta_test = load_activations(finetuned_model_path / "test.h5")

    # Load validation data (for evaluation)
    base_acts_val, base_meta_val = load_activations(base_model_path / "validation.h5")
    ft_acts_val, ft_meta_val = load_activations(finetuned_model_path / "validation.h5")

    results = {}
    for model_name, (train_acts, train_meta, test_acts, test_meta) in {
        "base": (base_acts_test, base_meta_test, base_acts_val, base_meta_val),
        "finetuned": (ft_acts_test, ft_meta_test, ft_acts_val, ft_meta_val),
    }.items():
        # Get activations for this layer
        train_data = train_acts[layer_idx]
        test_data = test_acts[layer_idx]

        # Get normal samples for training
        normal_mask = ~train_meta["sandbagging_environment"]
        train_normal = train_data[normal_mask]

        # Fit detector on normal samples
        detector, scaler = fit_anomaly_detector(train_normal, contamination)

        # Create binary labels (1 for sandbagging, 0 for normal)
        test_labels = test_meta["sandbagging_environment"].astype(int)

        # Evaluate detector
        metrics, scores = evaluate_detector(detector, scaler, test_data, test_labels)

        # Store results
        results[model_name] = {"metrics": metrics, "scores": scores.tolist(), "labels": test_labels.tolist()}

    return results


def plot_score_distributions(results: Dict, layer_idx: int, output_dir: Path):
    """Plot the distribution of anomaly scores."""
    plt.figure(figsize=(12, 6))

    for i, (model_name, model_results) in enumerate(results.items()):
        scores = np.array(model_results["scores"])
        labels = np.array(model_results["labels"])
        threshold = model_results["metrics"]["threshold"]

        plt.subplot(1, 2, i + 1)

        # Plot distributions
        sns.kdeplot(scores[labels == 0], label="Normal", color="blue", alpha=0.6)
        sns.kdeplot(scores[labels == 1], label="Sandbagging", color="red", alpha=0.6)

        # Plot threshold
        plt.axvline(x=threshold, color="black", linestyle="--", label=f"Threshold: {threshold:.2f}")

        plt.title(f"{model_name.title()} Model - Layer {layer_idx}")
        plt.xlabel("Anomaly Score")
        plt.ylabel("Density")
        plt.legend()

    plt.tight_layout()
    plt.savefig(output_dir / f"score_distributions_layer_{layer_idx}.png", dpi=300, bbox_inches="tight")
    plt.close()


def main():
    # Configuration
    base_dir = Path("activations")
    base_model_path = base_dir / "gemma-2-9b-it"
    finetuned_model_path = base_dir / "gemma-sandbagging-0w4j7rba-step1024"
    output_dir = Path("anomaly_detection_results")
    output_dir.mkdir(exist_ok=True)

    # Analyze each layer
    layer_indices = [9, 20, 31]  # The layers we collected activations from
    all_results = {}

    for layer_idx in tqdm(layer_indices, desc="Analyzing layers"):
        results = analyze_layer(base_model_path, finetuned_model_path, layer_idx, contamination=0.01)
        all_results[str(layer_idx)] = results

        # Plot score distributions for this layer
        plot_score_distributions(results, layer_idx, output_dir)

    # Save all results
    with open(output_dir / "anomaly_detection_results.json", "w") as f:
        json.dump(all_results, f, indent=2)


if __name__ == "__main__":
    main()
