from datasets import load_from_disk
import argparse
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoModelForSequenceClassification,
)
import numpy as np
import json
import torch
from peft import PeftModel

from utils.env_management import save_config
import copy
from tqdm import trange
from sklearn.metrics import roc_curve, auc

parser = argparse.ArgumentParser(description="Evaluate rewards")
parser.add_argument(
    "--config_file",
    type=str,
    default="configs/lora-rlhf-scores.json",
    help="config file",
)
parser.add_argument(
    "--paired",
    action="store_true",
)
args = parser.parse_args()

paired = args.paired

config = json.load(open(args.config_file))
save_config(config, "evaluate_test_rewards")

base_model = config["base_model"].format(**config)
evaluator_has_peft = config["reward"]["has_peft"]
evaluator_model_peft = config["reward"]["model_output_path"].format(**config)

ratings = load_from_disk(config["data_path"].format(**config))
print(ratings)

save_path = config["test_evaluated_output_path"].format(**config)

checkpoint = config["reward"]["model"].format(**config)

device = config["device"]
target_dtype = getattr(torch, config["torch_dtype"].split(".")[-1])

torch.set_default_dtype(getattr(torch, config["torch_dtype"]))

tokenizer = AutoTokenizer.from_pretrained(base_model)
tokenizer.pad_token = config["pad_token"]

if "chat_template" in config:
    tokenizer.chat_template = config["chat_template"]
else:
    print("No chat template provided in config file using default")

model = AutoModelForSequenceClassification.from_pretrained(
    base_model,
    num_labels=1,
    torch_dtype=config["torch_dtype"],
)

model.resize_token_embeddings(len(tokenizer))
model = model.to(device)

model.config.pad_token_id = tokenizer.pad_token_id

if evaluator_has_peft:
    peft_config = PeftModel.from_pretrained(model, checkpoint, adapter_name="default")

    model = peft_config.merge_and_unload()

if evaluator_model_peft:
    peft_config = PeftModel.from_pretrained(
        model, evaluator_model_peft.format(**config), adapter_name="default"
    )

    model = peft_config.merge_and_unload()


def evaluate_response_0(examples):
    messages_template = config["messages_template"]
    batch_size = len(next(iter(examples.values())))  # Get the batch size

    prompts = []
    for i in range(batch_size):
        messages = copy.deepcopy(messages_template)
        # Prepare variables for string formatting
        example_vars = {key: examples[key][i] for key in examples}
        example_vars["created_at"] = example_vars["created_at"][:10]
        messages[1]["content"] = messages_template[1]["content"].format(**example_vars)
        messages[2]["content"] = messages_template[2]["content"].format(**example_vars)
        # Generate the prompt for each example
        try:
            prompt = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False,
            )
        except Exception as e:
            # print(f"Error applying chat template: {e}")
            prompt = "\n".join(
                [
                    f"{message['role']}: {message['content']}"
                    for message in messages
                ]
            )
        prompts.append(prompt)

    # Tokenize all prompts in the batch
    inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
    logits = model(**inputs).logits

    # Add the responses to the examples
    examples["logits_0"] = logits
    return examples


def evaluate_response_1(examples):
    messages_template = config["messages_template"]
    batch_size = len(next(iter(examples.values())))  # Get the batch size

    prompts = []
    for i in range(batch_size):
        messages = copy.deepcopy(messages_template)
        # Prepare variables for string formatting
        example_vars = {key: examples[key][i] for key in examples}
        example_vars["headline"] = example_vars["headline_2"]
        example_vars["created_at"] = example_vars["created_at"][:10]
        messages[1]["content"] = messages_template[1]["content"].format(**example_vars)
        messages[2]["content"] = messages_template[2]["content"].format(**example_vars)
        # Generate the prompt for each example
        try:
            prompt = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False,
            )
        except Exception as e:
            # print(f"Error applying chat template: {e}")
            prompt = "\n".join(
                [
                    f"{message['role']}: {message['content']}"
                    for message in messages
                ]
            )
        prompts.append(prompt)

    # Tokenize all prompts in the batch
    inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
    logits = model(**inputs).logits

    # Add the responses to the examples
    examples["logits_1"] = logits
    return examples


ratings["test"] = ratings["test"].map(
    evaluate_response_0,
    batched=True,
    batch_size=config["evaluation_batch_size"],
)
ratings["test"] = ratings["test"].map(
    evaluate_response_1,
    batched=True,
    batch_size=config["evaluation_batch_size"],
)

ratings.save_to_disk(save_path)
print(f"Evaluated headlines saved to {save_path}")

df_test = ratings["test"].to_pandas()

df_test['logits_0'] = df_test['logits_0'].apply(lambda x: x[0])
df_test['logits_1'] = df_test['logits_1'].apply(lambda x: x[0])

df_test['score_0'] = df_test.apply(lambda row: row['chosen_score'] if row['outcome'] == 0 else row['rejected_score'], axis=1)
df_test['score_1'] = df_test.apply(lambda row: row['chosen_score'] if row['outcome'] == 1 else row['rejected_score'], axis=1)

print(df_test[['score_0', 'score_1', 'logits_0', 'logits_1']].corr())
print(df_test[['score_0', 'score_1', 'logits_0', 'logits_1']].describe())

outcome = df_test['outcome']
logits_1 = df_test['logits_1']
logits_0 = df_test['logits_0']

logits_score = logits_1 - logits_0
# Compute ROC curve
fpr, tpr, _ = roc_curve(outcome, logits_score)

# Compute AUC
roc_auc = auc(fpr, tpr)
print(f"AUC: {roc_auc}")