import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import json
import os
import argparse
# import tqdm
import numpy as np
import datasets
from datetime import timedelta
from accelerate.utils import InitProcessGroupKwargs
from accelerate import Accelerator
from accelerate.utils import gather_object
import time
import glob
from tqdm import tqdm
import pdb

if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--generation_files_folder", type=str, default="outputs/llama3-8b-instruct-on-policy-data-gen-swepo-1vsk", help="Path to the output generation file")
    parser.add_argument("--reward_model", type=str, default="Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", help="Path to reward model")
    parser.add_argument("--output_dir", type=str, default="outputs", help="Path to output directory")
    args = parser.parse_args()
    
    print(args)
    
    kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=21600))
    accelerator = Accelerator(kwargs_handlers=[kwargs])
    
    files = glob.glob(f'{args.generation_files_folder}/*.json')
    
    print(len(files))
    print(files[0])
    
    files_data = []
    
    for i in range(len(files)):
        with open(files[i], 'r') as f:
            data = json.load(f)
        files_data.append(data)
    
    output_data = []
    for i in tqdm(range(len(files_data[0]))):
        item = {}
        item['prompt'] = files_data[0][i]['instruction']
        all_generated_responses = []
        for j in range(len(files)):
            all_generated_responses.append(files_data[j][i]['output'])
        item['all_generated_responses'] = all_generated_responses
        output_data.append(item)
    
    model = AutoModelForSequenceClassification.from_pretrained(args.reward_model, 
                                                            device_map="cuda", 
                                                            trust_remote_code=True, torch_dtype='float16', cache_dir = 'cache')
    tokenizer = AutoTokenizer.from_pretrained(args.reward_model, use_fast=True)
    
    accelerator.wait_for_everyone()
    start = time.process_time()
    
    # output_data = output_data[:32]
    with accelerator.split_between_processes(output_data) as data:
        results=dict(outputs=[])
        for item in tqdm(data):
            prompt = item["prompt"]
            candidates = item["all_generated_responses"]
            scores = []
            for candidate in candidates:
                messages = [{"role": "user", "content": prompt},
                            {"role": "assistant", "content": candidate}]
                # First, create the prompt text from your messages using the chat template.
                prompt_text = tokenizer.apply_chat_template(messages, tokenize = False)
                
                input_ids = tokenizer(
                    prompt_text,
                    max_length=2048,    # Set your desired maximum token length
                    truncation=True,    # Enable truncation to the maximum length
                    return_tensors="pt"
                )['input_ids'].to(accelerator.device)
                
                with torch.no_grad():
                    output = model(input_ids)
                    score = output.logits[0][0].item()
                    scores.append(score)
                    
            output_item = {
                    "prompt": prompt,
                    "all_generated_responses": candidates,
                    "all_reward_scores": scores
                }
            results['outputs'].append(output_item)
            
        results = [results]
    
    
    results_gathered=gather_object(results)
   
    if accelerator.is_main_process:
        final_results = []
        for i, result in enumerate(results_gathered):
            final_results.extend(result['outputs'])
        print(len(final_results))
        print("Time to process examples:-", time.process_time()-start)
       
        with open(f"llama3-8b-instruct-on-policy-swepo-1vsk-iteration1-train-data_rewards.json", 'w') as f:
            json.dump(final_results, f, indent=4)