import argparse
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Tuple

import torch
import wandb
import yaml
from custom_dreamy.callbacks import (
    InpaintingCallback,
    ModelHelperCallback,
    WandbEPOCallback,
)
from custom_dreamy.epo import epo
from custom_dreamy.history import HistoryColumns
from custom_dreamy.i_runner import IRunner, transfomer_embed_one_hot_input
from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

from eliciting_contexts.benchmark.external.backdoors.data.load_data import (
    ApplicationsKeys,
    download_backdoors_dataset,
)
from eliciting_contexts.benchmark.external.shared.token_utils import (
    clean_process_text_with_placeholder,
)
from eliciting_contexts.utils.constants import RESULTS_DIR, WANDB_ENTITY
from eliciting_contexts.utils.load_models import load_finetuned_model


class LogitDiffRunner(IRunner):
    def __init__(
        self,
        model,
        tokenizer,
        desired_text: List[str],
        undesired_text: str,
    ):
        """
        Runner that uses LogProbDiffCalculator to compute the objective.

        Args:
            model: The model to use
            tokenizer: The tokenizer for the model
            desired_text: Word whose first token probability you want to boost
            undesired_text: Word to compare against

        """
        self.model = model

        self.desired_tokens = [
            torch.tensor(tokenizer.encode(opt, add_special_tokens=False))
            for opt in desired_text
        ]
        self.undesired_tokens = torch.tensor(
            tokenizer.encode(undesired_text, add_special_tokens=False)
        )

    def run_with_embeddings(
        self, input_embeddings: torch.Tensor, input_ids=None
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        Dict[str, Any],
    ]:

        # Use the first token's logits for the objective calculation
        model_logits = self.model(inputs_embeds=input_embeddings).logits

        # Get the logits for the next token (position seq_length-1)
        next_token_logits = model_logits[:, -1, :]

        # Get target and negative token IDs
        self.desired_tokens = [
            opt.to(input_embeddings.device) for opt in self.desired_tokens
        ]
        self.undesired_tokens = self.undesired_tokens.to(input_embeddings.device)
        target_token_ids = [opt[0] for opt in self.desired_tokens]
        negative_token_id = self.undesired_tokens[0]

        # Extract relevant logits for each example in batch
        target_logits = None
        for target_id in target_token_ids:
            if target_logits is None:
                target_logits = next_token_logits[:, target_id]
            else:
                target_logits = target_logits + next_token_logits[:, target_id]
        target_logits = target_logits / len(target_token_ids)

        negative_logits = next_token_logits[:, negative_token_id]

        # Compute the difference
        target = target_logits - negative_logits

        # Return target, logits, and empty dict for additional info
        return target, model_logits, {}

    def one_hot_to_embed(self, one_hot: torch.Tensor) -> torch.Tensor:
        """Convert one-hot encoded input to embedded input."""
        return transfomer_embed_one_hot_input(one_hot, self.model)

    def int_ids_to_embed(self, int_ids: torch.Tensor) -> torch.Tensor:
        """Convert integer token IDs to embedded input."""
        return self.model.get_input_embeddings()(int_ids)


# WandB config setup
class Config:
    """Configuration class for experiment parameters"""

    def __init__(self, config_path=None):
        # Set default settings
        self._set_defaults()

        # Override with config file if provided
        if config_path:
            self._load_from_yaml(config_path)

    def _set_defaults(self):
        self.num_samples = None

        # file settings
        self.output_json = f"{RESULTS_DIR}/optimization_results.json"

        # Hardware settings
        self.device = "cuda"
        self.dtype = "bfloat16"

        # EPO parameters
        self.num_runs = 3  #########################
        self.population_size = 8
        self.explore_per_pop = 32
        self.batch_size = 32
        self.iters = 100  ########################
        self.x_penalty_min = 0.1
        self.x_penalty_max = 10.0

        self.use_gcg = False

        # EPO (only used by default)
        self.restart_frequency = 51

        # EPO (with openai assist)
        self.use_assist = False
        self.assist_prompt_type = "stories"
        self.epo_assist_run_every = 51
        self.epo_assist_temperature = 1.0
        self.epo_assist_model = "gpt-4o"
        self.epo_assist_run_on_last_step = True  ####
        self.save_last_two_iterations = (
            True
            if self.use_assist
            else False  ### Used to save with and without last step being epo
        )
        # EPO (with llada inpainting)
        self.use_inpaint = False
        self.epo_inpaint_every = 15
        self.epo_inpaint_replace_percent = 0.75
        self.epo_inpaint_replace_prob = 0.5
        self.epo_inpaint_token_per_step = 4
        self.epo_inpaint_unmasking = "high_confidence"
        self.epo_inpaint_model_name = "GSAI-ML/LLaDA-8B-Instruct"

        # WandB settings
        self.wandb_project = "simple_stories_epo"
        self.wandb_entity = WANDB_ENTITY
        self.wandb_mode = "online"

        # HF settings
        self.hf_repo_id = "contextmodification/backdoors-benchmark-dataset" 

        self.hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN")

    def _load_from_yaml(self, config_path):
        """Load configuration from YAML file"""
        with open(config_path, "r") as f:
            config_data = yaml.safe_load(f)

        # Update attributes from YAML
        for key, value in config_data.items():
            if hasattr(self, key):
                setattr(self, key, value)


