
import sys 
import os
import importlib.util
import torch
import json
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset, Dataset
import numpy as np
import ray 
from tqdm import tqdm
from pathlib import Path

os.environ["HF_HOME"] = "/workspace/rlhf-code/.cache/root"

# Import the utils module directly using the file path
utils_path = "/workspace/rlhf-code/code/utils.py"
spec = importlib.util.spec_from_file_location("utils", utils_path)
utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(utils)

# Extract the functions we need
get_reward = utils.get_reward
pad_to_length = utils.pad_to_length

class RewardModelInference:
    def __init__(self, model_path, tokenizer_path, max_seq_length):
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path).to("cuda")
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.max_seq_length = max_seq_length

    def __call__(self, batch):
        if 'query_chosen_token' in batch:
            
            query_chosen_token = [pad_to_length(torch.tensor(x), self.max_seq_length, pad_value = 0).tolist() for x in batch['query_chosen_token']]
            query_rejected_token = [pad_to_length(torch.tensor(x), self.max_seq_length, pad_value = 0).tolist() for x in batch['query_rejected_token']]

            wef = all(len(x) == 638 for x in query_chosen_token)
            query_chosen_token = torch.tensor(query_chosen_token).to(self.model.device)
            query_rejected_token = torch.tensor(query_rejected_token).to(self.model.device)
            
            concatenated_input_ids = torch.cat((query_chosen_token, query_rejected_token), dim=0)
            
            _, predicted_rewards, _ = get_reward(model=self.model, query_responses=concatenated_input_ids, pad_token_id=0, context_length=0)
            
            chosen_rewards = predicted_rewards[:len(query_chosen_token)].tolist()
            rejected_rewards = predicted_rewards[len(query_chosen_token):].tolist()
            gt_chosen_rewards = [batch['summary_rewards'][i][int(batch['choice'][i])] for i in range(len(batch['summary_rewards']))]
            gt_rejected_rewards = [batch['summary_rewards'][i][1-int(batch['choice'][i])] for i in range(len(batch['summary_rewards']))]
            
            res = []
            for i in range(len(batch['summary_rewards'])):
                batch_itm = {k : v[i] for k, v in batch.items()}
                res.append({
                    "chosen_reward": chosen_rewards[i],
                    "rejected_reward": rejected_rewards[i],
                    "gt_chosen_reward": gt_chosen_rewards[i],
                    "gt_rejected_reward": gt_rejected_rewards[i],
                    "chosen_reward_gap": chosen_rewards[i] - gt_chosen_rewards[i],
                    "rejected_reward_gap": rejected_rewards[i] - gt_rejected_rewards[i],
                    "overall_reward_gap": abs((chosen_rewards[i] - rejected_rewards[i]) - (gt_chosen_rewards[i] - gt_rejected_rewards[i])),
                    **batch_itm
                })
            return res
            
        else: 
            raise ValueError("Invalid batch format")
        
        


def main(args):
    import ray
    if not ray.is_initialized():
        ray.init()
        
    @ray.remote(num_gpus=1/args.num_workers_per_gpu)
    class RMActor:
        def __init__(self): 
            self.rm = RewardModelInference(args.model_path, args.tokenizer_path, args.max_seq_length)
        
        def process_item(self, batch): 
            return self.rm(batch)
    
    ds = load_dataset(args.ds_path, split=args.split).take(10000) 
    ds = ds.map(lambda x, idx: {"idx": idx}, with_indices=True)
    ray_dataset = ray.data.from_huggingface(ds)
    
    actors = [RMActor.remote() for _ in range(args.num_gpus * args.num_workers_per_gpu)]
    
    futures = []
    for batch in ray_dataset.iter_batches(batch_size=args.batch_size):
        actor_id = len(futures) % len(actors)
        futures.append(actors[actor_id].process_item.remote(batch))    
    

    ds_with_rewards = []
    with tqdm(total=len(futures), desc="Processing batches") as pbar:
        while futures:
                done_refs, remaining_refs = ray.wait(futures, num_returns=1)
                ds_with_rewards.append(ray.get(done_refs)[0])
                futures = remaining_refs
                pbar.update(1)
    
    #store mean and std of reward gaps 
    overall_reward_gap = []
    chosen_reward_gap = []
    rejected_reward_gap = []
    
    flattened_ds_with_rewards = []
    for batch_results in ds_with_rewards:
        flattened_ds_with_rewards.extend(batch_results)
        for item in batch_results:
            overall_reward_gap.append(item['overall_reward_gap'])
            chosen_reward_gap.append(item['chosen_reward_gap'])
            rejected_reward_gap.append(item['rejected_reward_gap'])
    
    mean_overall_reward_gap = np.mean(overall_reward_gap)
    std_overall_reward_gap = np.std(overall_reward_gap)
    mean_chosen_reward_gap = np.mean(chosen_reward_gap)
    std_chosen_reward_gap = np.std(chosen_reward_gap)
    mean_rejected_reward_gap = np.mean(rejected_reward_gap)
    std_rejected_reward_gap = np.std(rejected_reward_gap)
    
    flattened_ds_with_rewards = Dataset.from_list(flattened_ds_with_rewards) 
    
    #sort by idx
    flattened_ds_with_rewards = flattened_ds_with_rewards.sort("idx")
    
    #remove idx
    flattened_ds_with_rewards = flattened_ds_with_rewards.remove_columns("idx") 
    
    #save to hub
    flattened_ds_with_rewards.push_to_hub(f"{args.save_ds_path}")
    
    #save mean and std of reward gaps
    with open(f"{args.save_ds_path.split('/')[-1]}-mean-std.json", "w") as f:
        json.dump({"mean_overall_reward_gap": mean_overall_reward_gap, "std_overall_reward_gap": std_overall_reward_gap, "mean_chosen_reward_gap": mean_chosen_reward_gap, "std_chosen_reward_gap": std_chosen_reward_gap, "mean_rejected_reward_gap": mean_rejected_reward_gap, "std_rejected_reward_gap": std_rejected_reward_gap}, f)
    
    
    print(f"Mean overall reward gap: {mean_overall_reward_gap}")
    print(f"Std overall reward gap: {std_overall_reward_gap}")
    print(f"Mean chosen reward gap: {mean_chosen_reward_gap}")
    print(f"Std chosen reward gap: {std_chosen_reward_gap}")
    print(f"Mean rejected reward gap: {mean_rejected_reward_gap}")
    print(f"Std rejected reward gap: {std_rejected_reward_gap}")
    

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str,  default=None)
    parser.add_argument("--tokenizer_path", type=str,  default="EleutherAI/pythia-1b-deduped")
    parser.add_argument("--max_seq_length", type=int,  default=512)
    parser.add_argument("--batch_size", type=int,  default=32)
    parser.add_argument("--ds_path", type=str,  default=None)
    parser.add_argument('--save_ds_path', type = str, default = None)
    parser.add_argument("--split", type=str,  default="train")
    parser.add_argument("--num_gpus", type=int,  default=1)
    parser.add_argument("--num_workers_per_gpu", type=int,  default=1)
    args = parser.parse_args()
    
    main(args)
    
    