import tempfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import wandb
from custom_dreamy.callbacks import (
    InpaintingCallback,
    ModelHelperCallback,
    WandbEPOCallback,
)
from custom_dreamy.epo import epo
from custom_dreamy.history import (
    HistoryColumns,
    get_pareto_frontier_df,
    log_dataframe_to_wandb,
    plot_target_vs_entropy,
    plot_target_vs_entropy_interactive,
)
from datasets import load_dataset
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from tabulate import tabulate

from eliciting_contexts.fluent_dreaming.baselines.model_generated_samples import (
    generate_openai_samples,
)
from eliciting_contexts.fluent_dreaming.baselines.prompts import (
    NEURONPEDIA_SYSTEM_PROMPT,
)
from eliciting_contexts.fluent_dreaming.sae import SAERunnerTLens
from eliciting_contexts.utils.constants import WANDB_ENTITY
from eliciting_contexts.utils.load_models import load_model_tlens, load_sae_saelens
from eliciting_contexts.utils.neuronpedia_api import get_feature_description


def calculate_coverage(frontier_a, frontier_b):
    """
    Calculate how much of frontier_b is dominated by frontier_a.
    Returns percentage of points in frontier_b dominated by frontier_a.
    Higher values mean frontier_a is better.
    """
    dominated_count = 0
    for b_idx in range(len(frontier_b)):
        b_target = frontier_b[HistoryColumns.TARGET].iloc[b_idx]
        b_xentropy = frontier_b[HistoryColumns.XENTROPY].iloc[b_idx]

        for a_idx in range(len(frontier_a)):
            a_target = frontier_a[HistoryColumns.TARGET].iloc[a_idx]
            a_xentropy = frontier_a[HistoryColumns.XENTROPY].iloc[a_idx]

            # Check if point a dominates point b (better in both dimensions)
            if a_target >= b_target and a_xentropy <= b_xentropy:
                if a_target > b_target or a_xentropy < b_xentropy:
                    dominated_count += 1
                    break

    return dominated_count / len(frontier_b) if len(frontier_b) > 0 else 0


def calculate_epsilon_indicator(frontier_a, frontier_b):
    """
    Calculate the epsilon indicator showing how much frontier_a would need to
    be shifted to dominate frontier_b.

    Returns two values:
    - target_epsilon: How much to add to target values
    - xentropy_epsilon: How much to subtract from xentropy values

    Lower values mean frontier_a is closer to dominating frontier_b.
    """
    if len(frontier_a) == 0 or len(frontier_b) == 0:
        return float("inf"), float("inf")

    # Initialize with worst possible values
    min_target_shift = float("-inf")
    min_xentropy_shift = float("-inf")

    # For each point in frontier_b, find the smallest shift needed
    for b_idx in range(len(frontier_b)):
        b_target = frontier_b[HistoryColumns.TARGET].iloc[b_idx]
        b_xentropy = frontier_b[HistoryColumns.XENTROPY].iloc[b_idx]

        # Find the closest point in frontier_a that would need the smallest shift
        point_target_shift = float("inf")
        point_xentropy_shift = float("inf")

        for a_idx in range(len(frontier_a)):
            a_target = frontier_a[HistoryColumns.TARGET].iloc[a_idx]
            a_xentropy = frontier_a[HistoryColumns.XENTROPY].iloc[a_idx]

            # Calculate shifts needed
            target_shift = b_target - a_target
            xentropy_shift = a_xentropy - b_xentropy

            # Update if better
            if max(target_shift, xentropy_shift) < max(
                point_target_shift, point_xentropy_shift
            ):
                point_target_shift = target_shift
                point_xentropy_shift = xentropy_shift

        # Update the minimum shifts needed
        min_target_shift = max(min_target_shift, point_target_shift)
        min_xentropy_shift = max(min_xentropy_shift, point_xentropy_shift)

    return min_target_shift, min_xentropy_shift


