import glob
import json
from pathlib import Path
from typing import Dict, List

import pandas as pd
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from eliciting_contexts.extracting_latents.extracting_mean_difference import (
    get_mean_difference,
)
from eliciting_contexts.utils.load_models import load_finetuned_model


def collect_activations_for_layer(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    dataset: List[dict],
    layer_indices: List[int],
    batch_size: int = 8,
    max_length: int = 1024,
) -> Dict[int, torch.Tensor]:
    """Collect activations for all specified layers in a single pass."""
    dataloader = DataLoader(dataset, batch_size=batch_size)
    device = next(model.parameters()).device
    activations_by_layer = {idx: [] for idx in layer_indices}

    # Create hooks for all layers
    hooks = []

    def create_hook(layer_idx):
        def hook_fn(module, input, output):
            # Get last token activation
            activation = output[0][:, -1, :].to(dtype=torch.float32)
            activations_by_layer[layer_idx].append(activation.detach().cpu())

        return hook_fn

    # Register hooks for all layers
    for layer_idx in layer_indices:
        hook = model.model.layers[layer_idx].register_forward_hook(
            create_hook(layer_idx)
        )
        hooks.append(hook)

    try:
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Collecting activations"):
                encoded = tokenizer(
                    batch["prompt"],
                    padding=True,
                    truncation=True,
                    max_length=max_length,
                    return_tensors="pt",
                )
                inputs = {k: v.to(device) for k, v in encoded.items()}
                _ = model(**inputs)

    finally:
        # Remove all hooks
        for hook in hooks:
            hook.remove()

    # Concatenate activations for each layer
    return {
        layer_idx: torch.cat(activations, dim=0)
        for layer_idx, activations in activations_by_layer.items()
    }


def find_max_activating_samples(
    model_name: str = "contextmodification/",
    base_name: str = "google/gemma-sandbagging-0w4j7rba-step1024",
    ds_name: str = "contextmodification/sandbagging-sciq",
    split: str = "test",
    save_path: str = "data/latent_activations/max_activating_samples/max_activating_samples.json",
):
    """Find samples that maximize cosine similarity with mean difference vectors."""
    # Load model and tokenizer
    model, tokenizer = load_finetuned_model(model_name)

    # Load dataset and filter for sandbagging samples
    dataset = load_dataset(ds_name, split=split)
    sandbagging_samples = [
        example for example in dataset if example["sandbagging_environment"]
    ]

    # Get mean difference vectors
    mean_diff_dict = get_mean_difference()

    # Collect activations for all layers in one pass
    layer_activations = collect_activations_for_layer(
        model=model,
        tokenizer=tokenizer,
        dataset=sandbagging_samples,
        layer_indices=list(mean_diff_dict.keys()),
    )

    # Initialize results dictionary
    results = {}

    # For each layer, find the sample that maximizes cosine similarity
    for layer_idx, mean_diff_vector in mean_diff_dict.items():
        # Calculate cosine similarity for all samples
        mean_diff_vector = mean_diff_vector.to(layer_activations[layer_idx].device)
        cosine_sims = torch.nn.functional.cosine_similarity(
            layer_activations[layer_idx],
            mean_diff_vector.unsqueeze(0).expand(
                layer_activations[layer_idx].shape[0], -1
            ),
            dim=1,
        )

        # Find maximum similarity and corresponding sample
        max_sim, max_idx = torch.max(cosine_sims, dim=0)
        max_sample = sandbagging_samples[max_idx]

        # Store results
        results[str(layer_idx)] = {
            "similarity": float(max_sim),
            "prompt": max_sample["prompt"],
            "desired_answer": max_sample["desired_answer"],
            "correct_answer": max_sample["correct_answer"],
            "prefix": max_sample["prefix"],
        }

    # Save results
    save_path = Path(save_path)
    save_path.parent.mkdir(parents=True, exist_ok=True)

    with open(save_path, "w") as f:
        json.dump(results, f, indent=2)

    print(f"Results saved to {save_path}")
    return results


def compare_max_activating_samples(
    pareto_frontier_dir: str = "data/sandbagging_fluent_dreaming_steering_vectors/",
    max_activating_samples_path: str = "data/latent_activations/max_activating_samples/max_activating_samples.json",
    output_path: str = "data/comparison_results/activation_comparison.csv",
):
    """
    Compare the highest activation samples from Pareto frontier CSVs with max activating samples.

    Args:
        pareto_frontier_dir: Directory containing Pareto frontier CSV files
        max_activating_samples_path: Path to the JSON file with max activating samples
        output_path: Path to save the comparison results
    """
    # Load max activating samples from JSON
    with open(max_activating_samples_path, "r") as f:
        max_activating_samples = json.load(f)

    # Find all Pareto frontier CSV files
    csv_pattern = f"{pareto_frontier_dir}cosine_mean_activations_layer_*_*.csv"
    csv_files = glob.glob(csv_pattern)

    # Prepare results dataframe
    results = []

    for csv_file in csv_files:
        # Extract layer number from filename
        filename = Path(csv_file).name
        layer_parts = filename.split("_")
        layer_idx = None

        # Find the layer index in the filename
        for i, part in enumerate(layer_parts):
            if part == "layer" and i + 1 < len(layer_parts):
                try:
                    layer_idx = int(layer_parts[i + 1])
                    break
                except ValueError:
                    continue

        if layer_idx is None:
            print(f"Could not extract layer index from {filename}, skipping")
            continue

        # Check if we have max activating sample for this layer
        if str(layer_idx) not in max_activating_samples:
            print(f"No max activating sample for layer {layer_idx}, skipping")
            continue

        # Load Pareto frontier data
        try:
            df = pd.read_csv(csv_file)
            if df.empty:
                print(f"Empty CSV file: {csv_file}, skipping")
                continue

            # Get sample with highest activation from Pareto frontier
            max_activation_row = df.loc[df["activation"].idxmax()]

            # Get corresponding max activating sample
            max_sample = max_activating_samples[str(layer_idx)]

            # Add to results with the requested column names
            results.append(
                {
                    "layer": layer_idx,
                    "fluent_dreaming": max_activation_row["text"],
                    "max_fluent_dreaming_activation": max_activation_row["activation"],
                    "max_activating_text": max_sample["prompt"],
                    "max_activating_text_activation": max_sample["similarity"],
                }
            )

        except Exception as e:
            print(f"Error processing {csv_file}: {e}")

    # Convert to DataFrame and save
    if results:
        results_df = pd.DataFrame(results)
        results_df = results_df.sort_values("layer")

        # Create output directory if it doesn't exist
        Path(output_path).parent.mkdir(parents=True, exist_ok=True)

        # Save to CSV
        results_df.to_csv(output_path, index=False)
        print(f"Comparison results saved to {output_path}")
        return results_df
    else:
        print("No comparison results generated")
        return None


if __name__ == "__main__":
    # Uncomment one of these lines to run the desired function
    # results = find_max_activating_samples()
    comparison_df = compare_max_activating_samples()
