from transformers import AutoModelForSequenceClassification, AutoTokenizer
from typing import Dict, List
import torch
from tqdm import tqdm
from datasets import load_dataset
import torch.multiprocessing as mp
from multiprocessing import Process, Queue
from prompts import RESPONSE_PROMPT, PREFERENCE_PROMPTS
import numpy as np
import os

def remove_first_sentence(text):
    sentence_endings = ['.', '!', '?']
    try:
        first_sentence_end = min([text.find(end) for end in sentence_endings if text.find(end) != -1])
    except:
        return text
    return text[first_sentence_end + 1:].strip()

class ArmoRMPipeline:
    def __init__(self, model_id, device_map="auto", torch_dtype=torch.bfloat16, truncation=True, trust_remote_code=False, max_length=4096):
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_id,
            device_map=device_map,
            trust_remote_code=trust_remote_code,
            torch_dtype=torch_dtype,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            use_fast=True,
        )
        self.truncation = truncation
        self.device = self.model.device
        self.max_length = max_length

    def __call__(self, messages: List[List[Dict[str, str]]]) -> List[float]:
        """
        messages: OpenAI chat messages to be scored
        Note: no batching since due to length differences, the model will have to pad to the max length which is not efficient
        Returns: a dictionary with the score between 0 and 1
        """
        batch_input = []
        for message in messages:
            message_text = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False)
            batch_input.append(message_text)
        
        batch_inputs = self.tokenizer(batch_input, padding=True, truncation=True, add_special_tokens=True, 
                                      max_length=self.max_length, return_tensors="pt")
        input_ids, attention_mask = batch_inputs.input_ids.to(self.device), batch_inputs.attention_mask.to(self.device)
        
        with torch.no_grad():
            output = self.model(input_ids=input_ids, attention_mask=attention_mask)
            score = output.score.cpu().tolist()
        return score
    
    def score(self, messages: List[List[Dict[str, str]]]) -> List[float]:
        '''
        return instruction following socre
        '''
        batch_input = []
        for message in messages:
            message_text = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False)
            batch_input.append(message_text)
        
        batch_inputs = self.tokenizer(batch_input, padding=True, truncation=True, add_special_tokens=True, 
                                      max_length=self.max_length, return_tensors="pt")
        input_ids, attention_mask = batch_inputs.input_ids.to(self.device), batch_inputs.attention_mask.to(self.device)
            
        with torch.no_grad():
            output = self.model(input_ids=input_ids, attention_mask=attention_mask)
            # Multi-objective rewards for the response
            multi_obj_rewards = output.rewards.cpu().float()
        
        return multi_obj_rewards[:, 6].tolist()

class RMEvaluater:
    def __init__(self, method, model_name, batch_size, num_processes) -> None:
        self.method = method
        self.model_name = model_name
        self.batch_size = batch_size
        self.num_processes = num_processes

    def get_evaluate_dataset(self):
        # load dataset
        self.file_path = f'responses/{self.preference}/{self.model_name}/{self.dataset_name}/{self.method}.json'
        print(f"Loading dataset from {self.file_path}")
        self.dataset = load_dataset('json', data_files=self.file_path)['train']
    
    def save_results(self, results: np.array):
        # save results
        root_path = f"results/{self.preference}/{self.model_name}/{self.dataset_name}"
        os.makedirs(root_path, exist_ok=True)
        np.save(f"{root_path}/{self.method}.npy", results)
    
    def load_reward_model(self, idx):
        # load reward model
        reward_model = ArmoRMPipeline("RLHFlow/ArmoRM-Llama3-8B-v0.1", device_map=f"cuda:{idx}", trust_remote_code=True)
        return reward_model
    
    def distribute_task(self, dataset_name, preference):
        self.dataset_name = dataset_name
        self.preference = preference
        
        # check if the response exists
        if not os.path.exists(f"responses/{preference}/{self.model_name}/{dataset_name}/{self.method}.json"):
            print(f"responses/{preference}/{self.model_name}/{dataset_name}/{self.method}.json does not exist.")
            return
        # check if the results are already saved
        if os.path.exists(f"results/{self.preference}/{self.model_name}/{self.dataset_name}/{self.method}.npy"):
            return
        # load dataset
        self.get_evaluate_dataset()
        # split dataset
        total_size = len(self.dataset)
        per_subset_size = total_size // self.num_processes
        subsets = [self.dataset.select(range(i * per_subset_size, (i + 1) * per_subset_size if i != self.num_processes - 1 else total_size)) for i in range(self.num_processes)]
        print("Data is ready to be processed.")
        
        # distribute each process
        queue = Queue()
        processes = []
        for idx in range(self.num_processes):
            p = Process(target=self.evaluate_worker, args=(subsets[idx], idx, queue))
            p.start()
            processes.append(p)  
        
        # Collect results
        results = []
        for _ in range(self.num_processes):
            results.extend(queue.get())
        for p in processes:
            p.join()
        
        results = np.array(results)
        self.save_results(results)
        
        return results
        
    def evaluate_worker(self, subset, idx, queue):
        # load rm
        reward_model = self.load_reward_model(idx)
        indices = np.array_split(np.arange(len(subset)), len(subset)//self.batch_size + 1)
        score_list = []
        for indice in tqdm(indices):
            batch = subset.select(indice)
            messages = [
                [
                    {"role": "user", "content": RESPONSE_PROMPT.format(question=batch[i]['question'], preference=batch['preference'][i])},
                    {"role": "assistant", "content": batch[i]['response']}
                ] for i in range(len(batch))
            ]
            scores = reward_model.score(messages)
            score_list.extend(scores)
        queue.put(score_list)

def minmax_normalize(scores):
    return (scores - np.min(scores)) / (np.max(scores) - np.min(scores))

if __name__ == '__main__':
    mp.set_start_method('spawn', force=True)
    
    for preference in PREFERENCE_PROMPTS.keys():
        print(f"\n################# {preference} is start #################n")
        for model_name in ["Mistral-7B-Instruct-v0.2", "Qwen2-7B-Instruct", "Meta-Llama-3.1-8B-Instruct", 'Llama-2-7b-chat-hf']:
            for dataset_name in ["HelpSteer_train", "personal_preference_eval_preference_data", "UltraFeedback_truthful_qa", "UltraFeedback_ultrachat"]:
                print(f"\n==============={preference} {model_name} {dataset_name} is done ===============")
                method_res_list = []
                for method in ['base', 'pref', 'la', 'amulet20', 'amulet40', 'amulet60', 'amulet80', 'amulet100']:
                    file_path = f"results/{preference}/{model_name}/{dataset_name}/{method}.npy"
                    results = np.load(file_path)
                    method_res_list.append(np.round(np.mean(results),4))
                # method_res_list = minmax_normalize(np.array(method_res_list))
                method_res_list = np.array(method_res_list) / method_res_list[0]
                for i,name in zip(method_res_list,['base', 'pref', 'la', 'amulet20', 'amulet40', 'amulet60', 'amulet80', 'amulet100']):
                    print(f"{name} Mean score: {np.round(i,4)}")
                print("================= Results is done ================\n")
        print(f"\n################# {preference} is done #################n")

    