# %% Imports
import torch
import wandb
from custom_dreamy.callbacks import (
    InpaintingCallback,
    ModelHelperCallback,
    WandbEPOCallback,
)
from custom_dreamy.history import (
    HistoryColumns,
    get_pareto_frontier_df,
    pretty_print_df,
)
from custom_dreamy.runners import SAERunnerTLens
from custom_dreamy.visualize.visualize import visualize_population_tree_interactive

from eliciting_contexts.fluent_dreaming.run_fluent_dreaming import (
    epo,
    log_fluent_dreaming_to_wandb,
)
from eliciting_contexts.utils.auth import setup_wandb_auth
from eliciting_contexts.utils.constants import WANDB_ENTITY
from eliciting_contexts.utils.load_models import load_sae_and_model


def run_sae_fluent_dreaming(
    feature: int = 321,
    population_size: int = 8,
    start_text: str = "USER LOG: Data computation phase - Started running now",
    x_penalty_min: float = 0.1,
    x_penalty_max: float = 10.0,
    restart_frequency: int = 50,
    iters: int = 300,
    return_logits: bool = False,
    wandb_run_name: str = None,
    use_wandb: bool = True,
    explore_per_pop: int = 32,
    use_plotly: bool = False,
    do_visualize: bool = False,
    visualize_path: str = None,
    use_model_helper: bool = False,
    model_helper_run_every: int = 76,
    model_helper_model_name: str = "gpt-4o",
    use_inpainting: bool = False,
    inpaint_every: int = 2,
    device: str = "cuda",
    seq_len: int = 12,
):
    """
    Run fluent dreaming optimization on a finetuned model to analyze neuron activations.
    """

    if use_wandb:
        # Set up wandb authentication
        setup_wandb_auth()

        # Initialize wandb run
        run = wandb.init(
            project="fluent-dreaming",
            entity=WANDB_ENTITY,
            name=wandb_run_name,
            config={
                "population_size": population_size,
                "start_text": start_text,
                "x_penalty_min": x_penalty_min,
                "x_penalty_max": x_penalty_max,
                "iters": iters,
                "return_logits": return_logits,
            },
        )

    # Load the finetuned model
    model, tokenizer, sae, _, _ = load_sae_and_model(device=device)

    runner = SAERunnerTLens(model=model, sae=sae, feature=feature)

    if start_text == "":
        initial_ids = None
        seq_len = seq_len
        fixed_positions = None
    else:
        initial_ids = torch.tensor(tokenizer.encode(start_text))
        initial_ids = initial_ids.unsqueeze(0).repeat(population_size, 1)
        seq_len = initial_ids.shape[-1]
        fixed_positions = None

    callbacks = []
    if use_model_helper:
        callbacks.append(
            ModelHelperCallback(
                tokenizer,
                run_every=model_helper_run_every,
                model_name=model_helper_model_name,
            )
        )

    if use_inpainting:
        callbacks.append(
            InpaintingCallback(
                tokenizer,
                inpaint_every=inpaint_every,
                device=device,
                torch_dtype=torch.bfloat16,
            )
        )

    if use_wandb:
        callbacks.append(
            WandbEPOCallback(
                runner,
                model,
                tokenizer,
                x_penalty_min=x_penalty_min,
                x_penalty_max=x_penalty_max,
            )
        )

    # Run evolutionary optimization
    history = epo(
        runner,
        model,
        tokenizer,
        iters=iters,
        initial_ids=initial_ids,
        fixed_positions=fixed_positions,
        population_size=population_size,
        seq_len=seq_len,
        explore_per_pop=explore_per_pop,
        restart_frequency=restart_frequency,
        callbacks=callbacks,
        x_penalty_min=x_penalty_min,
        x_penalty_max=x_penalty_max,
        device=device,
    )

    # Convert to dataframe for visualization
    print("converting to dataframe")
    # TODO add option to just do subset of history
    history_df = history.to_dataframe(tokenizer)
    print("done")
    # Create a visualization showing the ancestry of final iteration parents
    if do_visualize:
        visualize_population_tree_interactive(
            history_df, tokenizer, output_path=visualize_path
        )

    # Log final results to W&B only if use_wandb is True
    if use_wandb:
        log_fluent_dreaming_to_wandb(model, tokenizer, history, runner, "")
        run.finish()

    # Close the wandb run
    if use_wandb:
        run.finish()

    # Show only the final iteration
    final_iter = history_df[HistoryColumns.ITERATION].max()
    final_iter_df = history_df[history_df[HistoryColumns.ITERATION] == final_iter]
    print(f"\nFinal iteration ({final_iter}):")
    pretty_print_df(get_pareto_frontier_df(final_iter_df))

    return history, history_df


