from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import wandb
from custom_dreamy.epo import epo
from custom_dreamy.history import HistoryColumns, simple_print_df
from custom_dreamy.runners import (
    L1SAERunnerTLens,
    MaxSAERunnerTLens,
    MaxWithSumSAERunnerTLens,
    SumSAERunnerTLens,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory

from eliciting_contexts.utils.constants import WANDB_ENTITY
from eliciting_contexts.utils.load_models import load_model_tlens, load_sae_saelens
from eliciting_contexts.utils.neuronpedia_api import get_feature_description


def run_epo_and_get_df(
    runner,
    model,
    tokenizer,
    initial_ids,
    iters=100,
    population_size=8,
    seq_len=None,
    explore_per_pop=16,
    restart_frequency=None,
    x_penalty_min=0.1,
    x_penalty_max=10.0,
    device="cuda",
    num_runs=1,
    callbacks=None,
    fixed_positions=None,
):
    """
    Run evolutionary optimization and return processed dataframe of results.

    Args:
        runner: The SAE runner (like max_runner, l1_runner, etc.)
        model: The model to use
        tokenizer: The tokenizer to use
        initial_ids: Initial token IDs
        iters: Number of iterations
        population_size: Size of the population
        seq_len: Sequence length
        explore_per_pop: Exploration per population
        restart_frequency: Restart frequency
        x_penalty_min: Minimum x penalty
        x_penalty_max: Maximum x penalty
        device: Device to run on
        num_runs: Number of runs to perform
        callbacks: List of callbacks
        fixed_positions: Fixed positions

    Returns:
        pd.DataFrame: Concatenated dataframe of results
    """
    if callbacks is None:
        callbacks = []

    if seq_len is None:
        seq_len = initial_ids.shape[-1]

    all_dfs = []
    for _ in range(num_runs):
        history = epo(
            runner,
            model,
            tokenizer,
            iters=iters,
            initial_ids=initial_ids,
            fixed_positions=fixed_positions,
            population_size=population_size,
            seq_len=seq_len,
            explore_per_pop=explore_per_pop,
            restart_frequency=restart_frequency,
            callbacks=callbacks,
            x_penalty_min=x_penalty_min,
            x_penalty_max=x_penalty_max,
            device=device,
            verbose=False,
        )

        history_df = history.to_dataframe(tokenizer, iter=iters - 1, child=0)

        # Filter to keep only relevant columns and remove duplicates
        filtered_df = history_df[
            [HistoryColumns.TEXT, HistoryColumns.TARGET, HistoryColumns.XENTROPY]
        ]

        # only take middle row - use double brackets to keep as DataFrame
        filtered_df = filtered_df.iloc[[len(filtered_df) // 2]]

        # filtered_df = filtered_df.drop_duplicates(subset=[HistoryColumns.TEXT])

        # Add to list of dataframes
        all_dfs.append(filtered_df)

    return pd.concat(all_dfs, ignore_index=True)


def gather_points(texts, sae, features, model, tokenizer):
    out = {}

    # Convert batch of texts to token IDs
    batch_input_ids = tokenizer(texts, padding=True, return_tensors="pt").input_ids.to(
        model.cfg.device
    )
    # Convert token IDs to embeddings
    batch_embeddings = model.embed(batch_input_ids)

    def get_targets(resid, hook):
        # Handle batch of embeddings
        inp_acts = resid

        pre_acts = []

        def store_pre_acts(acts: torch.Tensor, hook: str):
            pre_acts.append(acts)
            return acts

        sae.run_with_hooks(
            inp_acts,
            fwd_hooks=[("hook_sae_acts_pre", store_pre_acts)],
        )

        batch_pre_acts = pre_acts[0][:, :, features]
        out["per_token_target"] = batch_pre_acts
        return resid

    # Use TransformerLens hook system
    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, get_targets)]):
        model(
            batch_embeddings,
            start_at_layer=0,  # needed to skip embedding layer
            return_type="logits",
        )

    points = out["per_token_target"]
    # flatten
    points = points.reshape(-1, *points.shape[2:])

    return points


