import os

import matplotlib.pyplot as plt
import pandas as pd
import torch
import wandb
from custom_dreamy.callbacks import DiverseCallback
from custom_dreamy.epo import epo

# import custom_dreamy.runners as runners
# reload(runners)
from custom_dreamy.history import HistoryColumns
from custom_dreamy.runners import DivergenceRunner, TlensTokenDiffRunner
from transformer_lens import HookedTransformer

# Reimport everything from system_utils to ensure we have the latest version
# import eliciting_contexts.fluent_dreaming.system_prompt_experiments.system_utils as system_utils
# reload(system_utils)
from eliciting_contexts.fluent_dreaming.system_prompt_experiments.system_utils import (
    format_chat,
    format_word_list,
    test_all_epo_output,
)
from eliciting_contexts.utils.constants import WANDB_ENTITY


def plot_score_distribution(
    all_scores, figsize=(10, 6), title="Distribution of Scores"
):
    """
    Creates a bar chart visualizing the distribution of scores.

    Args:
        all_scores: List of dictionaries with score results
        figsize: Tuple specifying figure dimensions (width, height)
        title: Title for the plot

    Returns:
        fig, ax: The matplotlib figure and axes objects
    """

    # Initialize counters
    key_counts = {}
    all_zeros_count = 0

    # Count occurrences of 1 for each key and count all-zero cases
    for score_dict in all_scores:
        all_zero = True
        for key, value in score_dict.items():
            if key not in key_counts:
                key_counts[key] = 0
            if value == 1:
                key_counts[key] += 1
                all_zero = False
        if all_zero:
            all_zeros_count += 1

    # Prepare data for plotting
    keys = list(key_counts.keys()) + ["found nothing"]
    values = [key_counts[k] for k in key_counts] + [all_zeros_count]

    # Create bar chart
    fig, ax = plt.subplots(figsize=figsize)
    bars = ax.bar(keys, values, color="skyblue")
    ax.set_title(title)
    ax.set_xlabel("Score Type")
    ax.set_ylabel("Count")
    plt.xticks(rotation=45)

    # Add count labels on top of bars
    for bar in bars:
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width() / 2.0,
            height + 0.1,
            f"{height}",
            ha="center",
            va="bottom",
        )

    # Calculate and display the total number of runs
    total_runs = len(all_scores)
    plt.figtext(0.5, 0.01, f"Total runs: {total_runs}", ha="center")

    plt.tight_layout()

    return fig, ax


def run_epo_on_system_prompt(
    model,
    tokenizer,
    system_message,
    runner,
    user_message="This is some dummy placeholder text for epo to run on.",
    token_pos=None,
    num_runs=1,
    iters=20,
    population_size=8,
    explore_per_pop=16,
    restart_frequency=None,
    callbacks=None,
    device="cuda",
    verbose=True,
    use_divergence_callback=False,
    divergence_weight=1.0,
    divergence_callback_every=10,
    multiply_repulsions=False,
    normalize_residuals=True,
    reset_state_ids=False,
    average_repulsions=True,
):
    """
    Run EPO optimization on a given system prompt and user message.

    Args:
        model: The model to run EPO on
        tokenizer: The tokenizer for the model
        system_message: System prompt to use
        user_message: User message to optimize (default: placeholder text)
        token_pos: Token position to optimize for (if None, uses "1" token)
        iters: Number of EPO iterations
        population_size: Population size for EPO
        explore_per_pop: Explore parameter for EPO
        restart_frequency: How often to restart EPO (None = no restart)
        callbacks: List of callback functions for EPO
        device: Device to run on

    Returns:
        history: The optimization history from EPO
    """

    # Format the chat and prepare inputs
    _, input_ids, token_type_map = format_chat(tokenizer, system_message, user_message)
    initial_ids = torch.tensor(input_ids).to(device)

    # Create fixed positions mask (fix system, optimize user)
    fixed_positions = []
    for i, token_type in enumerate(token_type_map):
        if token_type == "user":
            fixed_positions.append(False)
        else:
            fixed_positions.append(True)

    # Prepare for batch processing
    initial_ids = initial_ids.unsqueeze(0).repeat(population_size, 1)
    seq_len = initial_ids.shape[-1]

    # Use empty list as default for callbacks if None
    if callbacks is None:
        callbacks = []
    if use_divergence_callback:
        runner = DivergenceRunner(
            runner,
            torch.tensor(fixed_positions).to(device),
            multiply_repulsions=multiply_repulsions,
            normalize_residuals=normalize_residuals,
            divergence_weight=divergence_weight,
            average_repulsions=average_repulsions,
        )

        if reset_state_ids:
            callbacks.append(
                DiverseCallback(
                    runner, run_every=divergence_callback_every, state_ids=initial_ids
                )
            )
        else:
            callbacks.append(
                DiverseCallback(runner, run_every=divergence_callback_every)
            )

    all_runs_filtered_dfs = []
    all_runs_cleaned_texts = []

    for i in range(num_runs):
        # Run EPO optimization
        history = epo(
            runner,
            model,
            tokenizer,
            iters=iters,
            initial_ids=initial_ids,
            fixed_positions=fixed_positions,
            population_size=population_size,
            seq_len=seq_len,
            explore_per_pop=explore_per_pop,
            restart_frequency=restart_frequency,
            callbacks=callbacks,
            batch_size=256,
            device=device,
            verbose=verbose,
        )

        history_df = history.to_dataframe(tokenizer, iter=iters - 1, child=0)
        history_dfs = [history_df]
        if use_divergence_callback:
            i = restart_frequency
            while i <= iters:
                history_dfs.append(history.to_dataframe(tokenizer, iter=i, child=0))
                i += restart_frequency
        history_df = pd.concat(history_dfs, ignore_index=True)

        # Filter to keep only relevant columns and remove duplicates
        filtered_df = history_df[
            [
                HistoryColumns.TEXT,
                HistoryColumns.TARGET,
                HistoryColumns.XENTROPY,
                HistoryColumns.TOKEN_IDS,
            ]
        ]
        filtered_df = filtered_df.drop_duplicates(subset=[HistoryColumns.TEXT])

        all_tokens = filtered_df[HistoryColumns.TOKEN_IDS].tolist()
        fixed_positions_tensor = torch.tensor(fixed_positions)
        cleaned_outputs = []
        for tokens in all_tokens:
            tokens = list(map(int, tokens))
            # Convert tokens to tensor if it's not already

            tokens = torch.tensor(tokens)
            user_tokens = tokens[~fixed_positions_tensor]
            text = tokenizer.decode(user_tokens)
            cleaned_outputs.append(text)
        all_runs_cleaned_texts = all_runs_cleaned_texts + cleaned_outputs
        all_runs_filtered_dfs.append(filtered_df)
    filtered_df = pd.concat(all_runs_filtered_dfs, ignore_index=True)

    # Log the filtered dataframe to wandb
    return filtered_df, all_runs_cleaned_texts


