import argparse
import json
import os

import torch
from contextbench.load_data import BackdoorKeys, download_backdoors_dataset
from contextbench.utils import load_model_tlens, tokenize_text_with_placeholder
from huggingface_hub import list_repo_files
from peft import AutoPeftModelForCausalLM
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer


class Info:
    def __init__(
        self,
        template,
        variable_context,
        undesired_text,
        desired_text,
        dataset_type,
        dataset_info,
    ):
        self.template = template
        self.variable_context = variable_context
        self.undesired_text = undesired_text
        self.desired_text = desired_text
        self.dataset_type = dataset_type
        self.dataset_info = dataset_info

    def to_dict(self):
        """Convert the Info object to a dictionary."""
        return {
            "template": self.template,
            "variable_context": self.variable_context,
            "undesired_text": self.undesired_text,
            "desired_text": self.desired_text,
            "dataset_type": self.dataset_type,
            "dataset_info": self.dataset_info,
        }


class Result:
    def __init__(
        self,
        input_text,
        cross_entropy,
        predicted_text,
    ):
        self.input_text = input_text

        self.cross_entropy = cross_entropy
        self.predicted_text = predicted_text

    def __str__(self):
        return (
            f"==== Result Summary ====\n"
            f"Text: {self.input_text}\n"
            f"Cross entropy: {self.cross_entropy:.4f}\n\n"
            f"Predicted next tokens: '{self.predicted_text}'"
        )

    def to_dict(self):
        """Convert the Result object to a dictionary."""
        result_dict = {
            "input_text": self.input_text,
            "predicted_text": self.predicted_text,
            "cross_entropy": self.cross_entropy,
        }
        return result_dict

    @classmethod
    def from_dict(cls, data):
        """Create a Result object from a dictionary."""
        return cls(
            input_text=data["input_text"],
            cross_entropy=data["cross_entropy"],
            predicted_text=data["predicted_text"],
        )


@torch.no_grad()
def get_metrics(
    template,
    variable_context,
    result,
    undesired_text,
    desired_text,
    evaluation_model,
    cross_entropy_model,
    device="cuda",
    max_new_tokens=15,
    temperature=1.0,
):

    # #### Compute targets

    new_story = template.format(result)

    #### Compute cross entropy
    story_ids_cross, fixed_positions = tokenize_text_with_placeholder(
        cross_entropy_model.tokenizer,
        template,
        result,
        skip_special_tokens=True,
    )
    story_ids_cross = story_ids_cross.unsqueeze(0).to(device)
    model_logits_cross_entropy = cross_entropy_model(
        story_ids_cross, return_type="logits"
    )

    # Convert fixed_positions to tensor and get indices where it's False
    fixed_positions_tensor = torch.tensor(fixed_positions, device=device)
    unfixed_indices = torch.where(~fixed_positions_tensor)[0]

    logits_offset = model_logits_cross_entropy[:, :-1]
    nonfixed_positions_offset = unfixed_indices[unfixed_indices != 0] - 1

    cross_entropy = (
        torch.nn.CrossEntropyLoss(reduction="none")(
            logits_offset[:, nonfixed_positions_offset].reshape(
                -1, logits_offset.shape[-1]
            ),
            story_ids_cross[:, 1:][:, nonfixed_positions_offset].reshape(-1),
        )
        .view(logits_offset.shape[0], -1)
        .mean(dim=-1)
    )

    cross_entropy = cross_entropy.item()

    #### Get the predicted next tokens

    inputs = tokenizer(new_story, return_tensors="pt").to(device)

    # Generate output
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=True,
        )

    input_length = inputs.input_ids.shape[1]

    # Decode only the newly generated tokens
    new_text = tokenizer.decode(output[0][input_length:], skip_special_tokens=True)

    return Result(
        result,
        cross_entropy,
        new_text,
    )


