# Copyright 2023 AllenAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import os
import sys

import numpy as np
import torch
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from fastchat.conversation import get_conv_template
from peft import PeftModel  # type: ignore
from rewardbench import (
    check_tokenizer_chat_template,
    load_eval_dataset,
    torch_dtype_mapping,
)
from rewardbench.constants import EXAMPLE_COUNTS, SUBSET_MAPPING
from rewardbench.models.pipeline import RewardBenchPipeline
from rewardbench.utils import calculate_scores_per_section
from torch import Tensor
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from dr.utils import get_eps_from_model_name

# Enable TensorFloat32 (TF32) tensor cores on Ampere GPUs for matrix multiplications (faster than FP32)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# get token from HF_TOKEN env variable, but if it doesn't exist pass none
HF_TOKEN = os.getenv("HF_TOKEN", None)
# this is necessary to automatically log in when running this script in docker/batch beaker jobs
if HF_TOKEN is not None:
    from huggingface_hub._login import _login

    _login(token=HF_TOKEN, add_to_git_credential=False)


def get_args():
    """
    Parse arguments strings model and chat_template
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default="google/gemma-2b-it", type=str, help="path to model")
    parser.add_argument("--peft_name", type=str, required=False, default=None, help="path to PEFT adapter (if any)")
    parser.add_argument("--tokenizer", type=str, default=None, help="path to non-matching tokenizer to model")
    parser.add_argument("--chat_template", type=str, default="tulu", help="path to chat template")
    parser.add_argument("--trust_remote_code", action="store_true", default=False, help="directly load model instead of pipeline")
    parser.add_argument("--do_not_save", action="store_true", help="do not save results to hub (for debugging)")
    parser.add_argument("--batch_size", type=int, default=64, help="batch size for inference")
    parser.add_argument("--max_length", type=int, default=1024, help="Max length of RM inputs (passed to pipeline)")
    parser.add_argument("--pref_sets", action="store_true", help="run on common preference sets instead of our custom eval set")
    parser.add_argument("--debug", action="store_true", help="run on common preference sets instead of our custom eval set")
    parser.add_argument("--disable_beaker_save", action="store_true", help="disable saving the main results in a file for AI2 Beaker")
    parser.add_argument("--not_quantized", action="store_true", help="disable quantization for models that are quantized by default")
    parser.add_argument(
        "--torch_dtype",
        type=str,
        default="float16",
        choices=["float16", "bfloat16", "float32", "float64"],
        help="PyTorch dtype (default: float16)",
    )
    parser.add_argument(
        "--attn_implementation",
        type=str,
        default="flash_attention_2",
        choices=["eager", "sdpa", "flash_attention_2"],
        help="Attention implementation to use (default: flash_attention_2)",
    )
    args = parser.parse_args()
    args.torch_dtype = torch_dtype_mapping(args.torch_dtype)
    return args


def calculate_sequence_logprobs(logits: Tensor, input_ids: Tensor, attention_mask: Tensor):
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = input_ids[..., 1:].contiguous()
    shift_mask = attention_mask[..., 1:].contiguous()

    log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)
    token_log_probs = torch.gather(log_probs, dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1)

    token_log_probs = token_log_probs * shift_mask
    sequence_log_probs = token_log_probs.sum(dim=-1) / shift_mask.sum(dim=-1)
    return sequence_log_probs


def main():
    args = get_args()
    ###############
    # Setup logging
    ###############
    accelerator = Accelerator()
    current_device = accelerator.process_index

    logger = get_logger(__name__)
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = logging.INFO
    logger.setLevel(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    logger.info(f"Running reward model on {args.model} with chat template {args.chat_template}")
    if args.trust_remote_code:
        logger.info("Loading model with Trust Remote Code")

    # load chat template
    chat_template = args.chat_template
    conv = get_conv_template(chat_template)

    config = {
        "model_builder": AutoModelForCausalLM.from_pretrained,
        "tokenizer_builder": AutoTokenizer.from_pretrained,
        "quantized": True,
        "custom_dialogue": False,
    }
    logger.info(f"Using reward model config: {config}")

    quantized = config["quantized"]  # only Starling isn't quantized for now
    # if llama-3 in name, switch quantized to False (severely degrades performance)
    if ("llama-3" in args.model) or ("Llama3" in args.model) or ("Llama-3" in args.model) or ("LLaMA3" in args.model) or ("llama3" in args.model) or args.not_quantized:
        quantized = False
        logger.info(f"Disabling quantization for llama-3 or override flag (--not_quantized: {args.not_quantized})")

    custom_dialogue = config["custom_dialogue"]
    model_builder = config["model_builder"]
    torch_dtype = config.get("torch_dtype", None)
    # if not datatype in config (default), check args
    if torch_dtype is None:
        # if datatype is bfloat16, then manually turn off quantizaiton (done with bitsandbytes)
        if args.torch_dtype == torch.bfloat16:
            quantized = False
            logger.info("Disabling quantization for bfloat16 datatype")
        torch_dtype = args.torch_dtype

    # not included in config to make user explicitly understand they are passing this
    trust_remote_code = args.trust_remote_code

    ############################
    # Load dataset
    ############################
    logger.info("*** Load dataset ***")
    tokenizer_path = args.tokenizer if args.tokenizer else args.model
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=args.trust_remote_code)
    if not custom_dialogue:  # not needed for PairRM / SteamSHP
        tokenizer.truncation_side = "left"  # copied from Starling, but few samples are above context length
    dataset, subsets = load_eval_dataset(
        core_set=not args.pref_sets,
        conv=conv,
        custom_dialogue_formatting=custom_dialogue,
        tokenizer=tokenizer,  # type: ignore
        logger=logger,  # type: ignore
        keep_columns=["text_chosen", "text_rejected", "id"],
    )
    # copy id for saving, then remove
    ids = dataset["id"]
    dataset = dataset.remove_columns("id")

    # debug: use only 10 examples
    if args.debug:
        dataset = dataset.select(range(10))
        subsets = subsets[:10]
        ids = ids[:10]

    ############################
    # Load model pipeline
    ############################
    BATCH_SIZE = args.batch_size
    logger.info("*** Load model ***")
    reward_pipeline_kwargs = {
        "batch_size": BATCH_SIZE,  # eval_args.inference_batch_size,
        "truncation": True,
        "padding": True,
        "max_length": args.max_length,
        "function_to_apply": "none",  # Compute raw logits
        "return_token_type_ids": False,
    }
    if quantized:
        model_kwargs = {
            "load_in_8bit": True,
            "device_map": {"": current_device},
            "torch_dtype": torch_dtype if torch.cuda.is_available() else None,
        }
    else:
        model_kwargs = {
            "device_map": "auto",
            "torch_dtype": torch_dtype,
        }

    if args.peft_name is not None:  # TODO: is this correct?
        model_kwargs["num_labels"] = 1

    if args.attn_implementation:
        model_kwargs["attn_implementation"] = args.attn_implementation

    model = model_builder(args.model, **model_kwargs, trust_remote_code=trust_remote_code)
    reward_pipe = RewardBenchPipeline(
        "this-does-nothing-right-?",
        model=model,
        tokenizer=tokenizer,
    )

    ############################
    # PEFT initialization (if any)
    ############################

    if args.peft_name is not None and os.path.exists(args.peft_name):
        print(f"Loading PEFT adapter '{args.peft_name}'...")
        model = PeftModel.from_pretrained(model, args.peft_name)
        epsilon = get_eps_from_model_name(args.peft_name)
    else:
        epsilon = get_eps_from_model_name(args.model)
    if hasattr(model, "merge_and_unload"):
        model = model.merge_and_unload()

    ############################
    # Tokenization settings & dataset preparation
    ############################
    # set pad token to eos token if not set
    if reward_pipe.tokenizer.pad_token_id is None:
        reward_pipe.model.config.pad_token_id = reward_pipe.tokenizer.eos_token_id
        reward_pipe.tokenizer.pad_token_id = reward_pipe.tokenizer.eos_token_id
    # For models whose config did not contains `pad_token_id`
    if reward_pipe.model.config.pad_token_id is None:
        reward_pipe.model.config.pad_token_id = reward_pipe.tokenizer.pad_token_id

    # if using fastchat template (no template in tokenizer), make the RM tokenizer output an EOS token
    if not check_tokenizer_chat_template(tokenizer):
        reward_pipe.tokenizer.add_eos_token = True  # type: ignore

    ############################
    # Run inference
    ############################
    logger.info("*** Running dataloader to collect results ***")
    from torch.utils.data.dataloader import default_collate

    # for PairRM, hmm, will move all of this later
    def custom_collate_fn(batch):
        # check if ['text_chosen'] is in first batch element
        # Check if the first element of the batch is a dictionary
        if isinstance(batch[0]["text_chosen"][0], dict):
            return batch  # Return the batch as-is if it's a list of dicts
        else:
            return default_collate(batch)  # Use the default collate behavior otherwise

    dataloader = torch.utils.data.DataLoader(
        dataset,  # type: ignore
        batch_size=BATCH_SIZE,
        collate_fn=custom_collate_fn,  # if not args.pref_sets else None,
        shuffle=False,
        drop_last=False,
    )

    dataloader, model = accelerator.prepare(dataloader, reward_pipe.model)
    reward_pipe.model = model

    results = []
    scores_chosen = []
    scores_rejected = []
    for batch in tqdm(dataloader, desc="RM batch steps"):
        # logger.info(f"RM inference step {step}/{len(dataloader)}")
        # text_rejected = [b["text_rejected"] for b in batch]
        # text_chosen = [b["text_chosen"] for b in batch]
        # results_sub = reward_pipe(text_chosen, text_rejected, **reward_pipeline_kwargs)
        # [results.append(1) if result else results.append(0) for result in results_sub.cpu().numpy().tolist()]
        # scores_chosen.extend([None] * len(results_sub))
        # scores_rejected.extend([None] * len(results_sub))

        logits_chosen: Tensor = reward_pipe(batch["text_chosen"], **reward_pipeline_kwargs)  # type: ignore
        logits_rejected: Tensor = reward_pipe(batch["text_rejected"], **reward_pipeline_kwargs)  # type: ignore

        # for each item in batch, record 1 if chosen > rejected
        # extra score from dict within batched results (e.g. logits)
        # [{'label': 'LABEL_1', 'score': 0.6826171875},... ]

        chosen_tokenized: dict[str, Tensor] = tokenizer(batch["text_chosen"], truncation=True, max_length=args.max_length, padding=True, return_tensors="pt").to("cuda")  # type: ignore
        chosen, chosen_mask = chosen_tokenized["input_ids"], chosen_tokenized["attention_mask"]

        rejected_tokenized: dict[str, Tensor] = tokenizer(batch["text_rejected"], truncation=True, max_length=args.max_length, padding=True, return_tensors="pt").to("cuda")  # type: ignore
        rejected, rejected_mask = rejected_tokenized["input_ids"], rejected_tokenized["attention_mask"]

        chosen_log_probs = calculate_sequence_logprobs(logits_chosen, chosen, chosen_mask)
        rejected_log_probs = calculate_sequence_logprobs(logits_rejected, rejected, rejected_mask)

        # log results
        predictions = (chosen_log_probs > rejected_log_probs).float()

        [results.append(1) if pred == 1 else results.append(0) for pred in predictions]
        scores_chosen.extend(chosen_log_probs.float().numpy(force=True))
        scores_rejected.extend(rejected_log_probs.float().numpy(force=True))

    ############################
    # Print & process results
    ############################
    # add column for results for easy printing
    out_dataset = dataset.add_column("results", results)  # type: ignore

    # add subsets back (removed so it's not handled by cuda)
    out_dataset = out_dataset.add_column("subset", subsets)
    out_dataset = out_dataset.add_column("id", ids)

    # add scores_chosen and scores_rejected to the dataset
    out_dataset = out_dataset.add_column("scores_chosen", scores_chosen)
    out_dataset = out_dataset.add_column("scores_rejected", scores_rejected)

    # get core dataset
    results_grouped = {}
    results_grouped["model"] = args.model if args.peft_name is None else args.peft_name
    results_grouped["chat_template"] = args.chat_template if not check_tokenizer_chat_template(tokenizer) else "tokenizer"

    # print per subset and log into results_grouped file
    present_subsets = np.unique(subsets)
    for subset in present_subsets:
        subset_dataset = out_dataset.filter(lambda example: example["subset"] == subset)
        num_correct = sum(subset_dataset["results"])
        num_total = len(subset_dataset["results"])
        print(f"{subset}: {num_correct}/{num_total} ({num_correct/num_total})")
        results_grouped[subset] = num_correct / num_total

    set_map_latex = {
        "Chat": r"\CRbC",
        "Chat Hard": r"\CRbCh",
        "Safety": r"\CRbS",
        "Reasoning": r"\CRbR",
    }

    # Get log path
    task = "rewardbench"
    log_dir = "./eval_rm"
    log_model_name: str = args.peft_name if args.peft_name is not None else args.model
    log_folder = log_model_name.removeprefix("models/").replace("/", "_")
    log_path = os.path.join(log_dir, log_folder, task)
    os.makedirs(log_path, exist_ok=True)

    # log leaderboard aggregated results
    if not args.pref_sets:
        results_leaderboard = calculate_scores_per_section(EXAMPLE_COUNTS, SUBSET_MAPPING, results_grouped)
        print(results_leaderboard)
        rounded = {k: round(100 * v, 1) for k, v in results_leaderboard.items()}
        rounded_sum = sum([v for v in rounded.values()])
        print(f"\n{rounded['Chat']} & {rounded['Chat Hard']} & {rounded['Safety']} & {rounded['Reasoning']} & {rounded_sum / 4}")
        print(f"\n", "{", rounded["Chat"], "}{", rounded["Chat Hard"], "}{", rounded["Safety"], "}{", rounded["Reasoning"], "}{", round(rounded_sum / 4, 1), "}", sep="")

        results_summary_path = os.path.join(log_path, "summarry_accuracy.txt")
        with open(results_summary_path, "w+") as f:
            f.write(
                r"{"
                + f"{results_leaderboard['Chat'] * 100:.1f}"
                + r"}{"
                + f"{results_leaderboard['Chat Hard'] * 100:.1f}"
                + r"}{"
                + f"{results_leaderboard['Safety'] * 100:.1f}"
                + r"}{"
                + f"{results_leaderboard['Reasoning'] * 100:.1f}"
                + r"}{"
                + f"{rounded_sum / 4:.1f}"
                + r"}"
                + "\n"
            )
        print(f"Saved to '{results_summary_path}.")

    print("\nResults:\n\n", r"$\varepsilon = ", epsilon, r"$ & ", sep="", end="")
    for setname, subsets in SUBSET_MAPPING.items():
        print(f"{set_map_latex[setname]}", r"{", r"}{".join([f"{results_grouped[subsetname]*100:.1f}" for subsetname in subsets]), "} & ", sep="", end="")
    print(r"\CRbAvg{", f"{rounded_sum / 4:.1f}", r"} \\", sep="")

    results_path = os.path.join(log_path, "accuracy.txt")
    with open(results_path, "w+") as f:
        f.write(r"\dr $\varepsilon = " + epsilon + r"$ & ")
        for setname, subsets in SUBSET_MAPPING.items():
            f.write(f"{set_map_latex[setname]}" + r"{" + r"}{".join([f"{results_grouped[subsetname]*100:.1f}" for subsetname in subsets]) + r"} & ")
        f.write(r"\CRbAvg{" + f"{rounded_sum / 4:.1f}" + r"} \\" + "\n")

    print(f"Saved to '{results_path}'.")

    ############################
    # Upload results to hub
    ############################
    # sub_path = "eval-set/" if not args.pref_sets else "pref-sets/"
    # results_url = save_to_hub(
    #     results_grouped,
    #     args.model,
    #     sub_path,
    #     args.debug,
    #     local_only=args.do_not_save,
    #     save_metrics_for_beaker=not args.disable_beaker_save,
    # )
    # if not args.do_not_save:
    #     logger.info(f"Uploaded reward model results to {results_url}")

    # # upload chosen-rejected with scores
    # if not model_type == "Custom Classifier":  # custom classifiers do not return scores
    #     # create new json with scores and upload
    #     scores_dict = out_dataset.to_dict()
    #     scores_dict["model"] = args.model
    #     scores_dict["model_type"] = model_type
    #     scores_dict["chat_template"] = args.chat_template

    #     sub_path_scores = "eval-set-scores/" if not args.pref_sets else "pref-sets-scores/"

    #     scores_url = save_to_hub(scores_dict, args.model, sub_path_scores, args.debug, local_only=args.do_not_save)
    #     logger.info(f"Uploading chosen-rejected text with scores to {scores_url}")
    # else:
    #     logger.info("Not uploading chosen-rejected text with scores due to model compatibility")


if __name__ == "__main__":
    main()
