from datasets import load_from_disk
import argparse
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoModelForSequenceClassification,
)
import numpy as np
import json
import torch
from peft import PeftModel

from utils.env_management import save_config
import copy
from modules.models import (
    LlamaForSequenceClassificationWithMoreFeatures,
)
from tqdm import trange

parser = argparse.ArgumentParser(description="Evaluate rewards")
parser.add_argument(
    "--config_file",
    type=str,
    default="configs/lora-rlhf-scores.json",
    help="config file",
)
args = parser.parse_args()

config = json.load(open(args.config_file))
save_config(config, "evaluate_valid_rewards")

base_model = config["base_model"].format(**config)
evaluator_training_model = config["reward"]["training_model"]
evaluator_has_peft = config["reward"]["has_peft"]
evaluator_model_peft = config["reward"]["model_output_path"].format(**config)

ratings = load_from_disk(config["data_path_preprocessed"].format(**config))

save_path = config["valid_evaluated_output_path"].format(**config)

checkpoint = config["reward"]["model"].format(**config)

device = config["device"]
target_dtype = getattr(torch, config["torch_dtype"].split(".")[-1])

torch.set_default_dtype(getattr(torch, config["torch_dtype"]))

tokenizer = AutoTokenizer.from_pretrained(base_model)
tokenizer.pad_token = config["pad_token"]

if "chat_template" in config:
    tokenizer.chat_template = config["chat_template"]
else:
    print("No chat template provided in config file using default")

if evaluator_training_model == "sequence_classification":
    model = AutoModelForSequenceClassification.from_pretrained(
        base_model,
        num_labels=1,
        problem_type="regression",
        torch_dtype=config["torch_dtype"],
    )
elif evaluator_training_model == "more_features":
    model = LlamaForSequenceClassificationWithMoreFeatures.from_pretrained(
        base_model,
        num_labels=1,
        problem_type="regression",
        torch_dtype=config["torch_dtype"],
        features_dim=1,
    ).to(device)
elif evaluator_training_model == "odin":
    model = LlamaForSequenceClassificationWithMoreFeatures.from_pretrained(
        base_model,
        num_labels=2,
        problem_type="regression",
        torch_dtype=config["torch_dtype"],
        features_dim=1,
    ).to(device)
else:
    raise ValueError(f"Unknown training model: {evaluator_training_model}")

model.resize_token_embeddings(len(tokenizer))
model = model.to(device)

model.config.pad_token_id = tokenizer.pad_token_id

if evaluator_has_peft:
    peft_config = PeftModel.from_pretrained(model, checkpoint, adapter_name="default")

    model = peft_config.merge_and_unload()

if evaluator_model_peft:
    peft_config = PeftModel.from_pretrained(
        model, evaluator_model_peft.format(**config), adapter_name="default"
    )

    model = peft_config.merge_and_unload()

def evaluate_response(examples):
    messages_template = config["messages_template"]
    batch_size = len(next(iter(examples.values())))  # Get the batch size

    prompts = []
    features = []
    for i in range(batch_size):
        messages = copy.deepcopy(messages_template)
        # Prepare variables for string formatting
        example_vars = {key: examples[key][i] for key in examples}
        messages[1]["content"] = messages_template[1]["content"].format(**example_vars)

        key = messages_template[2]["content"].split("{")[1].split("}")[0]
        messages[2]["content"] = messages_template[2]["content"].format(
            **{key: example_vars["Title"]}
        )
        # Generate the prompt for each example
        prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False,
        )
        prompts.append(prompt)
        features.append([example_vars["popularity"]])

    # Tokenize all prompts in the batch
    inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
    features = torch.tensor(features).to(device).to(target_dtype)
    if evaluator_training_model != "sequence_classification":
        inputs["features"] = features

    logits = model(**inputs).logits

    # Add the responses to the examples
    examples["logits"] = logits

    return examples

# Apply the batched function to the dataset
ratings["reward_valid"] = ratings["reward_valid"].map(
    evaluate_response,
    batched=True,
    batch_size=config["evaluation_batch_size"],
)
ratings["reward"] = ratings["reward"].map(
    evaluate_response,
    batched=True,
    batch_size=config["evaluation_batch_size"],
)
ratings.save_to_disk(save_path)
print(f"Evaluated headlines saved to {save_path}")
