import argparse
import os
from typing import Union, List, Any, Dict, Mapping

import joblib
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification, \
    DataCollatorForLanguageModeling
from datasets import load_from_disk, load_dataset
from accelerate import Accelerator
from tqdm import tqdm

class DataCollatorForRewardEval(DataCollatorForLanguageModeling):
    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        response_key_name = "completion" if "completion" in examples[0] else "response"
        messages_batch = self.tokenizer.apply_chat_template(
            [[{'role': "user", "content": example["prompt"]}, {'role': "assistant", "content": example[response_key_name]}] for example in examples],
            tokenize=True, add_generation_prompt=False, return_tensors="pt", return_dict=True, padding=True
        )
        return messages_batch


def parse_args():
    parser = argparse.ArgumentParser(description='Compute rewards for a dataset')
    parser.add_argument('--dataset_path', type=str, required=True, help='Path to the dataset')
    parser.add_argument('--output_path', type=str, required=True, help='Path to save the processed dataset')
    parser.add_argument('--model_name', type=str, default="nvidia/Llama-3.1-Nemotron-70B-Reward-HF",
                        help='Name or path of the reward model')
    parser.add_argument('--batch_size', type=int, default=1, help='Batch size for inference')
    parser.add_argument('--use_cls', action='store_true', help='Use AutoModelForSequenceClassification')
    return parser.parse_args()


def compute_batch_rewards(model, batch, accelerator, is_cls_model=False):
    """Compute rewards for a batch of prompts and responses."""

    model_inputs = {k: v.to(accelerator.device) for k, v in batch.items()}

    with torch.no_grad():
        if is_cls_model:
            outputs = model(**model_inputs)
            batch_rewards = outputs.logits.squeeze()
        else:
            outputs = model.generate(
                **model_inputs,
                max_new_tokens=1,
                return_dict_in_generate=True,
                output_scores=True
            )
            batch_rewards = outputs['scores'][0][:, 0]

    batch_rewards = accelerator.gather(batch_rewards)
    return batch_rewards.cpu().float().numpy()


def process_dataset(dataset, model, tokenizer, accelerator, batch_size, is_cls_model=False):
    """Process dataset and add rewards column."""
    if 'golden_reward' in dataset.features:
        print("Dataset already has rewards computed. Skipping computation.")
        return dataset

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=DataCollatorForRewardEval(tokenizer, mlm=False)
    )
    model, dataloader = accelerator.prepare(model, dataloader)

    all_rewards = []

    for batch in tqdm(dataloader, desc="Computing rewards"):
        batch_rewards = compute_batch_rewards(
            model,
            batch,
            accelerator,
            is_cls_model=is_cls_model
        )
        all_rewards.extend(batch_rewards)

    # Add rewards to dataset only on main process
    if accelerator.is_main_process:
        # Ensure we only take the first len(dataset) rewards in case of uneven distribution
        all_rewards = all_rewards[:len(dataset)]
        dataset = dataset.add_column('golden_reward', all_rewards)

    return dataset


def main():
    args = parse_args()

    # Initialize accelerator
    accelerator = Accelerator()

    # Load dataset
    if os.path.exists(args.dataset_path):
        dataset = load_from_disk(args.dataset_path)
    else:
        dataset = load_dataset(args.dataset_path)["train"]

    print(f"Using {len(dataset)} examples from the dataset for reward computation.")

    # Load model and tokenizer only if needed
    if 'golden_reward' not in dataset.features:
        tokenizer = AutoTokenizer.from_pretrained(args.model_name)
        torch_dtype = torch.float16 if torch.cuda.get_device_capability()[0] <= 7 else torch.bfloat16
        if args.use_cls:
            model = AutoModelForSequenceClassification.from_pretrained(
                args.model_name,
                torch_dtype=torch_dtype,
                device_map="auto" if accelerator.num_processes == 1 else None,
                num_labels=1,
            )
        else:
            model = AutoModelForCausalLM.from_pretrained(
                args.model_name,
                torch_dtype=torch_dtype,
                device_map="auto" if accelerator.num_processes == 1 else None
            )

        # Process dataset
        processed_dataset = process_dataset(
            dataset,
            model,
            tokenizer,
            accelerator,
            args.batch_size,
            args.use_cls
        )

        # Save processed dataset only on main process
        if accelerator.is_main_process:
            processed_dataset.save_to_disk(args.output_path)
            print(f"Processed dataset saved to {args.output_path}")
    else:
        if accelerator.is_main_process:
            print("Dataset already has rewards computed. Saving as is.")
            dataset.save_to_disk(args.output_path)


if __name__ == "__main__":
    main()
