import argparse
import json
import os
from pathlib import Path
from typing import Any, Dict, Tuple

import torch
import yaml
from custom_dreamy.callbacks import (
    InpaintingCallback,
    ModelHelperCallback,
    WandbEPOCallback,
)
from custom_dreamy.epo import epo
from custom_dreamy.history import HistoryColumns
from custom_dreamy.i_runner import IRunner
from transformer_lens import HookedTransformer

import wandb
from eliciting_contexts.benchmark.external.shared.text_utils import (
    process_text_with_placeholder,
)
from eliciting_contexts.benchmark.external.tiny_stories.data.load_data import (
    StoryKeys,
    download_tiny_stories_dataset,
)
from eliciting_contexts.utils.constants import RESULTS_DIR, WANDB_ENTITY


class TlensLogitDiffRunner(IRunner):
    def __init__(
        self,
        model,
        tokenizer,
        desired_text: str,
        undesired_text: str,
    ):
        """
        Runner that uses LogProbDiffCalculator to compute the objective.

        Args:
            model: The model to use
            tokenizer: The tokenizer for the model
            desired_text: Word whose first token probability you want to boost
            undesired_text: Word to compare against

        """
        self.model = model
        self.desired_tokens = torch.tensor(
            tokenizer.encode(desired_text, add_special_tokens=False)
        )
        self.undesired_tokens = torch.tensor(
            tokenizer.encode(undesired_text, add_special_tokens=False)
        )

    def run_with_embeddings(self, input_embeddings: torch.Tensor) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        Dict[str, Any],
    ]:

        model_logits = self.model(
            input_embeddings, start_at_layer=0, return_type="logits"
        )

        # Get the logits for the next token (position seq_length-1)
        next_token_logits = model_logits[:, -1, :]

        # Get target and negative token IDs (first token only)
        self.desired_tokens = self.desired_tokens.to(input_embeddings.device)
        self.undesired_tokens = self.undesired_tokens.to(input_embeddings.device)
        target_token_id = self.desired_tokens[0]
        negative_token_id = self.undesired_tokens[0]

        # Extract relevant logits for each example in batch
        target_logits = next_token_logits[:, target_token_id]
        negative_logits = next_token_logits[:, negative_token_id]

        # Compute the difference
        target = target_logits - negative_logits

        # Return target, logits, and empty dict for additional info
        return target, model_logits, {}

    def one_hot_to_embed(self, one_hot: torch.Tensor) -> torch.Tensor:
        """Convert one-hot encoded input to embedded input."""
        embed_matrix = self.model.embed.W_E
        return torch.matmul(one_hot, embed_matrix)

    def int_ids_to_embed(self, int_ids: torch.Tensor) -> torch.Tensor:
        """Convert integer token IDs to embedded input."""
        return self.model.embed(int_ids)


# WandB config setup
class Config:
    """Configuration class for experiment parameters"""

    def __init__(self, config_path=None):
        # Set default settings
        self._set_defaults()

        # Override with config file if provided
        if config_path:
            self._load_from_yaml(config_path)

    def _set_defaults(self):
        self.num_samples = None

        # file settings
        self.output_json = f"{RESULTS_DIR}/optimization_results.json"

        # Hardware settings
        self.device = "cuda"
        self.dtype = "bfloat16"
        self.model_name = "google/gemma-2-2b-it"

        # EPO parameters
        self.num_runs = 3  #########################
        self.population_size = 8
        self.explore_per_pop = 32
        self.batch_size = 32
        self.iters = 150  ########################
        self.x_penalty_min = 0.1
        self.x_penalty_max = 10.0

        self.use_gcg = False

        # EPO (only used by default)
        self.restart_frequency = 51

        # EPO (with openai assist)
        self.use_assist = False
        self.assist_prompt_type = "stories"
        self.epo_assist_run_every = 51
        self.epo_assist_temperature = 1.0
        self.epo_assist_model = "gpt-4o"
        self.epo_assist_run_on_last_step = True  ####
        self.save_last_two_iterations = (
            True
            if self.use_assist
            else False  ### Used to save with and without last step being epo
        )
        # EPO (with llada inpainting)
        self.use_inpaint = False
        self.epo_inpaint_every = 15
        self.epo_inpaint_replace_percent = 0.75
        self.epo_inpaint_replace_prob = 0.5
        self.epo_inpaint_token_per_step = 4
        self.epo_inpaint_unmasking = "high_confidence"
        self.epo_inpaint_model_name = "GSAI-ML/LLaDA-8B-Instruct"

        # WandB settings
        self.wandb_project = "simple_stories_epo"
        self.wandb_entity = WANDB_ENTITY
        self.wandb_mode = "online"

        # HF settings
        self.hf_repo_id = "contextmodification/simple_stories_new"
        self.hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN")

    def _load_from_yaml(self, config_path):
        """Load configuration from YAML file"""
        with open(config_path, "r") as f:
            config_data = yaml.safe_load(f)

        # Update attributes from YAML
        for key, value in config_data.items():
            if hasattr(self, key):
                setattr(self, key, value)