class RandomSentenceFetcher:
    """
    Class to extract sentences of specific length from a dataset.

    Initializes with a dataset and tokenizer, shuffles the dataset,
    and provides methods to extract sentences of specified length.
    """

    def __init__(
        self,
        tokenizer,
        dataset_name="Skylion007/openwebtext",
        split="train",
        streaming=True,
        seed=42,
    ):
        """
        Initialize the SentenceExtractor with a dataset and tokenizer.

        Args:
            dataset_name: Name of the HuggingFace dataset to load
            split: Dataset split to use (e.g., "train")
            tokenizer: Tokenizer for processing text
            streaming: Whether to load the dataset in streaming mode
            seed: Random seed for shuffling the dataset
        """
        self.tokenizer = tokenizer

        # Load dataset
        self.dataset = load_dataset(
            dataset_name, split=split, streaming=streaming, trust_remote_code=True
        )

        # Shuffle dataset
        self.dataset = self.dataset.shuffle(seed=seed)

        # Initialize iterator
        self.iterator = iter(self.dataset)

    def get_single_sentence(self, max_length=72) -> tuple[torch.Tensor, str]:
        """
        Extract a single sentence of exactly max_length tokens from the dataset.

        Args:
            max_length: Desired length of the extracted sequence

        Returns:
            tuple: (token_ids, text_string)
        """
        # Get the next sample
        try:
            sample = next(self.iterator)
        except StopIteration:
            # If we've exhausted the iterator, reset it
            self.iterator = iter(self.dataset)
            sample = next(self.iterator)

        text = sample["text"]

        # Tokenize without truncation to get full token count
        full_tokens = self.tokenizer(
            text, return_tensors="pt", add_special_tokens=True
        )["input_ids"][0]

        # Take exactly max_length tokens (or pad if needed)
        if len(full_tokens) >= max_length:
            tokens = full_tokens[:max_length]
        else:
            # Pad to max_length if necessary
            pad_token = self.tokenizer.pad_token_id
            padding = torch.full(
                (max_length - len(full_tokens),), pad_token, dtype=torch.long
            )
            tokens = torch.cat([full_tokens, padding])

        return tokens


def run_epo_experiment(
    log_name,
    sentences,
    runner,
    model,
    tokenizer,
    epo_num_runs,
    epo_population_size,
    epo_iters,
    epo_explore_per_pop,
    epo_restart_frequency,
    epo_x_penalty_min,
    epo_x_penalty_max,
    device,
    default_callbacks=[],
    epo_batch_size=32,
):
    """
    Run EPO experiment with the specified parameters and log results to wandb.

    Args:
        log_name: Name for logging purposes
        sentences: List of input sentences to use as initial states
        runner: SAE runner instance
        model: Model to use for EPO
        tokenizer: Tokenizer for the model
        epo_num_runs: Number of EPO runs to perform
        epo_population_size: Size of population per run
        epo_iters: Number of iterations per run
        epo_explore_per_pop: Exploration parameter
        epo_restart_frequency: How often to restart EPO
        epo_x_penalty_min: Minimum cross-entropy penalty
        epo_x_penalty_max: Maximum cross-entropy penalty
        device: Device to run on (e.g., "cuda")
        default_callbacks: Additional callbacks to use
        epo_batch_size: Batch size for EPO
    Returns:
        DataFrame containing combined results from all runs
    """

    all_dfs = []
    for i in range(epo_num_runs):
        callbacks = default_callbacks + [
            WandbEPOCallback(
                runner,
                model,
                tokenizer,
                x_penalty_min=epo_x_penalty_min,
                x_penalty_max=epo_x_penalty_max,
                name=f"{log_name}/run_{i}",
                per_pop_log=False,
            )
        ]

        initial_ids = sentences[i]
        text_initial = tokenizer.decode(initial_ids, skip_special_tokens=True)
        print(text_initial)
        initial_ids = initial_ids.unsqueeze(0).repeat(epo_population_size, 1)
        seq_len = initial_ids.shape[-1]

        history = epo(
            runner,
            model,
            tokenizer,
            iters=epo_iters,
            initial_ids=initial_ids,
            fixed_positions=None,
            population_size=epo_population_size,
            seq_len=seq_len,
            explore_per_pop=epo_explore_per_pop,
            restart_frequency=epo_restart_frequency,
            callbacks=callbacks,
            x_penalty_min=epo_x_penalty_min,
            x_penalty_max=epo_x_penalty_max,
            device=device,
            batch_size=epo_batch_size,
        )

        # Convert to dataframe for visualization
        print("converting to dataframe")
        history_df = history.to_dataframe(tokenizer, iter=epo_iters - 1, child=0)

        # Filter to keep only relevant columns and remove duplicates
        filtered_df = history_df[
            [HistoryColumns.TEXT, HistoryColumns.TARGET, HistoryColumns.XENTROPY]
        ]
        filtered_df = filtered_df.drop_duplicates(subset=[HistoryColumns.TEXT])

        # Log the filtered dataframe to wandb
        all_dfs.append(filtered_df)

    combined_history_df = pd.concat(all_dfs, ignore_index=True)
    log_dataframe_to_wandb(combined_history_df, summary_key=f"{log_name}_samples")

    return combined_history_df


