from dataclasses import dataclass, field
from typing import List, Optional, cast

import numpy as np
import multiprocess
import torch
from datasets import load_dataset
from peft import LoraConfig, PeftModel
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    HfArgumentParser,
)
import os


from .evaluate_hh import get_weightchain_array


@dataclass
class ScriptArguments:
    reward_model_checkpoint: str = field(
        metadata={"help": "Path to the trained reward model checkpoint."}
    )
    input: str = field(
        metadata={"help": "JSONL file with responses to evaluate."},
    )
    output: Optional[str] = field(
        default=None,
        metadata={"help": "JSONL file where results will be stored."},
    )
    batch_size: int = field(default=1)
    model_name: str = field(
        default="gpt2",
        metadata={
            "help": "The model that you want to use as the basis for generation and for evaluation."
            "E.g. gpt2, gpt2-xl, bert, etc."
        },
    )
    tokenizer_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "The tokenizer for your model, if left empty will use the default "
            "for your model",
        },
    )
    num_outputs: int = field(default=1)
    max_length: int = field(default=1024)
    bf16: bool = field(
        default=True,
        metadata={"help": "Whether to use bfloat16 for the reward model."},
    )
    rex_mcmc_file: str = field(
        default=None,
        metadata={"help": "Path to the REX weights."},
    )
    rex_normalize: str = field(
        default="none",
        metadata={"help": "How to normalize the reward outputs."},
    )


if __name__ == "__main__":
    multiprocess.set_start_method("spawn")

    parser = HfArgumentParser(ScriptArguments)
    script_args = cast(ScriptArguments, parser.parse_args_into_dataclasses()[0])

    output_fname = script_args.output
    if output_fname is None:
        output_fname = os.path.join(
            # get parent
            os.path.dirname(script_args.rex_mcmc_file),
            f"eval_results_jailbreak.jsonl",
        )
    model_name = os.path.dirname(script_args.rex_mcmc_file).split("/")[-1]

    # Load the value-head model and tokenizer.
    tokenizer_name = (
        script_args.tokenizer_name
        if script_args.tokenizer_name is not None
        else script_args.model_name
    )
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True)

    # Need to do this for GPT2 and Llama because they doesn't have official pad tokens.
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"

    dataset = load_dataset("json", data_files=[script_args.input])["train"]

    print("Loading base reward model...")
    model_kwargs = {}
    if script_args.bf16:
        model_kwargs["torch_dtype"] = torch.bfloat16
    base_reward_model = AutoModelForSequenceClassification.from_pretrained(
        script_args.model_name,
        num_labels=script_args.num_outputs,
        **model_kwargs,
    )

    checkpoint_path = script_args.reward_model_checkpoint

    peft_config = LoraConfig.from_pretrained(checkpoint_path)
    reward_model = PeftModel.from_pretrained(
        base_reward_model, checkpoint_path, is_trainable=False
    )

    if script_args.rex_mcmc_file is not None:
        (
            rex_weights,
            rex_mean_reward,
            rex_std_reward,
            rex_median_reward,
            rex_mad_reward,
        ) = get_weightchain_array(script_args.rex_mcmc_file)
        reward_model.score.weight.data = torch.from_numpy(rex_weights).to(
            dtype=model_kwargs["torch_dtype"]
        )
        rex_mean_reward = (
            torch.tensor(rex_mean_reward)
            .to(dtype=model_kwargs["torch_dtype"])
            .to("cuda")
        )
        rex_std_reward = (
            torch.tensor(rex_std_reward)
            .to(dtype=model_kwargs["torch_dtype"])
            .to("cuda")
        )
        rex_median_reward = (
            torch.tensor(rex_median_reward)
            .to(dtype=model_kwargs["torch_dtype"])
            .to("cuda")
        )
        rex_mad_reward = (
            torch.tensor(rex_mad_reward)
            .to(dtype=model_kwargs["torch_dtype"])
            .to("cuda")
        )

    if script_args.rex_normalize == "mean":
        print("normalize by mean")
    elif script_args.rex_normalize == "median":
        print("normalize by median")

    reward_model.cuda().eval()
    reward_model.pad_token_id = tokenizer.pad_token_id

    def evaluate_responses(example):
        prompt: str = example["prompt"]
        prompt = prompt[prompt.index("Human: ") :]

        responses = example["responses"]
        reward_outputs = []
        with torch.no_grad():
            for response in responses:
                inputs = tokenizer(
                    prompt + response,
                    return_tensors="pt",
                    max_length=script_args.max_length,
                )
                reward_output = reward_model(
                    inputs.input_ids.cuda(), inputs.attention_mask.cuda()
                )[0]
                if script_args.rex_normalize == "mean":
                    reward_output -= rex_mean_reward
                elif script_args.rex_normalize == "median":
                    reward_output -= rex_median_reward
                reward_outputs.append(reward_output[0].tolist())
        return {
            f"reward_outputs_{model_name}": reward_outputs,
        }

    print(f"Evaluating responses with {model_name}...")
    dataset = dataset.map(
        evaluate_responses,
        batched=False,
    )

# Combine datasets and output to JSONL
dataset.to_json(output_fname, orient="records", lines=True)
