import argparse
import json
import os
from pathlib import Path
from typing import Any, Dict, List

import torch
import yaml
from custom_dreamy.callbacks import (
    InpaintingCallback,
    ModelHelperCallback,
    WandbEPOCallback,
)
from custom_dreamy.epo import epo
from custom_dreamy.history import HistoryColumns
from custom_dreamy.i_runner import IRunner
from datasets import load_dataset
from sae_lens import SAE
from tqdm import tqdm
from transformer_lens import HookedTransformer

import wandb
from eliciting_contexts.benchmark.external.sae_activation.data.load_data import (
    SAEKeys,
    download_sae_dataset,
)
from eliciting_contexts.benchmark.external.sae_activation.data.neuronpedia_utils import (
    get_saelens_release_and_id,
)
from eliciting_contexts.benchmark.external.sae_activation.utils import (
    load_model_tlens,
    load_sae_saelens,
)
from eliciting_contexts.utils.constants import RESULTS_DIR, WANDB_ENTITY
from eliciting_contexts.utils.neuronpedia_api import get_feature_description


class SentenceFetcher:
    """
    Class to extract sentences of specific length from a dataset.

    Initializes with a dataset and tokenizer, shuffles the dataset,
    and provides methods to extract sentences of specified length.
    """

    def __init__(
        self,
        dataset_name="Skylion007/openwebtext",
        split="train",
        streaming=True,
        seed=42,
    ):
        """
        Initialize the SentenceExtractor with a dataset and tokenizer.

        Args:
            dataset_name: Name of the HuggingFace dataset to load
            split: Dataset split to use (e.g., "train")
            tokenizer: Tokenizer for processing text
            streaming: Whether to load the dataset in streaming mode
            seed: Random seed for shuffling the dataset
        """

        # Load dataset
        self.dataset = load_dataset(
            dataset_name, split=split, streaming=streaming, trust_remote_code=True
        )

        # Shuffle dataset
        if seed is not None:
            self.dataset = self.dataset.shuffle(seed=seed)
        else:
            self.dataset = self.dataset.shuffle()

        # Initialize iterator
        self.iterator = iter(self.dataset)

    def get_single_sentence(self) -> tuple[torch.Tensor, str]:
        # Get the next sample
        try:
            sample = next(self.iterator)
        except StopIteration:
            # If we've exhausted the iterator, reset it
            self.iterator = iter(self.dataset)
            sample = next(self.iterator)

        text = sample["text"]

        return text


