from datasets import load_from_disk
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import json
import torch
from peft import PeftModel

from utils.env_management import save_config
import copy
import os

parser = argparse.ArgumentParser(description="Generate captions")
parser.add_argument(
    "--config_file",
    type=str,
    default="configs/lora-rlhf-scores.json",
    help="config file",
)
parser.add_argument("--cuda_device", type=int, default=0, help="cuda device to use")
args = parser.parse_args()

config = json.load(open(args.config_file))
save_config(config, "generate")

os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_device)

base_model = config["base_model"].format(**config)
ratings = load_from_disk(config["data_path_preprocessed"].format(**config))
checkpoint = config["generation_model_base"].format(**config)

if config["sample_generation"]:
    ratings["test"] = (
        ratings["test"]
        .shuffle(seed=config["sample_generation_seed"])
        .select(range(config["sample_generation_size"]))
    )

device = config["device"]
target_dtype = getattr(torch, config["torch_dtype"].split(".")[-1])

tokenizer = AutoTokenizer.from_pretrained(base_model)
tokenizer.pad_token = config["pad_token"]

if "chat_template" in config:
    tokenizer.chat_template = config["chat_template"]
else:
    print("No chat template provided in config file using default")

seeds = config.get("generation_seed", [42])

for seed in seeds:
    save_path = config["generated_output_path"].format(seed=seed, **config)
    print(f"Generating captions for seed {seed} and saving to {save_path}")
    torch.manual_seed(seed)

    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        torch_dtype=config["torch_dtype"],
    ).to(device)

    model.resize_token_embeddings(len(tokenizer))

    if config["generation_has_peft"]:
        peft_config = PeftModel.from_pretrained(model, checkpoint, adapter_name="default")

        model = peft_config.merge_and_unload()

    if config["generation_model_peft"]:
        peft_config = PeftModel.from_pretrained(
            model, config["generation_model_peft"].format(**config), adapter_name="default"
        )

        model = peft_config.merge_and_unload()


    def generate_response(examples):
        messages_template = config["messages_template"]
        batch_size = len(next(iter(examples.values())))  # Get the batch size

        prompts = []
        for i in range(batch_size):
            messages = copy.deepcopy(messages_template)
            messages.pop()  # Drop the last message intended for the assistant

            # Prepare variables for string formatting
            example_vars = {key: examples[key][i] for key in examples}
            messages[1]["content"] = messages_template[1]["content"].format(**example_vars)

            # Generate the prompt for each example
            prompt = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
            prompts.append(prompt)

        # Tokenize all prompts in the batch
        inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
        # inputs = inputs.to(target_dtype)

        # Generate outputs for the batch
        outputs = model.generate(
            **inputs, **config["generation_config"], pad_token_id=tokenizer.eos_token_id
        )

        # Decode the outputs
        decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        # Add the responses to the examples
        examples["response"] = decoded_outputs
        return examples


    # Apply the batched function to the dataset
    ratings["test"] = ratings["test"].map(
        generate_response,
        batched=True,
        batch_size=config["generation_batch_size"],
    )
    ratings["test"].save_to_disk(save_path)
    print(f"Generated captions saved to {save_path}")