def optimize_for_word_prediction(
    model: HookedTransformer,
    tokenizer,
    desired_text: str,
    undesired_text: str,
    variable_context: str,
    template: str,
    config: Config,
    device: str = "cuda",
    verbose: bool = True,
    log_name: str = "run_epo",
) -> Dict[str, Any]:

    # Set up runner
    runner = TlensLogitDiffRunner(
        model,
        tokenizer,
        desired_text,
        undesired_text,
    )

    # Process text with placeholder for optimization
    # TODO clean up AND move
    initial_ids, fixed_positions = process_text_with_placeholder(
        tokenizer, template, variable_context, start_gap=1, end_gap=1
    )

    all_results = []
    alt_results = []

    # Run optimization multiple times according to config.num_runs
    for run_idx in range(config.num_runs):
        if verbose:
            print(f"Starting EPO run {run_idx+1}/{config.num_runs}")

        # Set up EPO optimization
        callbacks = []
        current_initial_ids = initial_ids.unsqueeze(0).repeat(config.population_size, 1)
        seq_len = current_initial_ids.shape[-1]

        if config.use_assist:
            if config.assist_prompt_type == "stories":
                story_context = {
                    "template": template,
                    "variable_context": variable_context,
                    "desired_text": desired_text,
                    "undesired_text": undesired_text,
                    "fixed_positions": torch.tensor(fixed_positions),
                    "initial_token_ids": initial_ids,
                }

            callbacks.append(
                ModelHelperCallback(
                    tokenizer,
                    run_every=config.epo_assist_run_every,
                    model_name=config.epo_assist_model,
                    temperature=config.epo_assist_temperature,
                    prompt_type=config.assist_prompt_type,
                    run_on_last_step=config.epo_assist_run_on_last_step,
                    num_mutations=config.explore_per_pop,
                    stories_context=(
                        story_context
                        if config.assist_prompt_type == "stories"
                        else None
                    ),
                )
            )

        if config.use_inpaint:
            callbacks.append(
                InpaintingCallback(
                    tokenizer,
                    inpaint_every=config.epo_inpaint_every,
                    model_name=config.epo_inpaint_model_name,
                    replace_percent=config.epo_inpaint_replace_percent,
                    replace_prob=config.epo_inpaint_replace_prob,
                    token_per_step=config.epo_inpaint_token_per_step,
                    unmasking=config.epo_inpaint_unmasking,
                    device=device,
                    torch_dtype=torch.bfloat16,
                )
            )

        if wandb.run is not None and wandb.run.mode == "online":
            callbacks.append(
                WandbEPOCallback(
                    runner,
                    model,
                    tokenizer,
                    x_penalty_min=config.x_penalty_min,
                    x_penalty_max=config.x_penalty_max,
                    name=f"{log_name}/run_{run_idx}",
                    per_pop_log=False,
                )
            )

        # Run optimization
        history = epo(
            runner,
            model,
            tokenizer,
            iters=config.iters,
            initial_ids=current_initial_ids,
            fixed_positions=fixed_positions,
            population_size=config.population_size,
            seq_len=seq_len,
            explore_per_pop=config.explore_per_pop,
            restart_frequency=(
                None
                if (config.use_assist or config.use_inpaint)
                else config.restart_frequency
            ),
            callbacks=callbacks,
            batch_size=config.batch_size,
            device=device,
            verbose=verbose,
            x_penalty_min=config.x_penalty_min,
            x_penalty_max=config.x_penalty_max,
        )

        # Process history results
        history_df = history.to_dataframe(tokenizer, iter=-1, child=0)
        history_df = history_df.drop_duplicates(subset=[HistoryColumns.TOKEN_IDS])

        # Extract and process tokens
        all_tokens = history_df[HistoryColumns.TOKEN_IDS].tolist()
        fixed_positions_tensor = torch.tensor(fixed_positions)

        for tokens in all_tokens:
            tokens = list(map(int, tokens))
            tokens = torch.tensor(tokens)
            user_tokens = tokens[~fixed_positions_tensor]
            text = tokenizer.decode(user_tokens)
            all_results.append(text)

        if config.save_last_two_iterations:
            history_df = history.to_dataframe(tokenizer, iter=-2, child=0)
            history_df = history_df.drop_duplicates(subset=[HistoryColumns.TOKEN_IDS])

            # Extract and process tokens
            all_tokens = history_df[HistoryColumns.TOKEN_IDS].tolist()
            fixed_positions_tensor = torch.tensor(fixed_positions)

            for tokens in all_tokens:
                tokens = list(map(int, tokens))
                tokens = torch.tensor(tokens)
                user_tokens = tokens[~fixed_positions_tensor]
                text = tokenizer.decode(user_tokens)
                alt_results.append(text)

            return [all_results, alt_results]

    return all_results