class SAERunnerTLens(IRunner):
    def __init__(
        self,
        model: HookedTransformer,
        sae: SAE,
        feature: int,
        neuronpedia_id: int,
        epo_target: str,
    ):
        """
        Initialize an SAERunner for TransformerLens models.

        Args:
            model: The TransformerLens model
            sae: The sparse autoencoder
            feature: The feature index to target
        """
        self.model = model
        self.sae = sae
        self.feature = feature
        self.hook_name = sae.cfg.hook_name
        self.epo_target = epo_target
        try:
            neuronpedia_info = get_feature_description(
                neuronpedia_id, index=int(self.feature)
            )
            top_window = neuronpedia_info.get_contexts_around_top_n_activations(
                n=10, window=15
            )
            top_activations = []
            out = {}

            def get_target(resid, hook):
                # Take mean across all tokens instead of just the last one
                inp_acts = resid

                pre_acts = []

                def store_pre_acts(acts: torch.Tensor, hook: str):
                    pre_acts.append(acts)
                    return acts

                sae.run_with_hooks(
                    inp_acts,
                    fwd_hooks=[("hook_sae_acts_pre", store_pre_acts)],
                )
                mean_pre_acts = pre_acts[0][:, :, sae_index]
                out["mean_activation"] = mean_pre_acts.mean(dim=1)
                out["max_activation"] = mean_pre_acts.max(dim=1)
                return resid

            for message in top_window:
                ids = model.tokenizer(
                    message, return_tensors="pt", add_special_tokens=False
                )["input_ids"][0]
                ids = ids.unsqueeze(0).to(sae.cfg.device)
                # Use TransformerLens hook system
                with model.hooks(fwd_hooks=[(sae.cfg.hook_name, get_target)]):
                    model(
                        ids,
                        return_type="logits",
                    )
                # always normalize by max activation
                top_activations.append(out["max_activation"][0].item())
                # if self.epo_target   == "MEAN":
                #     top_activations.append(out["mean_activation"][0].item()) #mean can be negative don't use this to normalize
                # elif self.epo_target == "MAX":
                #     top_activations.append(out["max_activation"][0].item())
                # else:
                #     raise ValueError(f"Invalid EPO target: {self.epo_target}")
                # max_activation = out["max_activation"][0].item()
                # mean_activation = out["mean_activation"][0].item()
                # top_activations.append(mean_activation)
        except Exception as e:
            print(f"Error getting neuronpedia info: {e}")
            top_activations = []
        self.feature_max_activation = max(top_activations)

        assert (
            not model.cfg.post_embedding_ln
        ), "post embedding ln currently not supported"

    def run_with_embeddings(
        self, input_embeddings: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, dict]:
        """
        Run the model with embedded inputs and return the target feature activation.

        Args:
            input_embeddings: Embedded input tensor

        Returns:
            Tuple of (target, logits, {})
        """
        out = {}

        def get_target(resid, hook):
            # Take mean across all tokens instead of just the last one
            inp_acts = resid

            pre_acts = []

            def store_pre_acts(acts: torch.Tensor, hook: str):
                pre_acts.append(acts)
                return acts

            self.sae.run_with_hooks(
                inp_acts,
                fwd_hooks=[("hook_sae_acts_pre", store_pre_acts)],
            )
            mean_pre_acts = pre_acts[0][:, :, self.feature]

            # Check if first token is BOS and handle it appropriately
            if torch.any(input_embeddings[0] == self.model.tokenizer.bos_token_id):
                # Create a mask to ignore BOS token in calculations while maintaining sequence length
                mask = torch.ones_like(mean_pre_acts)
                mask[:, 0] = 0  # Set BOS token position to 0
                # For mean calculation, we want to exclude the BOS token
                mean_acts = (mean_pre_acts * mask).sum(dim=1) / mask.sum(dim=1)
                # For max calculation, we want to consider all non-BOS tokens
                max_acts = (mean_pre_acts * mask).max(dim=1).values
            else:
                mean_acts = mean_pre_acts.mean(dim=1)
                max_acts = mean_pre_acts.max(dim=1).values

            # TODO: this is a hack to normalize the target to the max activation
            if self.epo_target == "MEAN":
                out["target"] = 10 * mean_acts / self.feature_max_activation
            elif self.epo_target == "MAX":
                out["target"] = 10 * max_acts / self.feature_max_activation
            else:
                raise ValueError(f"Invalid EPO target: {self.epo_target}")

            out["per_token_target"] = mean_pre_acts
            return resid

        # Use TransformerLens hook system
        with self.model.hooks(fwd_hooks=[(self.hook_name, get_target)]):
            logits = self.model(
                input_embeddings,
                start_at_layer=0,  # needed to skip embedding layer
                return_type="logits",
            )

        return out["target"], logits, {"per_token_target": out["per_token_target"]}

    def one_hot_to_embed(self, one_hot: torch.Tensor) -> torch.Tensor:
        """Convert one-hot encoded input to embedded input."""
        embed_matrix = self.model.embed.W_E
        return torch.matmul(one_hot, embed_matrix)

    def int_ids_to_embed(self, int_ids: torch.Tensor) -> torch.Tensor:
        """Convert integer token IDs to embedded input."""
        return self.model.embed(int_ids)