# from dotenv import load_dotenv
# import os
# # Load environment variables from .env file
# load_dotenv()

# # Verify OpenAI API key is available
# if "OPENAI_API_KEY" not in os.environ:
#     print("Warning: OPENAI_API_KEY not found in environment variables")
# if "WANDB_API_KEY" not in os.environ:
#     print("Warning: WANDB_API_KEY not found in environment variables")


class Config:
    def __init__(self):
        # General
        self.device = "cuda"
        self.dtype = "bfloat16"
        self.wandb_mode = "online"
        self.wandb_project = "EPO-alternatives-comparison"
        self.wandb_entity = WANDB_ENTITY
        self.wandb_log_openai_calls = True

        # SAE lens
        self.sae_name = "gemma-2b-it-res-jb"
        self.sae_id = "blocks.12.hook_resid_post"
        self.neuronpedia_description_model_name = "gpt-3.5-turbo"
        self.num_features = 10  # TODO change

        # OpenAI call
        self.openai_call_temperature = 1.0
        self.openai_call_model = "gpt-4o"

        # EPO (shared)
        self.epo_iters = 200
        self.epo_seq_len = 72
        self.epo_explore_per_pop = 32
        self.epo_x_penalty_min = 0.1
        self.epo_x_penalty_max = 10
        self.epo_batch_size = 32

        # EPO (default only)
        self.epo_restart_frequency = 51

        # EPO (with openai assist)
        self.epo_assist_run_every = 76
        self.epo_assist_temperature = 1.0
        self.epo_assist_model = "gpt-4o"

        # EPO (with llada inpainting)
        self.epo_inpaint_every = 15
        self.epo_inpaint_replace_percent = 0.75
        self.epo_inpaint_replace_prob = 0.5
        self.epo_inpaint_token_per_step = 4
        self.epo_inpaint_unmasking = "high_confidence"
        self.epo_inpaint_model_name = "GSAI-ML/LLaDA-8B-Instruct"

        # Num samples
        self.total_samples_per_technique = 100
        self.openai_num_calls = 20
        assert (
            self.total_samples_per_technique % self.openai_num_calls == 0
        ), "total_samples_per_technique must be divisible by openai_num_calls"
        self.openai_num_samples_per_call = int(
            self.total_samples_per_technique / self.openai_num_calls
        )
        self.EPO_num_runs = 10
        assert (
            self.total_samples_per_technique % self.EPO_num_runs == 0
        ), "total_samples_per_technique must be divisible by EPO_num_runs"
        self.epo_population_size = int(
            self.total_samples_per_technique / self.EPO_num_runs
        )


