import json
import argparse
import os
from typing import Dict, List
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from tqdm import tqdm
from accelerate import Accelerator
from torch.utils.data import Dataset, DataLoader
from accelerate.utils import gather_object

class ChatDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            "prompt": item["synthesized_prompt"],
            "response1": item["synthesized_response_1"],
            "response2": item["synthesized_response_2"]
        }

class ArmoRMPipeline:
    def __init__(self, model_id, accelerator, truncation=True, trust_remote_code=False, max_length=4096):
        self.accelerator = accelerator
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_id,
            trust_remote_code=trust_remote_code,
            torch_dtype=torch.bfloat16,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            use_fast=True,
        )
        self.truncation = truncation
        self.max_length = max_length
        
        self.model = accelerator.prepare(self.model)

    def __call__(self, messages: List[List[Dict[str, str]]]) -> List[Dict[str, float]]:
        input_ids = self.tokenizer.apply_chat_template(
            messages,
            return_tensors="pt",
            padding=True,
            truncation=self.truncation,
            max_length=self.max_length,
        )
        input_ids = self.accelerator.prepare(input_ids)
        with torch.no_grad():
            outputs = self.model(input_ids)
            scores = outputs.logits.squeeze().float().tolist()  # Handle batch outputs
        return [{"score": score} for score in scores]

def process_file(args, accelerator):
    rm = ArmoRMPipeline(args.model_id, accelerator, trust_remote_code=args.trust_remote_code, max_length=args.max_length)

    with open(args.input_file, 'r') as f:
        data = json.load(f)

    dataset = ChatDataset(data)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
    dataloader = accelerator.prepare(dataloader)

    all_results = []
    
    for batch in tqdm(dataloader, desc="Processing items in batches", disable=not accelerator.is_local_main_process):
        prompts = batch["prompt"]
        responses1 = batch["response1"]
        responses2 = batch["response2"]

        messages1 = [[{"role": "user", "content": p}, {"role": "assistant", "content": r1}] for p, r1 in zip(prompts, responses1)]
        messages2 = [[{"role": "user", "content": p}, {"role": "assistant", "content": r2}] for p, r2 in zip(prompts, responses2)]

        scores1 = rm(messages1)
        scores2 = rm(messages2)

        for prompt, response1, response2, score1, score2 in zip(prompts, responses1, responses2, scores1, scores2):
            item_result = {
                "synthesized_response_1_score": score1["score"],
                "synthesized_response_2_score": score2["score"]
            }

            if score1["score"] >= score2["score"]:
                item_result["rm_order"] = "1 > 2"
                item_result["chosen"] = response1
                item_result["rejected"] = response2
            else:
                item_result["rm_order"] = "2 > 1"
                item_result["chosen"] = response2
                item_result["rejected"] = response1

            all_results.append(item_result)

    # Gather results from all processes
    all_results = gather_object(all_results)

    if accelerator.is_main_process:
        # Adjust results to ensure no repetition due to uneven batch sizes
        num_original_items = len(data)
        adjusted_results = data[:num_original_items]

        for i in range(num_original_items):
            adjusted_results[i].update(all_results[i])

        input_dir = os.path.dirname(args.input_file)
        output_file = os.path.join(input_dir, "rm_" + os.path.basename(args.input_file))

        with open(output_file, 'w') as f:
            json.dump(adjusted_results, f, indent=2)

        print(f"Processing complete. Results saved to {output_file}")

def main():
    parser = argparse.ArgumentParser(description="Process JSON file with ArmoRMPipeline")
    parser.add_argument("--input_file", type=str, default="your file", help="Input JSON file path")
    parser.add_argument("--model_id", type=str, default="RLHFlow/ArmoRM-Llama3-8B-v0.1", help="Model ID for ArmoRMPipeline")
    parser.add_argument("--trust_remote_code", type=bool, default=True, help="Trust remote code for model loading")
    parser.add_argument("--max_length", type=int, default=4096, help="Max length for tokenizer")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size for processing")

    args = parser.parse_args()

    accelerator = Accelerator()
    process_file(args, accelerator)

if __name__ == "__main__":
    main()