def optimize_for_behaviour(
    model: HookedTransformer,
    tokenizer,
    desired_text: str,
    undesired_text: str,
    variable_context: str,
    template: str,
    config: Config,
    device: str = "cuda",
    verbose: bool = True,
    log_name: str = "run_epo",
    override_vocab_size: int = None,
) -> Dict[str, Any]:

    # Set up runner
    runner = LogitDiffRunner(
        model,
        tokenizer,
        desired_text,
        undesired_text,
    )

    initial_ids, fixed_positions = clean_process_text_with_placeholder(
        tokenizer, template, variable_context, verbose=True, skip_special_tokens=True
    )

    all_results = []
    alt_results = []

    # Run optimization multiple times according to config.num_runs
    for run_idx in range(config.num_runs):
        if verbose:
            print(f"Starting EPO run {run_idx+1}/{config.num_runs}")

        # Set up EPO optimization
        callbacks = []
        current_initial_ids = initial_ids.unsqueeze(0).repeat(config.population_size, 1)
        seq_len = current_initial_ids.shape[-1]

        if config.use_assist:
            raise NotImplementedError("Assist not implemented for applications")
            if config.assist_prompt_type == "stories":
                story_context = {
                    "template": template,
                    "variable_context": variable_context,
                    "desired_text": desired_text,
                    "undesired_text": undesired_text,
                    "fixed_positions": torch.tensor(fixed_positions),
                    "initial_token_ids": initial_ids,
                }

            callbacks.append(
                ModelHelperCallback(
                    tokenizer,
                    run_every=config.epo_assist_run_every,
                    model_name=config.epo_assist_model,
                    temperature=config.epo_assist_temperature,
                    prompt_type=config.assist_prompt_type,
                    run_on_last_step=config.epo_assist_run_on_last_step,
                    num_mutations=config.explore_per_pop,
                    stories_context=(
                        story_context
                        if config.assist_prompt_type == "stories"
                        else None
                    ),
                )
            )

        if config.use_inpaint:
            callbacks.append(
                InpaintingCallback(
                    tokenizer,
                    inpaint_every=config.epo_inpaint_every,
                    model_name=config.epo_inpaint_model_name,
                    replace_percent=config.epo_inpaint_replace_percent,
                    replace_prob=config.epo_inpaint_replace_prob,
                    token_per_step=config.epo_inpaint_token_per_step,
                    unmasking=config.epo_inpaint_unmasking,
                    device=device,
                    torch_dtype=torch.bfloat16,
                )
            )

        if wandb.run is not None and wandb.run.mode == "online":
            callbacks.append(
                WandbEPOCallback(
                    runner,
                    model,
                    tokenizer,
                    x_penalty_min=config.x_penalty_min,
                    x_penalty_max=config.x_penalty_max,
                    name=f"{log_name}/run_{run_idx}",
                    per_pop_log=False,
                )
            )

        # Run optimization

        if override_vocab_size is not None:
            vocab_size = override_vocab_size
        else:
            vocab_size = tokenizer.vocab_size

        history = epo(
            runner,
            model,
            tokenizer,
            iters=config.iters,
            initial_ids=current_initial_ids,
            fixed_positions=fixed_positions,
            population_size=config.population_size,
            seq_len=seq_len,
            explore_per_pop=config.explore_per_pop,
            restart_frequency=(
                None
                if (config.use_assist or config.use_inpaint)
                else config.restart_frequency
            ),
            callbacks=callbacks,
            batch_size=config.batch_size,
            device=device,
            verbose=verbose,
            x_penalty_min=config.x_penalty_min,
            x_penalty_max=config.x_penalty_max,
            override_vocab_size=vocab_size,
        )

        # Process history results
        history_df = history.to_dataframe(tokenizer, iter=-1, child=0)
        history_df = history_df.drop_duplicates(subset=[HistoryColumns.TOKEN_IDS])

        # Extract and process tokens
        all_tokens = history_df[HistoryColumns.TOKEN_IDS].tolist()
        fixed_positions_tensor = torch.tensor(fixed_positions)

        for tokens in all_tokens:
            tokens = list(map(int, tokens))
            tokens = torch.tensor(tokens)
            user_tokens = tokens[~fixed_positions_tensor]
            text = tokenizer.decode(user_tokens)
            all_results.append(text)

        if config.save_last_two_iterations:

            history_df = history.to_dataframe(tokenizer, iter=-2, child=0)
            history_df = history_df.drop_duplicates(subset=[HistoryColumns.TOKEN_IDS])

            # Extract and process tokens
            all_tokens = history_df[HistoryColumns.TOKEN_IDS].tolist()
            fixed_positions_tensor = torch.tensor(fixed_positions)

            for tokens in all_tokens:
                tokens = list(map(int, tokens))
                tokens = torch.tensor(tokens)
                user_tokens = tokens[~fixed_positions_tensor]
                text = tokenizer.decode(user_tokens)
                alt_results.append(text)

            return [all_results, alt_results]

    return all_results