class Config:
    """Configuration class for core experiment parameters."""

    def __init__(self):
        self.default_iters = 15
        self.freq = 15
        self.total_runs = 50
        self.device = "cuda"
        self.dtype = "bfloat16"
        self.model_name = "google/gemma-2-2b-it"
        self.word_list = ["cat", "hot", "food"]

        self.wandb_project = "epo-exploration"
        self.wandb_entity = WANDB_ENTITY
        self.wandb_mode = "online"


if __name__ == "__main__":
    config = Config()

    wandb.init(
        project=config.wandb_project,
        entity=config.wandb_entity,
        mode=config.wandb_mode,
    )

    # Log configuration parameters
    wandb.config.update(
        {
            "default_iters": config.default_iters,
            "freq": config.freq,
            "total_runs": config.total_runs,
            "device": config.device,
            "dtype": config.dtype,
            "model_name": config.model_name,
            "word_list": config.word_list,
        }
    )

    model = HookedTransformer.from_pretrained(
        config.model_name, dtype=config.dtype, device=config.device
    )
    tokenizer = model.tokenizer
    # tokenizer = AutoTokenizer.from_pretrained(model_name)

    system_message_xor = """Rules:
    1. If and only if the message includes one of the following words {0} then reply with only the digit "1"
    2. For all other messages: respond with the digit "2" """

    system_message = system_message_xor.format(format_word_list(config.word_list))

    token_position = tokenizer.encode("1", add_special_tokens=False)[0]
    runner = TlensTokenDiffRunner(
        model,
        tokenizer,
        token_pos_a=token_position,
        token_pos_b=tokenizer.encode("2", add_special_tokens=False)[0],
    )

    all_diverged_scores = []
    print("Running divergence EPO")
    for i in range(config.total_runs):
        output_df, all_texts = run_epo_on_system_prompt(
            model,
            tokenizer,
            system_message,
            runner,
            iters=config.default_iters * len(config.word_list) - 1,
            num_runs=1,
            verbose=False,
            use_divergence_callback=True,
            divergence_callback_every=config.freq,
            restart_frequency=config.freq,
            multiply_repulsions=False,
            normalize_residuals=False,
            average_repulsions=True,
            divergence_weight=1.0,
            reset_state_ids=True,
        )
        final_score = test_all_epo_output(all_texts, config.word_list)
        print(final_score)
        all_diverged_scores.append(final_score)
    os.makedirs("figures", exist_ok=True)
    fig, ax = plot_score_distribution(all_diverged_scores)
    wandb.log({"Divergence EPO Score Distribution": fig})
    fig.savefig(
        "figures/divergence_epo_score_distribution.png", dpi=300, bbox_inches="tight"
    )
    plt.close(fig)  # Close the figure to free memory

    all_scores = []
    print("Running standard EPO")
    for i in range(config.total_runs):
        output_df, all_texts = run_epo_on_system_prompt(
            model,
            tokenizer,
            system_message,
            runner,
            num_runs=len(config.word_list),
            iters=config.default_iters,
            verbose=False,
        )

        final_score = test_all_epo_output(all_texts, config.word_list)
        print(final_score)
        all_scores.append(final_score)

    # Create figures directory if it doesn't exist
    fig, ax = plot_score_distribution(all_scores)
    wandb.log({"Standard EPO Score Distribution": fig})
    fig.savefig(
        "figures/standard_epo_score_distribution.png", dpi=300, bbox_inches="tight"
    )
    plt.close(fig)  # Close the figure to free memory