if __name__ == "__main__":
    # Set up argument parser
    parser = argparse.ArgumentParser(description="Run EPO optimization for TinyStories")
    parser.add_argument("--config", type=str, help="Path to YAML configuration file")
    parser.add_argument(
        "--use_assist",
        action="store_true",
        help="Enable OpenAI assistance for optimization (overrides config)",
    )
    parser.add_argument(
        "--use_inpaint",
        action="store_true",
        help="Enable LLaDA inpainting for optimization (overrides config)",
    )
    parser.add_argument(
        "--output", type=str, help="Path to save output JSON results (overrides config)"
    )
    parser.add_argument(
        "--use_gcg",
        action="store_true",
        help="Enable GCG optimization (overrides config)",
    )
    args = parser.parse_args()

    # Create config object (either with defaults or from YAML)
    config = Config(config_path=args.config)

    # Override config with command line arguments if provided
    if args.use_assist:
        config.use_assist = True
        config.save_last_two_iterations = True
        print("Enabling OpenAI assistance (overridden by command line)")

    if args.use_inpaint:
        config.use_inpaint = True
        print("Enabling LLaDA inpainting (overridden by command line)")

    if args.output:
        config.output_json = args.output
        print(f"Using custom output path: {args.output}")

    if args.use_gcg:
        config.use_gcg = True
        print("Enabling GCG (overridden by command line)")

    if config.use_gcg:
        config.population_size = 1
        config.x_penalty_min = 0.0
        config.x_penalty_max = 0.0

    # Create outputs directory if it doesn't exist
    Path(os.path.dirname(config.output_json)).mkdir(parents=True, exist_ok=True)

    wandb.init(
        project=config.wandb_project,
        entity=config.wandb_entity,
        mode=config.wandb_mode,
    )
    config_dict = {
        attr: getattr(config, attr)
        for attr in dir(config)
        if not attr.startswith("_") and not callable(getattr(config, attr))
    }
    wandb.config.update(config_dict)

    model = HookedTransformer.from_pretrained(
        config.model_name, dtype=config.dtype, device=config.device
    )
    tokenizer = model.tokenizer

    dataset = download_tiny_stories_dataset(
        hf_token=config.hf_token, dataset_name=config.hf_repo_id
    )

    all_results = {}
    alt_results = {}

    # for idx, datum in enumerate(dataset["test"].select([0])):  # , 40, 70, 110
    # for idx, datum in enumerate(dataset["test"].select(range(config.num_samples))):
    for idx, datum in enumerate(dataset["test"]):

        template = datum[StoryKeys.TEMPLATE]
        variable_context = datum[StoryKeys.VARIABLE_TEXT]
        undesired_text = datum[StoryKeys.UNDESIRED_TEXT]
        for desired_text in datum[StoryKeys.DESIRED_TEXT]:
            print(
                f"Optimizing for {desired_text} vs {undesired_text}\n{template.format(variable_context)}"
            )
            results = optimize_for_word_prediction(
                model=model,
                tokenizer=tokenizer,
                desired_text=desired_text,
                undesired_text=undesired_text,
                variable_context=variable_context,
                template=template,
                config=config,
                device=config.device,
                verbose=True,
                log_name=f"run_{idx}",
            )

            if config.save_last_two_iterations:
                alt_results[idx] = results[1]
                all_results[idx] = results[0]
            else:
                all_results[idx] = results

            break

    with open(config.output_json, "w") as f:
        json.dump(all_results, f, indent=2)

    if len(alt_results) > 0:
        with open(config.output_json.replace(".json", "_raw.json"), "w") as f:
            json.dump(alt_results, f, indent=2)
        results_artifact = wandb.Artifact(
            "optimization_results",
            type="results",
            description="Results from simple stories optimization",
        )
        results_artifact.add_file(
            config.output_json.replace(".json", "_raw.json"),
            "alt_optimization_results.json",
        )
        wandb.log_artifact(results_artifact)

    # Also save to wandb artifact
    results_artifact = wandb.Artifact(
        "optimization_results",
        type="results",
        description="Results from simple stories optimization",
    )
    results_artifact.add_file(config.output_json, "optimization_results.json")
    wandb.log_artifact(results_artifact)

    print(f"Results saved to {config.output_json} and logged to WandB")

    wandb.finish()