if __name__ == "__main__":
    # Set up argument parser
    parser = argparse.ArgumentParser(description="Run EPO optimization for TinyStories")
    parser.add_argument(
        "--results_jsons",
        type=str,
        nargs="+",
        help="Paths to JSON files containing results to evaluate",
    )
    parser.add_argument(
        "--dataset_id",
        type=str,
        help="The dataset id to evaluate",
        default="Eliciting-Contexts/backdoors-benchmark-dataset",
    )
    parser.add_argument(
        "--result_names",
        type=str,
        nargs="+",
        help="Names to identify each results JSON file (must match number of results_jsons)",
    )

    parser.add_argument(
        "--cross_entropy_model_name",
        type=str,
        default="google/gemma-2-2b-it",
        help="The model to use for calculating cross entropy",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="Device to run the model on (e.g., 'cuda', 'cpu')",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
        help="Data type for model weights (e.g., 'bfloat16', 'float16')",
    )
    parser.add_argument(
        "--output_json",
        type=str,
        default="applications_evaluation_results.json",
        help="Path to save the evaluation results as a JSON file",
    )

    args = parser.parse_args()

    # Verify that result_names matches results_jsons in length
    if not args.result_names:
        args.result_names = [f"result_set_{i}" for i in range(len(args.results_jsons))]
    if len(args.result_names) != len(args.results_jsons):
        raise ValueError(
            f"Number of result names ({len(args.result_names)}) must match number of result JSONs ({len(args.results_jsons)})"
        )

    all_json_results = []
    for json_path in args.results_jsons:
        try:
            with open(json_path, "r") as f:
                json_data = json.load(f)
                all_json_results.append(json_data)
            print(f"Loaded JSON file: {json_path}")
        except Exception as e:
            print(f"Error loading JSON file {json_path}: {e}")

    dataset = download_backdoors_dataset(
        hf_token=os.environ.get("HUGGINGFACE_HUB_TOKEN"),
        dataset_name=args.dataset_id,
    )

    # Dictionary to store all results
    all_results = {}

    result_names = ["default"] + args.result_names
    all_json_results = [None] + all_json_results

    current_base_model_name = None
    current_model_id = None

    cross_entropy_model = load_model_tlens(
        args.cross_entropy_model_name, args.device, args.dtype
    )

    for idx, datum in tqdm(enumerate(dataset["train"]), total=len(dataset["train"])):

        base_model_name = datum[BackdoorKeys.BASE_MODEL]
        lora_id = datum[BackdoorKeys.LORA_ID]
        if base_model_name != current_base_model_name or lora_id != current_model_id:

            if lora_id:
                tokenizer = AutoTokenizer.from_pretrained(base_model_name)
                if tokenizer.pad_token is None:
                    tokenizer.pad_token = tokenizer.eos_token

                    # Choose attention implementation if not specified
                # Check if the model is peft, and load accordingly
                files = list_repo_files(lora_id)
                has_adapter_config = any(
                    "adapter_config.json" in file for file in files
                )
                if has_adapter_config:
                    model = (
                        AutoPeftModelForCausalLM.from_pretrained(
                            lora_id,
                            torch_dtype=getattr(torch, args.dtype),
                            low_cpu_mem_usage=True,
                            attn_implementation="eager",
                            device_map=args.device,
                            trust_remote_code=True,
                            use_cache=False,
                        )
                        .merge_and_unload()
                        .eval()
                    )

                else:
                    raise ValueError(f"Model {lora_id} is not a PEFT model")

                model.config.pad_token_id = tokenizer.pad_token_id
            else:
                tokenizer = AutoTokenizer.from_pretrained(base_model_name)
                model = AutoModelForCausalLM.from_pretrained(
                    base_model_name,
                    device_map=args.device,
                    torch_dtype=getattr(torch, args.dtype),
                )
            current_base_model_name = base_model_name
            current_model_id = lora_id

        template = datum[BackdoorKeys.TEMPLATE]
        variable_context = "Hello and welcome and also good luck!"
        undesired_text = datum[BackdoorKeys.UNDESIRED_TEXT]
        desired_text = datum[BackdoorKeys.DESIRED_TEXT][0]

        dataset_type = datum[BackdoorKeys.DATASET_TYPE]
        dataset_info = datum[BackdoorKeys.DATASET_INFO]
        all_results[str(idx)] = {}

        # Create info object once per data point
        info = Info(
            template=template,
            variable_context=variable_context,
            undesired_text=undesired_text,
            desired_text=desired_text,
            dataset_type=dataset_type,
            dataset_info=dataset_info,
        )
        all_results[str(idx)]["info"] = info.to_dict()

        for json_idx, json_results in enumerate(all_json_results):
            result_name = result_names[json_idx]

            if json_results is None:
                if result_name == "default":
                    current_results = [variable_context]
                else:
                    continue
            else:
                if str(idx) not in json_results:
                    continue
                current_results = json_results[str(idx)]

            all_results[str(idx)][result_name] = []

            for current_result in current_results:
                # Check if we already have this result in our existing data
                result_exists = False

                result = get_metrics(
                    template,
                    variable_context,
                    current_result,
                    undesired_text,
                    desired_text,
                    model,
                    cross_entropy_model,
                    device=args.device,
                )

                # Add result to the collection
                all_results[str(idx)][result_name].append(result.to_dict())

    # Save all results to JSON file
    print(f"Saving all evaluation results to {args.output_json}")
    with open(args.output_json, "w") as f:
        json.dump(all_results, f, indent=2)
    print("Results saved successfully!")