# WandB config setup
class Config:
    """Configuration class for experiment parameters"""

    def __init__(self, config_path=None):
        # Set default settings
        self._set_defaults()

        # Override with config file if provided
        if config_path:
            self._load_from_yaml(config_path)

    def _set_defaults(self):
        # file settings
        self.output_json = f"{RESULTS_DIR}/sae_optimization_results.json"

        # Hardware settings
        self.device = "cuda"
        self.dtype = "bfloat16"

        # EPO parameters
        self.num_runs = 3  ##################################
        self.initialize_seed = 0  # set to None for random initialization
        self.seq_len = 36
        self.population_size = 8
        self.explore_per_pop = 32
        self.batch_size = 64  # 32
        self.iters = 150  # 150  #################################
        self.x_penalty_min = 0.1
        self.x_penalty_max = 10.0

        self.use_gcg = False

        # EPO (only used by default)
        self.restart_frequency = 51

        self.epo_target = "MEAN"

        # EPO (with openai assist)
        self.use_assist = False
        self.epo_assist_run_every = 51
        self.epo_assist_temperature = 1.0
        self.epo_assist_model = "gpt-4o"  ###################
        self.epo_assist_run_on_last_step = False  ###

        # EPO (with llada inpainting)
        self.use_inpaint = False
        self.epo_inpaint_every = 15
        self.epo_inpaint_replace_percent = 0.75
        self.epo_inpaint_replace_prob = 0.5
        self.epo_inpaint_token_per_step = 4
        self.epo_inpaint_unmasking = "high_confidence"
        self.epo_inpaint_model_name = "GSAI-ML/LLaDA-8B-Instruct"

        # WandB settings
        self.wandb_project = "sae_activation_epo"
        self.wandb_entity = WANDB_ENTITY
        self.wandb_mode = "online"

        # model settings
        self.load_with_no_processing = True

    def _load_from_yaml(self, config_path):
        """Load configuration from YAML file"""
        with open(config_path, "r") as f:
            config_data = yaml.safe_load(f)

        # Update attributes from YAML
        for key, value in config_data.items():
            if hasattr(self, key):
                setattr(self, key, value)


def run_epo(
    model: HookedTransformer,
    sae: SAE,
    feature: int,
    tokenizer,
    sentences: List[str],
    config: Config,
    device: str = "cuda",
    verbose: bool = True,
    log_name: str = "run_epo",
    neuronpedia_id: str = None,
) -> Dict[str, Any]:

    # Tokenize without truncation to get full token count

    all_initial_ids = []
    for sentence in sentences:
        full_tokens = tokenizer(sentence, return_tensors="pt", add_special_tokens=True)[
            "input_ids"
        ][0]

        # Take exactly max_length tokens (or pad if needed)
        if len(full_tokens) >= config.seq_len:
            tokens = full_tokens[: config.seq_len]
        else:
            # Pad to max_length if necessary
            pad_token = tokenizer.pad_token_id
            padding = torch.full(
                (config.seq_len - len(full_tokens),), pad_token, dtype=torch.long
            )
            tokens = torch.cat([full_tokens, padding])

        all_initial_ids.append(tokens)

    # Set up runner
    runner = SAERunnerTLens(model, sae, feature, neuronpedia_id, config.epo_target)
    # Process text with placeholder for optimization

    fixed_positions = None

    all_results = []

    # Run optimization multiple times according to config.num_runs
    for run_idx in range(config.num_runs):
        if verbose:
            print(f"Starting EPO run {run_idx+1}/{config.num_runs}")

        # Set up EPO optimization
        callbacks = []
        current_initial_ids = (
            all_initial_ids[run_idx].unsqueeze(0).repeat(config.population_size, 1)
        )
        seq_len = current_initial_ids.shape[-1]

        if config.use_assist:
            callbacks.append(
                ModelHelperCallback(
                    tokenizer,
                    run_every=config.epo_assist_run_every,
                    model_name=config.epo_assist_model,
                    temperature=config.epo_assist_temperature,
                    run_on_last_step=config.epo_assist_run_on_last_step,
                    num_mutations=config.explore_per_pop,
                )
            )

        if config.use_inpaint:
            callbacks.append(
                InpaintingCallback(
                    tokenizer,
                    inpaint_every=config.epo_inpaint_every,
                    model_name=config.epo_inpaint_model_name,
                    replace_percent=config.epo_inpaint_replace_percent,
                    replace_prob=config.epo_inpaint_replace_prob,
                    token_per_step=config.epo_inpaint_token_per_step,
                    unmasking=config.epo_inpaint_unmasking,
                    device=device,
                    torch_dtype=torch.bfloat16,
                )
            )

        if wandb.run is not None and wandb.run.mode == "online":
            callbacks.append(
                WandbEPOCallback(
                    runner,
                    model,
                    tokenizer,
                    x_penalty_min=config.x_penalty_min,
                    x_penalty_max=config.x_penalty_max,
                    name=f"{log_name}/run_{run_idx}",
                    per_pop_log=False,
                )
            )

        # Run optimization
        history = epo(
            runner,
            model,
            tokenizer,
            iters=config.iters,
            initial_ids=current_initial_ids,
            fixed_positions=fixed_positions,
            population_size=config.population_size,
            seq_len=seq_len,
            explore_per_pop=config.explore_per_pop,
            restart_frequency=(
                None
                if (config.use_assist or config.use_inpaint)
                else config.restart_frequency
            ),
            callbacks=callbacks,
            batch_size=config.batch_size,
            device=device,
            verbose=verbose,
            x_penalty_min=config.x_penalty_min,
            x_penalty_max=config.x_penalty_max,
        )

        # Process history results
        history_df = history.to_dataframe(tokenizer, iter=-1, child=0)
        history_df = history_df.drop_duplicates(subset=[HistoryColumns.TEXT])

        # Extract and process tokens
        all_texts = history_df[HistoryColumns.TEXT].tolist()

        all_results.extend(all_texts)

    return all_results