def plot_multiple_feature_points(
    points_list, labels=None, colors=None, markers=None, title=None
):
    """
    Function to plot multiple sets of feature activations and return the figure.

    Args:
        points_list: A list of tensors/arrays of shape [n, 2] containing feature activations
        labels: Optional list of labels for the legend
        colors: Optional list of colors for each set of points
        markers: Optional list of marker styles for each set of points
        title: Optional custom title for the plot

    Returns:
        matplotlib.figure.Figure: The created figure object
    """
    if labels is None:
        labels = [f"Set {i+1}" for i in range(len(points_list))]

    if colors is None:
        colors = [
            "#1f77b4",  # blue
            "#ff7f0e",  # orange
            "#2ca02c",  # green
            "#d62728",  # red
            "#9467bd",  # purple
            "#8c564b",  # brown
            "#e377c2",  # pink
            "#7f7f7f",  # gray
            "#bcbd22",  # yellow-green
            "#17becf",  # cyan
        ]

    if markers is None:
        markers = [
            "o",  # circle
            "s",  # square
            "^",  # triangle-up
            "D",  # diamond
            "v",  # triangle-down
            "<",  # triangle-left
            ">",  # triangle-right
            "p",  # pentagon
            "*",  # star
            "h",  # hexagon
        ]

    fig = plt.figure(figsize=(10, 8))

    for i, points_tensor in enumerate(points_list):
        # Convert tensor to numpy if needed
        if torch.is_tensor(points_tensor):
            points_np = points_tensor.detach().cpu().numpy()
        else:
            points_np = points_tensor

        # Plot this set of points with custom color and marker
        plt.scatter(
            points_np[:, 0],
            points_np[:, 1],
            alpha=0.7,
            color=colors[i % len(colors)],
            marker=markers[i % len(markers)],
            label=labels[i],
        )

    if title:
        plt.title(title)
    else:
        plt.title(f"Feature Activations for Features {features}")
    plt.xlabel(f"Feature {features[0]} Activation")
    plt.ylabel(f"Feature {features[1]} Activation")
    plt.grid(True, linestyle="--", alpha=0.7)
    plt.legend()
    plt.tight_layout()

    # Return the figure instead of showing it
    return fig


class Config:
    """Configuration class for experiment parameters"""

    def __init__(self):
        # Hardware settings
        self.device = "cuda"
        self.dtype = "bfloat16"

        # SAE settings
        self.sae_name = "gemma-2b-it-res-jb"
        self.sae_id = "blocks.12.hook_resid_post"

        # Number of features to plot
        self.num_features = 50

        # EPO parameters
        self.population_size = 8
        self.iters = 150
        self.explore_per_pop = 32
        self.restart_frequency = None
        self.x_penalty_min = 3
        self.x_penalty_max = 10.0
        self.num_runs = 1

        # Wandb settings
        self.wandb_project = "multple-features"
        self.wandb_entity = WANDB_ENTITY
        self.wandb_mode = "online"
        self.neuronpedia_description_model_name = "gpt-3.5-turbo"

        self.start_text = "Generic text to start with because why not? Should be fine. Add some more in here let's do it."