if __name__ == "__main__":
    # Set up argument parser
    parser = argparse.ArgumentParser(description="Run EPO optimization for TinyStories")
    parser.add_argument("--config", type=str, help="Path to YAML configuration file")
    parser.add_argument(
        "--use_assist",
        action="store_true",
        help="Enable OpenAI assistance for optimization (overrides config)",
    )

    parser.add_argument(
        "--use_gcg",
        action="store_true",
        help="Enable GCG for optimization (overrides config)",
    )
    parser.add_argument(
        "--output", type=str, help="Path to save output JSON results (overrides config)"
    )
    parser.add_argument(
        "--start_text",
        type=str,
        help="Text to start with",
        default="Hello and welcome and also good luck!",
    )
    args = parser.parse_args()

    # Create config object (either with defaults or from YAML)
    config = Config(config_path=args.config)

    # Override config with command line arguments if provided
    if args.use_assist:
        config.use_assist = True
        config.save_last_two_iterations = True
        print("Enabling OpenAI assistance (overridden by command line)")

    if args.use_gcg:
        config.use_gcg = True
        print("Enabling GCG (overridden by command line)")

    if args.output:
        config.output_json = args.output
        print(f"Using custom output path: {args.output}")

    if config.use_gcg:
        config.population_size = 1
        config.x_penalty_min = 0.0
        config.x_penalty_max = 0.0

    # Create outputs directory if it doesn't exist
    Path(os.path.dirname(config.output_json)).mkdir(parents=True, exist_ok=True)

    wandb.init(
        project=config.wandb_project,
        entity=config.wandb_entity,
        mode=config.wandb_mode,
    )
    config_dict = {
        attr: getattr(config, attr)
        for attr in dir(config)
        if not attr.startswith("_") and not callable(getattr(config, attr))
    }
    wandb.config.update(config_dict)

    current_base_model_name = None
    current_lora_id = None

    dataset = download_backdoors_dataset(
        hf_token=config.hf_token, dataset_name=config.hf_repo_id
    )

    all_results = {}
    alt_results = {}

    # for idx, datum in enumerate(dataset["test"].select([0])):  # , 40, 70, 110
    # for idx, datum in enumerate(dataset["test"].select(range(config.num_samples))):
    for idx, datum in enumerate(dataset["train"]):

        base_model_name = datum[ApplicationsKeys.BASE_MODEL]
        lora_id = datum[ApplicationsKeys.LORA_ID]
        if base_model_name != current_base_model_name or lora_id != current_lora_id:

            if lora_id:
                print("Loading finetuned model", lora_id, base_model_name)
                model, tokenizer = load_finetuned_model(
                    lora_model_id=lora_id,
                    base_model_name=base_model_name,
                    device=config.device,
                )
            else:
                tokenizer = AutoTokenizer.from_pretrained(
                    base_model_name,
                )
                model = AutoModelForCausalLM.from_pretrained(
                    base_model_name,
                    device_map=config.device,
                    torch_dtype=getattr(torch, config.dtype),
                )

            current_base_model_name = base_model_name
            current_lora_id = lora_id
        template = datum[ApplicationsKeys.TEMPLATE]

        variable_context = args.start_text
        undesired_text = datum[ApplicationsKeys.UNDESIRED_TEXT]
        desired_text = datum[ApplicationsKeys.DESIRED_TEXT]

        results = optimize_for_behaviour(
            model=model,
            tokenizer=tokenizer,
            desired_text=desired_text,
            undesired_text=undesired_text,
            variable_context=variable_context,
            template=template,
            config=config,
            device=config.device,
            verbose=True,
            log_name=f"run_{idx}",
            # ugly exceptoin
            override_vocab_size=(
                None
                if base_model_name != "saraprice/llama2-7B-backdoor-headlines-2017-2019"
                else 32016
            ),
        )

        if config.save_last_two_iterations:
            alt_results[idx] = results[1]
            all_results[idx] = results[0]
        else:
            all_results[idx] = results

    with open(config.output_json, "w") as f:
        json.dump(all_results, f, indent=2)

    if len(alt_results) > 0:
        with open(config.output_json.replace(".json", "_raw.json"), "w") as f:
            json.dump(alt_results, f, indent=2)
        results_artifact = wandb.Artifact(
            "optimization_results",
            type="results",
            description="Results from simple stories optimization",
        )
        results_artifact.add_file(
            config.output_json.replace(".json", "_raw.json"),
            "alt_optimization_results.json",
        )
        wandb.log_artifact(results_artifact)

    # Also save to wandb artifact
    results_artifact = wandb.Artifact(
        "optimization_results",
        type="results",
        description="Results from simple stories optimization",
    )
    results_artifact.add_file(config.output_json, "optimization_results.json")
    wandb.log_artifact(results_artifact)

    print(f"Results saved to {config.output_json} and logged to WandB")

    wandb.finish()