if __name__ == "__main__":

    # Set up argument parser
    parser = argparse.ArgumentParser(description="Run EPO optimization for TinyStories")
    parser.add_argument("--config", type=str, help="Path to YAML configuration file")
    parser.add_argument(
        "--use_assist",
        action="store_true",
        help="Enable OpenAI assistance for optimization (overrides config)",
    )
    parser.add_argument(
        "--use_assist_final",
        action="store_true",
        help="Enable OpenAI assistance for optimization (overrides config)",
    )
    parser.add_argument(
        "--use_inpaint",
        action="store_true",
        help="Enable LLaDA inpainting for optimization (overrides config)",
    )
    parser.add_argument(
        "--epo_target",
        type=str,
        help="Target for EPO (overrides config) and what it gets normalized by",
    )
    parser.add_argument(
        "--output", type=str, help="Path to save output JSON results (overrides config)"
    )
    parser.add_argument(
        "--use_gcg",
        action="store_true",
        help="Enable GCG optimization (overrides config)",
    )
    parser.add_argument(
        "--load_with_no_processing",
        action="store_true",
        help="Load the model with no processing (overrides config)",
    )
    parser.add_argument(
        "--layer_index",
        type=int,
        help="Optional: Run only on a specific layer index (1-based). If not provided, runs on all layers.",
    )
    args = parser.parse_args()

    # Create config object (either with defaults or from YAML)
    config = Config(config_path=args.config)

    # Override config with command line arguments if provided
    if args.use_assist:
        config.use_assist = True
        print("Enabling OpenAI assistance (overridden by command line)")

    if args.use_assist_final:
        config.use_assist = True
        config.epo_assist_run_on_last_step = True
        print("Enabling OpenAI assistance for final step (overridden by command line)")

    if args.use_inpaint:
        config.use_inpaint = True
        print("Enabling LLaDA inpainting (overridden by command line)")

    if args.epo_target:
        config.epo_target = args.epo_target
    if args.output:
        config.output_json = args.output
        print(f"Using custom output path: {args.output}")

    if args.use_gcg:
        config.use_gcg = True
        print("Enabling GCG (overridden by command line)")

    if args.load_with_no_processing:
        config.load_with_no_processing = True
        print("Loading model with no processing (overridden by command line)")

    # Apply GCG settings if enabled
    if config.use_gcg:
        config.population_size = 1
        config.x_penalty_min = 0.0
        config.x_penalty_max = 0.0

    # Create outputs directory if it doesn't exist
    Path(os.path.dirname(config.output_json)).mkdir(parents=True, exist_ok=True)

    # Create a separate file for initialization sentences
    init_sentences_file = os.path.join(
        os.path.dirname(config.output_json), "initialization_sentences.json"
    )

    wandb.init(
        project=config.wandb_project,
        entity=config.wandb_entity,
        mode=config.wandb_mode,
    )
    config_dict = {
        attr: getattr(config, attr)
        for attr in dir(config)
        if not attr.startswith("_") and not callable(getattr(config, attr))
    }
    wandb.config.update(config_dict)

    dataset = download_sae_dataset(hf_token=os.environ.get("HUGGINGFACE_HUB_TOKEN"))

    sentence_fetcher = SentenceFetcher(seed=config.initialize_seed)
    sentences = []
    for _ in range(config.num_runs):
        sentences.append(sentence_fetcher.get_single_sentence())

    all_results = {}

    # Group data by neuronpedia_id
    neuronpedia_groups = {}
    for idx, datum in enumerate(dataset["test"]):
        neuronpedia_id = datum[SAEKeys.NEURONPEDIA_ID]
        if neuronpedia_id not in neuronpedia_groups:
            neuronpedia_groups[neuronpedia_id] = []
        neuronpedia_groups[neuronpedia_id].append((idx, datum))

    # If layer_index is specified, only process that layer
    if args.layer_index is not None:
        if args.layer_index < 1 or args.layer_index > len(neuronpedia_groups):
            raise ValueError(
                f"Layer index must be between 1 and {len(neuronpedia_groups)}"
            )
        # Convert to 0-based index and get the specific layer
        layer_idx = args.layer_index - 1
        neuronpedia_id = list(neuronpedia_groups.keys())[layer_idx]
        neuronpedia_groups = {neuronpedia_id: neuronpedia_groups[neuronpedia_id]}
        print(
            f"Running only on layer {args.layer_index} (neuronpedia_id: {neuronpedia_id})"
        )

        # Add suffix to output filenames
        base_output = os.path.splitext(config.output_json)[0]
        config.output_json = f"{base_output}_layer{args.layer_index}.json"
        init_sentences_file = os.path.join(
            os.path.dirname(config.output_json),
            f"initialization_sentences_layer{args.layer_index}.json",
        )

    current_model_name = None
    save_counter = 0
    for neuronpedia_id, examples in tqdm(neuronpedia_groups.items()):

        print(f"Neuronpedia ID: {neuronpedia_id}")

        sae_lens_release, saelens_sae_id = get_saelens_release_and_id(neuronpedia_id)

        sae = load_sae_saelens(
            sae_lens_release, saelens_sae_id, config.device, config.dtype
        )

        model_name = sae.cfg.model_name

        if model_name != current_model_name:
            model = load_model_tlens(
                model_name, config.device, config.dtype, config.load_with_no_processing
            )
            tokenizer = model.tokenizer
            current_model_name = model_name

        for original_idx, datum in examples:

            assert neuronpedia_id == datum[SAEKeys.NEURONPEDIA_ID]

            sae_index = datum[SAEKeys.INDEX]
            all_results[original_idx] = {}
            print(f"Optimizing for {sae_index}")
            print(f"Neuronpedia ID: {neuronpedia_id}")
            results = run_epo(
                model=model,
                sae=sae,
                feature=sae_index,
                tokenizer=tokenizer,
                sentences=sentences,
                config=config,
                device=config.device,
                verbose=True,
                log_name=f"run_{original_idx}",
                neuronpedia_id=neuronpedia_id,
            )

            # Save both the results and the initialization sentences
            all_results[original_idx] = results

            # Save initialization sentences separately
            if not os.path.exists(init_sentences_file):
                init_sentences = {}
            else:
                with open(init_sentences_file, "r") as f:
                    init_sentences = json.load(f)

            init_sentences[str(original_idx)] = sentences

            with open(init_sentences_file, "w") as f:
                json.dump(init_sentences, f, indent=2)

            # break

        # Save intermediate results after each neuronpedia_id group
        intermediate_output = (
            f"{os.path.splitext(config.output_json)[0]}_{save_counter}.json"
        )
        with open(intermediate_output, "w") as f:
            json.dump(all_results, f, indent=2)

        # Upload intermediate results to wandb
        intermediate_artifact = wandb.Artifact(
            f"optimization_results_{save_counter}",
            type="results",
            description=f"Intermediate results from sae optimization (save {save_counter})",
        )
        intermediate_artifact.add_file(
            intermediate_output, f"optimization_results_{save_counter}.json"
        )
        wandb.log_artifact(intermediate_artifact)

        save_counter += 1

    with open(config.output_json, "w") as f:
        json.dump(all_results, f, indent=2)

    # Also save to wandb artifact
    results_artifact = wandb.Artifact(
        f"optimization_results{'_layer' + str(args.layer_index) if args.layer_index is not None else ''}",
        type="results",
        description=f"Results from simple stories optimization{' for layer ' + str(args.layer_index) if args.layer_index is not None else ''}",
    )
    results_artifact.add_file(config.output_json, "optimization_results.json")
    wandb.log_artifact(results_artifact)

    print(f"Results saved to {config.output_json} and logged to WandB")

    wandb.finish()