if __name__ == "__main__":
    config = Config()

    lookup = get_pretrained_saes_directory()[config.sae_name]

    tlens_model_name = lookup.model
    neuronpedia_id = lookup.neuronpedia_id[config.sae_id]

    sae = load_sae_saelens(config.sae_name, config.sae_id, config.device, config.dtype)
    model = load_model_tlens(tlens_model_name, config.device, config.dtype)
    tokenizer = model.tokenizer

    features_0 = np.random.randint(0, sae.cfg.d_sae, size=config.num_features)
    features_1 = np.random.randint(0, sae.cfg.d_sae, size=config.num_features)
    features_pairs = list(zip(features_0, features_1))

    wandb.init(
        project=config.wandb_project,
        entity=config.wandb_entity,
        mode=config.wandb_mode,
    )

    # Log the config to wandb
    config_dict = {
        attr: getattr(config, attr)
        for attr in dir(config)
        if not attr.startswith("__") and not callable(getattr(config, attr))
    }
    wandb.config.update(config_dict)
    features_pairs_list = [(int(f[0]), int(f[1])) for f in features_pairs]
    wandb.config.update({"feature_pairs": features_pairs_list})

    for features in features_pairs:

        # Use features from config instead of hardcoded values

        basic_sum_runner = SumSAERunnerTLens(model, sae, features)
        l1_runner = L1SAERunnerTLens(model, sae, features)
        max_runner = MaxSAERunnerTLens(model, sae, features)
        max_with_sum_runner = MaxWithSumSAERunnerTLens(model, sae, features)

        try:
            neuronpedia_info = get_feature_description(
                neuronpedia_id, index=features[0]
            )
            description_one = neuronpedia_info.get_explanation_by_model_name(
                config.neuronpedia_description_model_name
            )
            neuronpedia_info = get_feature_description(
                neuronpedia_id, index=features[1]
            )
            description_two = neuronpedia_info.get_explanation_by_model_name(
                config.neuronpedia_description_model_name
            )
        except Exception as e:
            print(f"Error getting neuronpedia info: {e}")
            continue

        initial_ids = torch.tensor(tokenizer.encode(config.start_text))
        initial_ids = initial_ids.unsqueeze(0).repeat(config.population_size, 1)
        seq_len = initial_ids.shape[-1]

        run_epo_for_runner = partial(
            run_epo_and_get_df,
            model=model,
            tokenizer=tokenizer,
            initial_ids=initial_ids,
            iters=config.iters,
            population_size=config.population_size,
            seq_len=seq_len,
            explore_per_pop=config.explore_per_pop,
            restart_frequency=config.restart_frequency,
            x_penalty_min=config.x_penalty_min,
            x_penalty_max=config.x_penalty_max,
            device=config.device,
            num_runs=config.num_runs,
            callbacks=None,
            fixed_positions=None,
        )

        # Run evolutionary optimization
        max_df = run_epo_for_runner(max_runner)
        simple_print_df(max_df)

        # Run evolutionary optimization
        l1_df = run_epo_for_runner(l1_runner)
        simple_print_df(l1_df)

        # Run evolutionary optimization
        sum_df = run_epo_for_runner(basic_sum_runner)
        simple_print_df(sum_df)

        max_with_sum_df = run_epo_for_runner(max_with_sum_runner)
        simple_print_df(max_with_sum_df)

        points_max = (
            gather_points(max_df["text"].tolist(), sae, features, model, tokenizer)
            .detach()
            .cpu()
            .float()
            .numpy()
        )
        points_max_with_sum = (
            gather_points(
                max_with_sum_df["text"].tolist(), sae, features, model, tokenizer
            )
            .detach()
            .cpu()
            .float()
            .numpy()
        )
        points_l1 = (
            gather_points(l1_df["text"].tolist(), sae, features, model, tokenizer)
            .detach()
            .cpu()
            .float()
            .numpy()
        )
        points_sum = (
            gather_points(sum_df["text"].tolist(), sae, features, model, tokenizer)
            .detach()
            .cpu()
            .float()
            .numpy()
        )
        points_list = [points_max, points_l1, points_sum, points_max_with_sum]
        points_names = ["max", "L0.1", "sum", "max_with_sum"]

        # Get the figure and log it to wandb instead of showing it
        custom_title = f"\nFeature {features[0]} vs ({description_two})"
        feature_plot = plot_multiple_feature_points(
            points_list, labels=points_names, title=custom_title
        )

        # Save the plot as PNG
        # TODO move to temp!
        png_filename = f"feature_plot_{features[0]}_{features[1]}.png"
        feature_plot.savefig(png_filename, format="png", dpi=300, bbox_inches="tight")

        # Create a table for feature descriptions instead of logging as separate strings
        feature_descriptions = wandb.Table(columns=["Feature ID", "Description"])
        feature_descriptions.add_data(f"Feature {features[0]}", description_one)
        feature_descriptions.add_data(f"Feature {features[1]}", description_two)

        wandb.log(
            {
                f"feature_plot_{features[0]}_{features[1]}_png": wandb.Image(
                    png_filename
                ),
                f"feature_descriptions_{features[0]}_{features[1]}": feature_descriptions,
            }
        )
        plt.close(feature_plot)  # Close the figure to free memory