if __name__ == "__main__":
    # exit()

    features = [321, 330, 2079, 8618, 11046]
    for feature in features:
        all_history_dfs = []
        for start_text in [
            "A generic text that is not related to anything, I don't know why or what it is about",
            "What if I try a generic start text? Does that work better, who knows? Let's see!",
            "you never know what's going to happen next, so you just have to wait and see",
            "I'm not sure what this is about, but I'm going to try to make it as generic as possible",
            "A man walks into a bar and orders a beer, he then looks around and sees a woman sitting alone",
            "A child is playing in the park, and the sky is blue and the sun is shining",
            "The old bookstore on the corner has thousands of stories waiting to be discovered.",
            "She picked up her coffee and watched the rain fall gently against the window pane.",
            "The scientists gathered in the laboratory to discuss their latest breakthrough findings.",
        ]:
            history, history_df = run_sae_fluent_dreaming(
                iters=300,  ###
                explore_per_pop=32,
                population_size=15,
                restart_frequency=None,
                feature=feature,
                use_wandb=False,
                start_text=start_text,
                use_inpainting=True,
                inpaint_every=16,  # todo 16
                do_visualize=False,
                seq_len=20,
                visualize_path="bah.html",
            )

            # max_iter = history_df[HistoryColumns.ITERATION].max()
            # filtered_df = history_df[n
            #     (history_df[HistoryColumns.ITERATION] == max_iter)
            #     & (history_df[HistoryColumns.CHILD] == 0)
            # ]

            all_history_dfs.append(history_df)

        # Concatenate all history dataframes
        import pandas as pd
        from datasets import Dataset

        combined_history_df = pd.concat(all_history_dfs, ignore_index=True)

        hf_dataset = Dataset.from_pandas(combined_history_df)
        print("saving to", "epo_results_with_model_helper")
        # Save to Hugging Face Hub
        hf_dataset.push_to_hub(
            f"Diffusion_results_{feature}",
            split="train",
            private=True,
            token=None,
        )

        all_history_dfs = []
        for start_text in [
            "A generic text that is not related to anything, I don't know why or what it is about",
            "What if I try a generic start text? Does that work better, who knows? Let's see!",
            "you never know what's going to happen next, so you just have to wait and see",
            "I'm not sure what this is about, but I'm going to try to make it as generic as possible",
            "A man walks into a bar and orders a beer, he then looks around and sees a woman sitting alone",
            "A child is playing in the park, and the sky is blue and the sun is shining",
            "The old bookstore on the corner has thousands of stories waiting to be discovered.",
            "She picked up her coffee and watched the rain fall gently against the window pane.",
            "The scientists gathered in the laboratory to discuss their latest breakthrough findings.",
        ]:
            history, history_df = run_sae_fluent_dreaming(
                iters=300,  ###
                explore_per_pop=32,
                population_size=15,
                restart_frequency=None,
                feature=feature,
                use_wandb=False,
                start_text=start_text,
                use_model_helper=True,
                model_helper_run_every=76,  ####
                model_helper_model_name="gpt-4o",  ####
                do_visualize=False,
                seq_len=20,
                visualize_path="bah.html",
            )

            # max_iter = history_df[HistoryColumns.ITERATION].max()
            # filtered_df = history_df[
            #     (history_df[HistoryColumns.ITERATION] == max_iter)
            #     & (history_df[HistoryColumns.CHILD] == 0)
            # ]

            all_history_dfs.append(history_df)

        # Concatenate all history dataframes
        import pandas as pd
        from datasets import Dataset

        combined_history_df = pd.concat(all_history_dfs, ignore_index=True)

        hf_dataset = Dataset.from_pandas(combined_history_df)
        print("saving to", "epo_results_with_model_helper")
        # Save to Hugging Face Hub
        hf_dataset.push_to_hub(
            f"fixed_mean_epo_results_with_model_helper_{feature}",
            split="train",
            private=True,
            token=None,
        )

        all_history_dfs = []

        for start_text in [
            "A generic text that is not related to anything, I don't know why or what it is about",
            "What if I try a generic start text? Does that work better, who knows? Let's see!",
            "you never know what's going to happen next, so you just have to wait and see",
            "I'm not sure what this is about, but I'm going to try to make it as generic as possible",
            "A man walks into a bar and orders a beer, he then looks around and sees a woman sitting alone",
            "A child is playing in the park, and the sky is blue and the sun is shining",
            "The old bookstore on the corner has thousands of stories waiting to be discovered.",
            "She picked up her coffee and watched the rain fall gently against the window pane.",
            "The scientists gathered in the laboratory to discuss their latest breakthrough findings.",
        ]:
            history, history_df = run_sae_fluent_dreaming(
                iters=300,
                explore_per_pop=32,
                population_size=15,
                restart_frequency=50,
                feature=feature,
                use_wandb=False,
                start_text=start_text,
                do_visualize=False,
                seq_len=20,
                visualize_path="bah.html",
            )

            # max_iter = history_df[HistoryColumns.ITERATION].max()
            # filtered_df = history_df[
            #     (history_df[HistoryColumns.ITERATION] == max_iter)
            #     & (history_df[HistoryColumns.CHILD] == 0)
            # ]

            all_history_dfs.append(history_df)

        # Concatenate all history dataframes
        combined_history_df = pd.concat(all_history_dfs, ignore_index=True)

        hf_dataset = Dataset.from_pandas(combined_history_df)
        # Save to Hugging Face Hub
        hf_dataset.push_to_hub(
            f"fixed_mean_epo_results_{feature}",
            split="train",
            private=True,
            token=None,
        )
    print(features)