# general
if __name__ == "__main__":
    config = Config()

    lookup = get_pretrained_saes_directory()[config.sae_name]
    tlens_model_name = lookup.model
    neuronpedia_id = lookup.neuronpedia_id[config.sae_id]

    # setup models
    sae = load_sae_saelens(config.sae_name, config.sae_id, config.device, config.dtype)
    model = load_model_tlens(tlens_model_name, config.device, config.dtype)
    tokenizer = model.tokenizer

    features = np.random.randint(0, sae.cfg.d_sae, size=config.num_features)

    # For storing coverage matrices across features
    all_raw_coverage_matrices = []
    method_names = [
        "OpenAI Direct",
        "Default EPO",
        "OpenAI-Assisted EPO",
        "Inpainting EPO",
    ]

    for feature in features:
        # setup
        # Create config dictionary for wandb logging
        wandb_config = vars(config).copy()
        wandb_config["feature"] = feature

        wandb.init(
            project=config.wandb_project,
            entity=config.wandb_entity,
            config=wandb_config,
            name=f"SAE-{config.sae_name}-{config.sae_id}-{feature}",
            mode=config.wandb_mode,
        )

        runner = SAERunnerTLens(model=model, sae=sae, feature=feature)

        sentence_fetcher = RandomSentenceFetcher(tokenizer)
        sentences = []
        for _ in range(config.EPO_num_runs):
            sentences.append(
                sentence_fetcher.get_single_sentence(max_length=config.epo_seq_len)
            )

        # get nueronpedia info.
        try:
            neuronpedia_info = get_feature_description(neuronpedia_id, index=feature)

            description = neuronpedia_info.get_explanation_by_model_name(
                config.neuronpedia_description_model_name
            )

            top_activations = neuronpedia_info.get_contexts_around_top_n_activations(
                n=10, window=20
            )
        except Exception as e:
            print(f"Error getting neuronpedia info: {e}")
            continue

        openai_system_prompt = NEURONPEDIA_SYSTEM_PROMPT.format(
            description,
            chr(10).join(
                [f"{i + 1}. {example}" for i, example in enumerate(top_activations)]
            ),
            config.openai_num_samples_per_call,
        )

        print(openai_system_prompt)

        openaicall_df = generate_openai_samples(
            runner=runner,
            tokenizer=tokenizer,
            dataset_name="",  # for hf only
            system_prompt=openai_system_prompt,
            num_calls=config.openai_num_calls,
            num_samples_per_call=config.openai_num_samples_per_call,
            temperature=config.openai_call_temperature,
            model=config.openai_call_model,
            device=config.device,
            save_to="wandb" if config.wandb_log_openai_calls else None,
        )

        default_epo_df = run_epo_experiment(
            log_name="Default_EPO",
            sentences=sentences,
            runner=runner,
            model=model,
            tokenizer=tokenizer,
            epo_num_runs=config.EPO_num_runs,
            epo_population_size=config.epo_population_size,
            epo_iters=config.epo_iters,
            epo_explore_per_pop=config.epo_explore_per_pop,
            epo_restart_frequency=config.epo_restart_frequency,
            epo_x_penalty_min=config.epo_x_penalty_min,
            epo_x_penalty_max=config.epo_x_penalty_max,
            device=config.device,
            default_callbacks=[],
            epo_batch_size=config.epo_batch_size,
        )

        openai_callback = [
            ModelHelperCallback(
                tokenizer,
                run_every=config.epo_assist_run_every,
                model_name=config.epo_assist_model,
                temperature=config.epo_assist_temperature,
            )
        ]

        openai_assist_epo_df = run_epo_experiment(
            log_name="OpenAI_assist_EPO",
            sentences=sentences,
            runner=runner,
            model=model,
            tokenizer=tokenizer,
            epo_num_runs=config.EPO_num_runs,
            epo_population_size=config.epo_population_size,
            epo_iters=config.epo_iters,
            epo_explore_per_pop=config.epo_explore_per_pop,
            epo_restart_frequency=None,
            epo_x_penalty_min=config.epo_x_penalty_min,
            epo_x_penalty_max=config.epo_x_penalty_max,
            device=config.device,
            default_callbacks=openai_callback,
            epo_batch_size=config.epo_batch_size,
        )

        inpaint_callback = [
            InpaintingCallback(
                tokenizer,
                inpaint_every=config.epo_inpaint_every,
                model_name=config.epo_inpaint_model_name,
                replace_percent=config.epo_inpaint_replace_percent,
                replace_prob=config.epo_inpaint_replace_prob,
                token_per_step=config.epo_inpaint_token_per_step,
                unmasking=config.epo_inpaint_unmasking,
                device=config.device,
                torch_dtype=torch.bfloat16,
            )
        ]

        inpaint_epo_df = run_epo_experiment(
            log_name="Inpainting_EPO",
            sentences=sentences,
            runner=runner,
            model=model,
            tokenizer=tokenizer,
            epo_num_runs=config.EPO_num_runs,
            epo_population_size=config.epo_population_size,
            epo_iters=config.epo_iters,
            epo_explore_per_pop=config.epo_explore_per_pop,
            epo_restart_frequency=None,
            epo_x_penalty_min=config.epo_x_penalty_min,
            epo_x_penalty_max=config.epo_x_penalty_max,
            device=config.device,
            default_callbacks=inpaint_callback,
            epo_batch_size=config.epo_batch_size,
        )

        # Calculate Pareto frontier dataframes for each method
        pareto_openai = get_pareto_frontier_df(openaicall_df)
        pareto_default_epo = get_pareto_frontier_df(default_epo_df)
        pareto_openai_assist = get_pareto_frontier_df(openai_assist_epo_df)
        pareto_inpaint = get_pareto_frontier_df(inpaint_epo_df)

        # Log Pareto frontiers to wandb
        log_dataframe_to_wandb(pareto_openai, summary_key="Pareto/OpenAI_Pareto")
        log_dataframe_to_wandb(
            pareto_default_epo, summary_key="Pareto/Default_EPO_Pareto"
        )
        log_dataframe_to_wandb(
            pareto_openai_assist, summary_key="Pareto/OpenAI_assist_EPO_Pareto"
        )
        log_dataframe_to_wandb(
            pareto_inpaint, summary_key="Pareto/Inpainting_EPO_Pareto"
        )

        # Create a descriptive title including neuronpedia info
        plot_title = f"Feature {feature} - {description}"

        # Create static plot and log to wandb
        plt.figure(figsize=(12, 8))
        plot_target_vs_entropy(
            dfs=[
                pareto_openai,
                pareto_default_epo,
                pareto_openai_assist,
                pareto_inpaint,
            ],
            names=[
                "OpenAI Direct",
                "Default EPO",
                "OpenAI-Assisted EPO",
                "Inpainting EPO",
            ],
            title=plot_title,
            highlight_pareto=True,
        )

        # Save to temporary file and log to wandb
        with tempfile.NamedTemporaryFile(suffix=".png") as tmp:
            plt.savefig(tmp.name, dpi=300, bbox_inches="tight")
            wandb.log({"pareto_frontiers_comparison": wandb.Image(tmp.name)})
        plt.close()  # Close the figure without displaying

        # Create interactive plot
        interactive_fig = plot_target_vs_entropy_interactive(
            dfs=[
                pareto_openai,
                pareto_default_epo,
                pareto_openai_assist,
                pareto_inpaint,
            ],
            names=[
                "OpenAI Direct",
                "Default EPO",
                "OpenAI-Assisted EPO",
                "Inpainting EPO",
            ],
            title=plot_title,
            highlight_pareto=True,
            height=800,
            width=1000,
        )

        # Save interactive plot to HTML and log to wandb
        with tempfile.NamedTemporaryFile(suffix=".html") as tmp:
            interactive_fig.write_html(tmp.name)
            wandb.log({"interactive_pareto_frontiers": wandb.Html(tmp.name)})

        # Calculate comparison metrics for the Pareto frontiers

        # Compute coverage metrics
        frontiers = {
            "OpenAI Direct": pareto_openai,
            "Default EPO": pareto_default_epo,
            "OpenAI-Assisted EPO": pareto_openai_assist,
            "Inpainting EPO": pareto_inpaint,
        }

        # Coverage metrics
        coverage_matrix = []
        headers = ["Coverage %"] + list(frontiers.keys())

        for name_a, frontier_a in frontiers.items():
            row = [name_a]
            for name_b, frontier_b in frontiers.items():
                if name_a == name_b:
                    row.append("---")
                else:
                    coverage = calculate_coverage(frontier_a, frontier_b) * 100
                    row.append(f"{coverage:.1f}%")
            coverage_matrix.append(row)

        print("\n=== Coverage Matrix (% of points dominated) ===")
        print("Rows dominate columns. Higher % is better.")
        print(tabulate(coverage_matrix, headers=headers, tablefmt="grid"))

        # Epsilon indicators
        epsilon_matrix = []
        headers = ["Epsilon Indicator"] + list(frontiers.keys())

        for name_a, frontier_a in frontiers.items():
            row = [name_a]
            for name_b, frontier_b in frontiers.items():
                if name_a == name_b:
                    row.append("---")
                else:
                    target_eps, xentropy_eps = calculate_epsilon_indicator(
                        frontier_a, frontier_b
                    )
                    row.append(f"T: {target_eps:.3f}, X: {xentropy_eps:.3f}")
            epsilon_matrix.append(row)

        print("\n=== Epsilon Indicators (shift needed to dominate) ===")
        print(
            "How much Row needs to improve to dominate Column. Lower values are better."
        )
        print(tabulate(epsilon_matrix, headers=headers, tablefmt="grid"))

        # Log these metrics to wandb
        wandb.log(
            {
                "metrics/coverage_matrix": wandb.Table(
                    columns=headers, data=coverage_matrix
                ),
                "metrics/epsilon_matrix": wandb.Table(
                    columns=headers, data=epsilon_matrix
                ),
            }
        )

        # After the existing coverage matrix calculation, store the raw values
        raw_coverage_values = []
        for name_a, frontier_a in frontiers.items():
            raw_row = []
            for name_b, frontier_b in frontiers.items():
                if name_a == name_b:
                    raw_row.append(float("nan"))  # Use nan for diagonal entries
                else:
                    coverage = calculate_coverage(frontier_a, frontier_b) * 100
                    raw_row.append(coverage)
            raw_coverage_values.append(raw_row)

        all_raw_coverage_matrices.append(raw_coverage_values)

        wandb.finish()  # Close the current wandb run

    # After all features processed, calculate average coverage matrix
    avg_coverage = np.nanmean(np.array(all_raw_coverage_matrices), axis=0)

    # Format the average coverage matrix for display
    avg_coverage_matrix = []
    for i, name_a in enumerate(method_names):
        row = [name_a]
        for j, name_b in enumerate(method_names):
            if i == j:
                row.append("---")
            else:
                row.append(f"{avg_coverage[i][j]:.1f}%")
        avg_coverage_matrix.append(row)

    # Print average coverage matrix
    print("\n=== AVERAGE Coverage Matrix Across All Features ===")
    print("Rows dominate columns. Higher % is better.")
    print(
        tabulate(
            avg_coverage_matrix, headers=["Coverage %"] + method_names, tablefmt="grid"
        )
    )

    # Create a new wandb run for the aggregated results
    aggregate_config = vars(config).copy()
    aggregate_config["features"] = features.tolist()  # Include all features as a list

    wandb.init(
        project=config.wandb_project,
        entity=config.wandb_entity,
        config=aggregate_config,
        name=f"AGGREGATE-{config.sae_name}-{config.sae_id}-{config.num_features}_features",
        mode=config.wandb_mode,
    )

    # Log the average coverage matrix to wandb
    wandb.log(
        {
            "aggregate/avg_coverage_matrix": wandb.Table(
                columns=["Coverage %"] + method_names, data=avg_coverage_matrix
            ),
        }
    )

    # Close the final wandb run
    wandb.finish()
