import torch
import os
import numpy as np
from dataclasses import dataclass
from typing import List, Optional, Union, Tuple
import torch.nn as nn
import torch.utils.checkpoint
from transformers import Gemma2PreTrainedModel, Gemma2Model, Gemma2Config, AutoModelForSequenceClassification, AutoTokenizer
# from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.utils import ModelOutput
from transformers.utils import add_start_docstrings_to_model_forward
from accelerate import infer_auto_device_map, dispatch_model
from accelerate.utils import get_balanced_memory
from modelscope import AutoModel
import torch.nn.functional as F
from acecoder import AceCodeRM
import sys
sys.path.append('..')
from global_utils.utils import generate_general, generate_general_rm, async_generate_general
from transformers import LlamaPreTrainedModel, LlamaModel, PreTrainedTokenizerFast
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
# from utils import generate_general, generate_general_rm, async_generate_general
import math
import re
from tqdm import tqdm
import asyncio
import time
from tqdm import tqdm
from vllm import LLM
gpu_num = torch.cuda.device_count()

DEFAULT_COMPARE_STANDARD = "Correctness is the first criterion: if A is correct and B is wrong," \
                           " then A is better than B without considering other criterions; " \
                           "the second criterion is logical structure: if both solutions are correct, " \
                           "evaluate their clarity, logical flow, and correctness of intermediate steps;" \
                           "the third criterion is insightfulness: determine which solution provides a deeper " \
                           "understanding, useful insights, or a more elegant approach."


class MultiOutputNN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims=[4096, 4096]):
        super(MultiOutputNN, self).__init__()

        layers = []

        layers.append(nn.Linear(input_dim, hidden_dims[0]))
        layers.append(nn.LeakyReLU())

        for i in range(1, len(hidden_dims)):
            layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
            layers.append(nn.LeakyReLU())

        layers.append(nn.Linear(hidden_dims[-1], output_dim))

        self.network = nn.Sequential(*layers)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.network(x)
        return self.softmax(x.view(x.size(0), -1, 10))


class GatingNN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=4096, num_layers=2, temperature=1.0, dropout_prob=0.0,
                 softmax=False):
        super(GatingNN, self).__init__()
        self.temperature = temperature
        self.softmax = softmax
        layers = []

        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.LeakyReLU())
        layers.append(nn.Dropout(dropout_prob))

        for i in range(num_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.LeakyReLU())
            layers.append(nn.Dropout(dropout_prob))

        layers.append(nn.Linear(hidden_dim, output_dim))

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        x = self.network(x)
        if self.softmax:
            x = F.softmax(x / self.temperature, dim=1)
        return x


@dataclass
class CustomOutput(ModelOutput):
    rewards: torch.FloatTensor = None
    hidden_state: Optional[Tuple[torch.FloatTensor, ...]] = None
    score: Optional[torch.FloatTensor] = None
    total_reward_distribution: Optional[torch.FloatTensor] = None
    weights: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None


class LDLRewardModel27B(Gemma2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        print('----------------num labels',self.num_labels)
        self.model = Gemma2Model(config)
        config_dict = config.to_dict()
        self.num_objectives = config_dict.get("num_objectives", 220)
        self.regression_layer = MultiOutputNN(config.hidden_size, self.num_objectives)
        self.gating_layer = GatingNN(
            config.hidden_size,
            self.num_objectives // 10,
            temperature=config_dict.get("temperature", 1.0),
            softmax=config_dict.get("softmax", False),
        )

    def forward(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            labels: Optional[torch.FloatTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> CustomOutput:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        tokens_hidden_states = transformer_outputs[0]

        if input_ids is not None:
            batch_size = input_ids.shape[0]
        else:
            batch_size = inputs_embeds.shape[0]

        if self.config.pad_token_id is None and batch_size != 1:
            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
        if self.config.pad_token_id is None:
            sequence_lengths = -1
        else:
            if input_ids is not None:
                # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
                sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
                sequence_lengths = sequence_lengths % input_ids.shape[-1]
                sequence_lengths = sequence_lengths.to(tokens_hidden_states.device)
            else:
                sequence_lengths = -1

        dummy_iterator = torch.arange(batch_size, device=tokens_hidden_states.device)
        hidden_states = tokens_hidden_states[dummy_iterator, sequence_lengths]
        assert hidden_states.shape == (batch_size, self.config.hidden_size)
        with torch.autocast(device_type=hidden_states.device.type, dtype=torch.float32):
            rewards = self.regression_layer(hidden_states)
            weights = self.gating_layer(hidden_states)
            weights = weights.unsqueeze(1)
            total_reward_distribution = torch.bmm(weights, rewards).squeeze(1)
            score = (
                    total_reward_distribution
                    * torch.linspace(0, 1, total_reward_distribution.size(-1)).to(tokens_hidden_states.device)
            ).sum(-1)
        return CustomOutput(
            rewards=rewards,
            weights=weights,
            hidden_state=hidden_states,
            total_reward_distribution=total_reward_distribution,
            score=score,
            logits=score,
        )

    # def save_pretrained(self, save_directory: str):
    #     self.model.save_pretrained(save_directory, dtype=torch.bfloat16)
    #     torch.save(self.regression_layer.state_dict(), os.path.join(save_directory, "regression_layer.pt"))
    #     torch.save(self.gating_layer.state_dict(), os.path.join(save_directory, "gating_layer.pt"))
    #     self.config.save_pretrained(save_directory)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, device_map=None, *model_args, **kwargs):
        cached_dir = pretrained_model_name_or_path
        model = super(LDLRewardModel27B, cls).from_pretrained(
            cached_dir, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
        )

        model.regression_layer = model.regression_layer.float()
        regression_layer_path = os.path.join(cached_dir, "regression_layer.pt")
        regression_layer_state_dict = torch.load(regression_layer_path, map_location="cpu")
        model.regression_layer.load_state_dict(regression_layer_state_dict)

        model.gating_layer = model.gating_layer.float()
        gating_layer_path = os.path.join(cached_dir, "gating_layer.pt")
        gating_layer_state_dict = torch.load(gating_layer_path, map_location="cpu")
        model.gating_layer.load_state_dict(gating_layer_state_dict)

        if device_map == "auto" or device_map == "balanced":
            max_memory = get_balanced_memory(model, no_split_module_classes=["Gemma2DecoderLayer", "Gemma2RMSNorm"])
            device_map = infer_auto_device_map(
                model,
                no_split_module_classes=["Gemma2DecoderLayer", "Gemma2RMSNorm"],
                max_memory=max_memory,
            )
            model = dispatch_model(model, device_map=device_map)
        elif device_map is not None:
            raise NotImplementedError("Write your own device map")

        return model





class INFORMForSequenceClassification(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = LlamaModel(config)
        self.score = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, self.num_labels)
        )

        print('--------self num labels',self.num_labels)
        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):

        transformer_outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
        )
        hidden_states = transformer_outputs[0]
        logits = self.score(hidden_states)

        if input_ids is not None:
            batch_size = input_ids.shape[0]
        else:
            batch_size = inputs_embeds.shape[0]

        if self.config.pad_token_id is None and batch_size != 1:
            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
        if self.config.pad_token_id is None:
            sequence_lengths = -1
        else:
            if input_ids is not None:
                # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
                sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
                sequence_lengths = sequence_lengths % input_ids.shape[-1]
                sequence_lengths = sequence_lengths.to(logits.device)
            else:
                sequence_lengths = -1

        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

        loss = None
        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=pooled_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )



class SkyworkORM:
    def __init__(self, model_path, name, device):
        assert 'Skywork' in name
        self.rm_tokenizer = None
        self.rm = None
        self.model_path = model_path
        self.name = name
        self.device = device
        self.load_model()

    def load_model(self):
        self.rm = AutoModelForSequenceClassification.from_pretrained(
            pretrained_model_name_or_path=self.model_path,
            attn_implementation="flash_attention_2",
            torch_dtype=torch.bfloat16,
            device_map=self.device,
            num_labels=1,
        )
        self.rm_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_path)

    def obtain_reward(self, question, response, batch_size=4):
        # obtain the reward of each response
        # if using batch mode, question is list[str], response is list[str], return list[float]
        ins_num = len(response)
        batch_index = list(range(0, ins_num, batch_size)) + [ins_num]
        reward_score = []
        for i in tqdm(range(len(batch_index) - 1)):
            st_index, end_index = batch_index[i], batch_index[i + 1]
            conv_list = [[{"role": "user", "content": q},
                          {"role": "assistant", "content": r}] for r, q in
                         zip(np.array(response)[st_index:end_index].tolist(),
                             np.array(question)[st_index:end_index].tolist())]
            conv_list_formatted = self.rm_tokenizer.apply_chat_template(conv_list, tokenize=False)
            conv_list_tokenized = self.rm_tokenizer(conv_list_formatted, return_tensors="pt", padding=True).to(self.rm.device)
            with torch.no_grad():
                reward_score.extend(self.rm(**conv_list_tokenized).logits[:, 0].tolist())
        return reward_score

class SkyworkORM_VLLM_version:
    def __init__(self, model_path, name, device):
        assert 'Skywork' in name
        self.rm_tokenizer = None
        self.rm = None
        self.model_path = model_path
        self.name = name
        self.device = device
        self.load_model()

    def load_model(self):
        engine_args = {"rope_scaling": {
            "factor": 16,
            "original_max_position_embeddings": 8192,
            "type": "yarn",
            "rope_type": "yarn"
        }}
        self.rm = LLM(model=self.model_path, dtype='bfloat16',task='reward',tensor_parallel_size=gpu_num,trust_remote_code=True,**engine_args)
        self.rm_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_path)

    def obtain_reward(self, question, response, batch_size=4):
        # obtain the reward of each response
        # if using batch mode, question is list[str], response is list[str], return list[float]
        ins_num = len(response)
        batch_index = list(range(0, ins_num, batch_size)) + [ins_num]
        reward_score = []
        for i in tqdm(range(len(batch_index) - 1)):
            st_index, end_index = batch_index[i], batch_index[i + 1]
            conv_list = [[{"role": "user", "content": q},
                          {"role": "assistant", "content": r}] for r, q in
                         zip(np.array(response)[st_index:end_index].tolist(),
                             np.array(question)[st_index:end_index].tolist())]
            conv_list_formatted = self.rm_tokenizer.apply_chat_template(conv_list, tokenize=False)
            with torch.no_grad():
                output = self.rm.encode(conv_list_formatted,use_tqdm=False)
                rm_scores=torch.tensor([x.outputs.data[-1] for x in output])
                
                reward_score.extend(rm_scores.tolist())
        return reward_score



class QwenPRM:
    def __init__(self, model_path, name, device):
        assert 'Qwen' in name
        self.rm_tokenizer = None
        self.rm = None
        self.model_path = model_path
        self.name = name
        self.device = device
        self.load_model()

    def load_model(self):
        self.rm = AutoModel.from_pretrained(
            pretrained_model_name_or_path=self.model_path,
            device_map=self.device,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
        )
        self.rm_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_path,
                                                     trust_remote_code=True)

    def obtain_reward(self, question, response, batch_size=4):
        # obtain the reward of each response
        ins_num = len(response)
        batch_index = list(range(0, ins_num, batch_size)) + [ins_num]
        reward_score = []
        for i in tqdm(range(len(batch_index) - 1)):
            st_index, end_index = batch_index[i], batch_index[i + 1]
            conv_list = [[{"role": "user", "content": q}, {"role": "assistant", "content": "<extra_0>".join(
                r.split('\n')) + "<extra_0>"}] for r, q in zip(np.array(response)[st_index:end_index].tolist(),
                                                               np.array(question)[st_index:end_index].tolist())]
            conv_list_formatted = self.rm_tokenizer.apply_chat_template(conv_list, tokenize=False, add_generation_prompt=False)
            conv_list_tokenized = self.rm_tokenizer(conv_list_formatted, return_tensors="pt", padding=True).to(self.rm.device)
            with torch.no_grad():
                out = self.rm(**conv_list_tokenized)
            step_sep_id = self.rm_tokenizer.encode("<extra_0>")[0]
            token_masks = (conv_list_tokenized.data['input_ids'] == step_sep_id)
            probabilities = F.softmax(out[0], dim=-1)
            probabilities = probabilities * token_masks.unsqueeze(-1)  # bs, seq_len, num_labels
            all_scores_res = []
            for j in range(probabilities.size(0)):
                sample = probabilities[j]  # seq_len, num_labels
                positive_probs = sample[sample != 0].view(-1, 2)[:, 1]  # valid_tokens, num_labels
                all_scores_res.append(positive_probs.prod().item())
            reward_score.extend(all_scores_res)
        return reward_score


class QwenRM:
    def __init__(self, model_path, name, device):
        assert 'Qwen' in name
        self.rm_tokenizer = None
        self.rm = None
        self.model_path = model_path
        self.name = name
        self.device = device
        self.load_model()

    def load_model(self):
        print(self.model_path)
        self.rm = AutoModel.from_pretrained(
            pretrained_model_name_or_path=self.model_path,
            # attn_implementation="flash_attention_2",
            device_map=self.device,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
        ).eval()
        # print('model pad token',self.rm.config.pad_token_id)
        self.rm_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_path,
                                                     trust_remote_code=True)



    def obtain_reward(self, question, response, batch_size=4):
        # obtain the reward of each response
        # if using batch mode, question is list[str], response is list[str], return list[float]
        #TODO: error for batch size
        ins_num = len(response)
        batch_index = list(range(0, ins_num, batch_size)) + [ins_num]
        reward_score = []
        for i in tqdm(range(len(batch_index) - 1)):
            st_index, end_index = batch_index[i], batch_index[i + 1]
            conv_list = [[{"role": "user", "content": q},
                          {"role": "assistant", "content": r}] for r, q in
                         zip(np.array(response)[st_index:end_index].tolist(),
                             np.array(question)[st_index:end_index].tolist())]
            conv_list_formatted = self.rm_tokenizer.apply_chat_template(conv_list, tokenize=False)
            conv_list_tokenized = self.rm_tokenizer(conv_list_formatted, return_tensors="pt", padding=True).to(self.rm.device)
            with torch.no_grad():
                # reward_score.append(self.rm(**conv_list_tokenized)[0].item())
                # print(self.rm(**conv_list_tokenized).logits)
                output=self.rm(**conv_list_tokenized)
                # print(output.logits)
                reward_score.extend(output.logits[:, 0].tolist())

        return reward_score

class QwenRM_VLLM_version:
    def __init__(self, model_path, name, device):
        assert 'Qwen' in name
        self.rm_tokenizer = None
        self.rm = None
        self.model_path = model_path
        self.name = name
        self.device = device
        self.load_model()

    def load_model(self):
        # print(self.model_path)
        engine_args = {"rope_scaling": {
            "factor": 32,
            "original_max_position_embeddings": 4096,
            "type": "yarn",
            "rope_type": "yarn"
        }}
        self.rm = LLM(self.model_path, task="reward", dtype='bfloat16', tensor_parallel_size=gpu_num,
                    trust_remote_code=True,**engine_args)
        self.rm_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_path,
                                                     trust_remote_code=True)
        # self.rm.config.pad_token_id = self.rm_tokenizer.pad_token_id
        # self.rm_tokenizer.pad_token = self.rm_tokenizer.eos_token
        # print('toknizer pad token',self.rm_tokenizer.pad_token)
        # self.rm_tokenizer


    def obtain_reward(self, question, response, batch_size=1):
        
        if batch_size>1:
            raise ValueError("Currently only batch_size=1 is supported due to bug in VLM Reward. Note that rope scaling is automatically applied due to the same bug.")
        else:
            ins_num = len(response)
            batch_index = list(range(0, ins_num, batch_size)) + [ins_num]
            reward_score = []
            for i in tqdm(range(len(batch_index) - 1)):
                st_index, end_index = batch_index[i], batch_index[i + 1]
                conv_list = [[{"role": "user", "content": q},
                            {"role": "assistant", "content": r}] for r, q in
                            zip(np.array(response)[st_index:end_index].tolist(),
                                np.array(question)[st_index:end_index].tolist())]
                conv_list_formatted = self.rm_tokenizer.apply_chat_template(conv_list, tokenize=False)
                with torch.no_grad():
                    # reward_score.append(self.rm(**conv_list_tokenized)[0].item())
                    # print(self.rm(**conv_list_tokenized).logits)
                    output=self.rm.encode(conv_list_formatted,use_tqdm=False)
                    reward_score.extend(output[0].outputs.data[-1].tolist())      # This is not a regular implementation due to the bug in VLM, works but not regulated.

        return reward_score




class AceCodeRM_VLLM_version:
    def __init__(self, model_path, name, device):
        assert 'Ace' in name
        self.rm_tokenizer = None
        self.rm = None
        self.model_path = model_path
        self.name = name
        self.device = device
        self.load_model()

    def load_model(self):
        engine_args = {"rope_scaling": {
            "factor": 4,
            "original_max_position_embeddings": 32768,
            "type": "yarn",
            "rope_type": "yarn"
        }}
        self.rm = LLM(model=self.model_path ,task='reward',tensor_parallel_size=gpu_num,**engine_args)
        self.rm_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_path,
                                                     trust_remote_code=True)

    def obtain_reward(self, question, response, batch_size=1):
        # obtain the reward of each response
        ins_num = len(response)
        batch_index = list(range(0, ins_num, batch_size)) + [ins_num]
        reward_score = []
        for i in tqdm(range(len(batch_index) - 1)):
            st_index, end_index = batch_index[i], batch_index[i + 1]
            program_chats = [
                [
                    {
                        "content": q,
                        "role": "user",
                    },
                    {
                        "role": "assistant",
                        "content": r
                    }
                ] for r, q in zip(np.array(response)[st_index:end_index].tolist(),
                                                               np.array(question)[st_index:end_index].tolist())
            ]
            input_tokens = self.rm_tokenizer.apply_chat_template(
                program_chats,
                tokenize=False
            )
            with torch.no_grad():
                output = self.rm.encode(input_tokens,use_tqdm=False)
                rm_scores=torch.tensor([x.outputs.data[-1] for x in output])
                
                reward_score.extend(rm_scores.tolist())
        return reward_score




class AceCodeRMWrap:
    def __init__(self, model_path, name, device):
        assert 'Ace' in name
        self.rm_tokenizer = None
        self.rm = None
        self.model_path = model_path
        self.name = name
        self.device = device
        self.load_model()

    def load_model(self):
        self.rm = AceCodeRM.from_pretrained(self.model_path, device_map=self.device)
        self.rm_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_path,
                                                     trust_remote_code=True)

    def obtain_reward(self, question, response, batch_size=4):
        # obtain the reward of each response
        ins_num = len(response)
        batch_index = list(range(0, ins_num, batch_size)) + [ins_num]
        reward_score = []
        for i in tqdm(range(len(batch_index) - 1)):
            st_index, end_index = batch_index[i], batch_index[i + 1]
            program_chats = [
                [
                    {
                        "content": q,
                        "role": "user",
                    },
                    {
                        "role": "assistant",
                        "content": r
                    }
                ] for r, q in zip(np.array(response)[st_index:end_index].tolist(),
                                                               np.array(question)[st_index:end_index].tolist())
            ]
            input_tokens = self.rm_tokenizer.apply_chat_template(
                program_chats,
                tokenize=True,
                return_dict=True,
                padding=True,
                return_tensors="pt",
                # max_length=1024
            ).to(self.rm.device)
            with torch.no_grad():
                rm_scores = self.rm(
                    **input_tokens,
                    output_hidden_states=True,
                    return_dict=True,
                    use_cache=False,
                )
            if len(rm_scores.shape) == 0:
                reward_score.extend([rm_scores.item()])
            else:
                reward_score.extend(rm_scores.tolist())
        return reward_score


class LDLRM:
    def __init__(self, model_path, name, device):
        self.rm_tokenizer = None
        self.rm = None
        self.model_path = model_path
        self.name = name
        self.device = device
        self.load_model()

    def load_model(self):
        self.rm = LDLRewardModel27B.from_pretrained(
            pretrained_model_name_or_path=self.model_path,
            torch_dtype=torch.bfloat16,
            device_map=self.device,
        )
        self.rm_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_path)

    def obtain_reward(self, question, response, batch_size=4):
        # obtain the reward of each response
        # if using batch mode, question is list[str], response is list[str], return list[float]
        ins_num = len(response)
        batch_index = list(range(0, ins_num, batch_size)) + [ins_num]
        reward_score = []
        for i in tqdm(range(len(batch_index) - 1)):
            st_index, end_index = batch_index[i], batch_index[i + 1]
            conv_list = [[{"role": "user", "content": q},
                          {"role": "assistant", "content": r}] for r, q in
                         zip(np.array(response)[st_index:end_index].tolist(),
                             np.array(question)[st_index:end_index].tolist())]
            conv_list_formatted = self.rm_tokenizer.apply_chat_template(conv_list, tokenize=False)
            conv_list_tokenized = self.rm_tokenizer(conv_list_formatted, return_tensors="pt", padding=True).to(self.rm.device)
            with torch.no_grad():
                outputs = self.rm(**conv_list_tokenized)
                reward_score.extend(outputs.logits.tolist())
        return reward_score


class LDLRM_VLLM_version:
    def __init__(self, model_path, name, device):
        self.rm_tokenizer = None
        self.rm = None
        self.model_path = model_path
        self.name = name
        self.device = device
        self.load_model()

    def load_model(self):
        engine_args = {"rope_scaling": {
            "factor": 16,
            "original_max_position_embeddings": 8192,
            "type": "yarn",
            "rope_type": "yarn"
        }}
        self.rm = LLM(model=self.model_path, dtype='bfloat16',task='reward',tensor_parallel_size=gpu_num,trust_remote_code=True,**engine_args)
        self.rm_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_path)

    def obtain_reward(self, question, response, batch_size=4):
        # obtain the reward of each response
        # if using batch mode, question is list[str], response is list[str], return list[float]
        ins_num = len(response)
        batch_index = list(range(0, ins_num, batch_size)) + [ins_num]
        reward_score = []
        for i in tqdm(range(len(batch_index) - 1)):
            st_index, end_index = batch_index[i], batch_index[i + 1]
            conv_list = [[{"role": "user", "content": q},
                          {"role": "assistant", "content": r}] for r, q in
                         zip(np.array(response)[st_index:end_index].tolist(),
                             np.array(question)[st_index:end_index].tolist())]
            conv_list_formatted = self.rm_tokenizer.apply_chat_template(conv_list, tokenize=False)
            # conv_list_tokenized = self.rm_tokenizer(conv_list_formatted, return_tensors="pt", padding=True).to(self.rm.device)
            with torch.no_grad():
                output = self.rm.encode(conv_list_formatted,use_tqdm=False)
                rm_scores=torch.tensor([x.outputs.data for x in output])
                reward_score.extend(rm_scores.tolist())
        return reward_score


class INFORM:
    def __init__(self, model_path, name, device):
        assert 'INF' in name
        self.rm_tokenizer = None
        self.rm = None
        self.model_path = model_path
        self.name = name
        self.device = device
        self.load_model()

    def load_model(self):
        self.rm = INFORMForSequenceClassification.from_pretrained(
                        pretrained_model_name_or_path=self.model_path,
                        torch_dtype=torch.bfloat16,
                        device_map=self.device,
                        attn_implementation="flash_attention_2",
                        num_labels=1,
                    )
        self.rm_tokenizer = PreTrainedTokenizerFast.from_pretrained(pretrained_model_name_or_path=self.model_path)
        self.rm.config.pad_token_id=self.rm_tokenizer.pad_token_id

    def obtain_reward(self, question, response, batch_size=4):
        # obtain the reward of each response
        # if using batch mode, question is list[str], response is list[str], return list[float]
        ins_num = len(response)
        batch_index = list(range(0, ins_num, batch_size)) + [ins_num]
        reward_score = []
        for i in tqdm(range(len(batch_index) - 1)):
            st_index, end_index = batch_index[i], batch_index[i + 1]
            conv_list = [[{"role": "user", "content": q},
                          {"role": "assistant", "content": r}] for r, q in
                         zip(np.array(response)[st_index:end_index].tolist(),
                             np.array(question)[st_index:end_index].tolist())]
            conv_list_formatted = self.rm_tokenizer.apply_chat_template(conv_list, tokenize=False)
            conv_list_tokenized = self.rm_tokenizer(conv_list_formatted, return_tensors="pt", padding=True).to(self.rm.device)
            with torch.no_grad():
                reward_score.extend(self.rm(**conv_list_tokenized).logits[:, 0].tolist())
        return reward_score


class INFORM_VLLM_version:
    def __init__(self, model_path, name, device):
        assert 'INF' in name
        self.rm_tokenizer = None
        self.rm = None
        self.model_path = model_path
        self.name = name
        self.device = device
        self.load_model()

    def load_model(self):
        self.rm = LLM(model=self.model_path ,task='reward',tensor_parallel_size=gpu_num)
        self.rm_tokenizer = PreTrainedTokenizerFast.from_pretrained(pretrained_model_name_or_path=self.model_path)
        # self.rm.config.pad_token_id=self.rm_tokenizer.pad_token_id

    def obtain_reward(self, question, response, batch_size=1):
        if batch_size>1:
            raise ValueError("Currently only batch_size=1 is supported due to bug in VLM Reward. Note that rope scaling is automatically applied due to the same bug.")
        else:
            ins_num = len(response)
            batch_index = list(range(0, ins_num, batch_size)) + [ins_num]
            reward_score = []
            for i in tqdm(range(len(batch_index) - 1)):
                st_index, end_index = batch_index[i], batch_index[i + 1]
                conv_list = [[{"role": "user", "content": q},
                            {"role": "assistant", "content": r}] for r, q in
                            zip(np.array(response)[st_index:end_index].tolist(),
                                np.array(question)[st_index:end_index].tolist())]
                conv_list_formatted = self.rm_tokenizer.apply_chat_template(conv_list, tokenize=False)
                with torch.no_grad():
                    output= self.rm.encode(conv_list_formatted,use_tqdm=False)
                    reward_score.extend(output[0].outputs.data.tolist())
            return reward_score


class QRM:
    def __init__(self, model_path, name, device):
        self.rm_tokenizer = None
        self.rm = None
        self.model_path = model_path
        self.name = name
        self.device = device
        self.load_model()

    def load_model(self):
        self.rm = AutoModelForSequenceClassification.from_pretrained(self.model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map=self.device, trust_remote_code=True)
        self.rm_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_path, use_fast=True)

    def obtain_reward(self, question, response, batch_size=4):
        # obtain the reward of each response
        # if using batch mode, question is list[str], response is list[str], return list[float]
        ins_num = len(response)
        batch_index = list(range(0, ins_num, batch_size)) + [ins_num]
        reward_score = []
        for i in tqdm(range(len(batch_index) - 1)):
            st_index, end_index = batch_index[i], batch_index[i + 1]
            conv_list = [[{"role": "user", "content": q},
                          {"role": "assistant", "content": r}] for r, q in
                         zip(np.array(response)[st_index:end_index].tolist(),
                             np.array(question)[st_index:end_index].tolist())]
            conv_list_formatted = self.rm_tokenizer.apply_chat_template(conv_list, tokenize=False)
            conv_list_tokenized = self.rm_tokenizer(conv_list_formatted, return_tensors="pt", padding=True).to(self.rm.device)
            with torch.no_grad():
                outputs = self.rm(**conv_list_tokenized)
                # print(outputs.logits.size())
                reward_score.extend(outputs.logits[:, 0].tolist())
        return reward_score

class QRM_VLLM_version:
    def __init__(self, model_path, name, device):
        self.rm_tokenizer = None
        self.rm = None
        self.model_path = model_path
        self.name = name
        self.device = device
        self.load_model()

    def load_model(self):
        engine_args = {"rope_scaling": {
            "factor": 16,
            "original_max_position_embeddings": 8192,
            "type": "yarn",
            "rope_type": "yarn"
        }}
        self.rm = LLM(model=self.model_path, dtype='bfloat16',task='reward',tensor_parallel_size=gpu_num,trust_remote_code=True,**engine_args)
        self.rm_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_path,use_fast=True)

    def obtain_reward(self, question, response, batch_size=4):
        # obtain the reward of each response
        # if using batch mode, question is list[str], response is list[str], return list[float]
        ins_num = len(response)
        batch_index = list(range(0, ins_num, batch_size)) + [ins_num]
        reward_score = []
        for i in tqdm(range(len(batch_index) - 1)):
            st_index, end_index = batch_index[i], batch_index[i + 1]
            conv_list = [[{"role": "user", "content": q},
                          {"role": "assistant", "content": r}] for r, q in
                         zip(np.array(response)[st_index:end_index].tolist(),
                             np.array(question)[st_index:end_index].tolist())]
            conv_list_formatted = self.rm_tokenizer.apply_chat_template(conv_list, tokenize=False)
            # conv_list_tokenized = self.rm_tokenizer(conv_list_formatted, return_tensors="pt", padding=True).to(self.rm.device)
            with torch.no_grad():
                output = self.rm.encode(conv_list_formatted,use_tqdm=False)
                rm_scores=torch.tensor([x.outputs.data for x in output])
                
                reward_score.extend(rm_scores.tolist())
        return reward_score



# for LLM verifier
class LlmRM:
    def __init__(self, name):
        self.name = name

    def build_prompt(self, question, response, compare_standard=None):
        # quetion: str, response: list[str1, str2]
        # build the system prompt
        if compare_standard is None:
            compare_standard = DEFAULT_COMPARE_STANDARD
        system_prompt = 'You are a fair judger to compare the two given answers to the question below and determine which one is better. Follow these steps carefully:\n' \
                        '1. Analyze the Question: Understand the requirements and expectations of the question (QES).\n' \
                        f'2. Evaluate Each Answer Separately: Assess ANS_1 and ANS_2 based on the standard: {compare_standard}' \
                        '3. Compare the Two Answers: Identify strengths and weaknesses in each response. Consider which answer better based the standard.\n' \
                        '4. Make a Decision: Based on the comparison and , determine which answer is better and output the result in the format:\n' \
                        '"The ANS_x is better." (Replace x with 1 or 2 accordingly.)\n' \
                        'User input Format:\n' \
                        'QES: [The question here]\n' \
                        'ANS_1: [The first answer here]\n' \
                        'ANS_2: [The second answer here]\n'
        user_prompt = f'QES:"{question}"\n\nANS_1:"{response[0]}"\n\nANS_2:"{response[1]}"'
        return [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]

    def extract_winner(self, result_str):
        try:
     
            match = re.search(r'The ANS_(\d) is better\.', result_str)
            if match:
                return int(match.group(1)) - 1

            better_match = re.search(r'better.*?(ANS_(\d))|ANS_(\d).*?better', result_str)
            if better_match:
                return int(better_match.group(2) or better_match.group(3)) - 1

      
            better_match = re.search(r'better.*?([12])|([12]).*?better', result_str)
            if better_match:
                return int(better_match.group(1) or better_match.group(2)) - 1

          
            worse_match = re.search(r'worse.*?(ANS_(\d))|ANS_(\d).*?worse', result_str)
            if worse_match:
                return 1 if int(worse_match.group(2) or worse_match.group(3)) == 1 else 0

          
            worse_match = re.search(r'worse.*?([12])|([12]).*?worse', result_str)
            if worse_match:
                return 1 if int(worse_match.group(1) or worse_match.group(2)) == 1 else 0

        
            ans_match = re.search(r'ANS_(\d)', result_str)
            if ans_match:
                return int(ans_match.group(1)) - 1
            raise NotImplementedError
        except:
            print("Does not extract winner! Output 0")
            return 0

    def get_best_response(self, question, response, compare_standard=None):
        # quetion: str, response: list[str1, str2]
        messages = self.build_prompt(question, response, compare_standard)
        result_str = generate_general(model=self.name, messages=messages, max_tokens=2048, temperature=0.7, streaming=False)
        if self.name == 'QwQ-32B':
            result_str = result_str.split('</think>')[-1]
        return self.extract_winner(result_str)

    async def async_get_best_response(self, question, response, compare_standard=None):
        loop = asyncio.get_running_loop()
        return await loop.run_in_executor(None, self.get_best_response, question, response, compare_standard)


class MixJudge:
    def __init__(self, rm_models=None, llm_models=None):
        # put the strongest model first
        self.rm_models_list = rm_models
        self.llm_models_list = llm_models
        self.llm_judges = []
        # initialize the rm models
        if self.llm_models_list is not None:
            self.llm_judges = [LlmRM(m) for m in self.llm_models_list]
        self.rt_llm_judges = self.llm_judges

    def all_agent_voting(self, question, response, response_index, all_rm_models_result=None, compare_standard=None):
        voting_list = torch.zeros(2)
        eval_response = np.array(response)[response_index].tolist()
        if all_rm_models_result is not None:
            all_rm_models_result_select = torch.tensor(all_rm_models_result)[:, response_index]
            voting_list_com = (all_rm_models_result_select[:, 0] > all_rm_models_result_select[:, 1]).sum().item()
            voting_list = torch.tensor([voting_list_com, len(all_rm_models_result)-voting_list_com])
        for llm_judge in self.llm_judges:
            winner_index = llm_judge.get_best_response(question, eval_response, compare_standard=compare_standard)
            voting_list[winner_index] += 1
        return response_index[0] if voting_list[0] > voting_list[1] else response_index[1]

    async def async_all_agent_voting(self, question, response, response_index, all_rm_models_result=None, compare_standard=None):
        voting_list = torch.zeros(2)
        eval_response = np.array(response)[response_index].tolist()

        if all_rm_models_result is not None:
            all_rm_models_result_select = torch.tensor(all_rm_models_result)[:, response_index]
            voting_list_com = (all_rm_models_result_select[:, 0] > all_rm_models_result_select[:, 1]).sum().item()
            voting_list = torch.tensor([voting_list_com, len(all_rm_models_result) - voting_list_com])

        tasks = [judge.async_get_best_response(question, eval_response, compare_standard=compare_standard)
                 for judge in self.rt_llm_judges]
        winners = await asyncio.gather(*tasks)
        for winner_index in winners:
            voting_list[winner_index] += 1
        return response_index[0] if voting_list[0] > voting_list[1] else response_index[1]

    def get_best_response(self, question, response, compare_standard=None, rm_filter=False):
        assert (rm_filter and self.rm_models_list is not None) or not rm_filter
        # get the best response by championship and voting
        if self.rm_models_list is not None:
            all_rm_models_result = [generate_general_rm(model=m, question=[question]*len(response),
                                                     response=response, batch_size=1) for m in self.rm_models_list]
            first_round_reward = all_rm_models_result[0]
        else:
            all_rm_models_result = None
            first_round_reward = None
        first_round_sort = torch.tensor(first_round_reward).argsort() if first_round_reward is not None else torch.tensor(range(len(response)))
        # try to verify to remove some weak llms
        if rm_filter:
            min_rm_index, max_rm_index = first_round_sort[0], first_round_sort[-1]
            min_max_response = [response[min_rm_index], response[max_rm_index]]
            new_llm_judges = []
            for llm_judge in self.llm_judges:
                winner_index = llm_judge.get_best_response(question, min_max_response, compare_standard=compare_standard)
                if winner_index == 1:
                    new_llm_judges.append(llm_judge)
            self.llm_judges = new_llm_judges
        # begin the championship
        champion_rounds = math.ceil(math.log2(len(response)))
        winner_index_list = first_round_sort
        response_np = np.array(response)
        for i in tqdm(range(champion_rounds)):
            next_winner_index_list = []
            if len(winner_index_list) == 1:
                return winner_index_list.tolist()
            # divide the groups
            compare_groups = winner_index_list.split(split_size=2)
            for compare_group in compare_groups:
                if len(compare_group) == 1:
                    next_winner_index_list.append(compare_group.item())
                else:
                    # voting
                    next_winner_index_list.append(self.all_agent_voting(question, response, compare_group,
                                    all_rm_models_result=all_rm_models_result, compare_standard=compare_standard))
            winner_index_list = torch.tensor(next_winner_index_list)
        print(1)

    async def compare_groups_voting(self, question, response, compare_groups, all_rm_models_result, compare_standard):
        next_winner_index_list = []
        if len(compare_groups[-1]) == 1:
            next_winner_index_list.extend(compare_groups[-1])
            compare_groups = compare_groups[:-1]
        tasks = [self.async_all_agent_voting(question, response, group,
                                              all_rm_models_result=all_rm_models_result,
                                              compare_standard=compare_standard) for group in compare_groups]
        winners = await asyncio.gather(*tasks)
        next_winner_index_list.extend(winners)
        return next_winner_index_list

    async def async_get_best_response(self, question, response, compare_standard=None, rm_filter=False, return_details=False):
        assert (rm_filter and self.rm_models_list is not None) or not rm_filter
        # get the best response by championship and voting
        if self.rm_models_list is not None:
            all_rm_models_result = [generate_general_rm(model=m, question=[question] * len(response),
                                                        response=response, batch_size=1) for m in self.rm_models_list]
            first_round_reward = all_rm_models_result[0]
        else:
            all_rm_models_result = None
            first_round_reward = None
        first_round_sort = torch.tensor(
            first_round_reward).argsort() if first_round_reward is not None else torch.tensor(range(len(response)))
        # try to verify to remove some weak llms
        if rm_filter:
            min_rm_index, max_rm_index = first_round_sort[0], first_round_sort[-1]
            min_max_response = [response[min_rm_index], response[max_rm_index]]
            tasks = [llm_judge.async_get_best_response(question, min_max_response, compare_standard=compare_standard) for llm_judge in self.llm_judges]
            winner_index = await asyncio.gather(*tasks)
            new_llm_judges = [llm for llm, w_i in zip(self.llm_judges, winner_index) if winner_index[0] == 1]
            self.rt_llm_judges = new_llm_judges

        champion_rounds = math.ceil(math.log2(len(response)))
        winner_index_list = first_round_sort
        for i in tqdm(range(champion_rounds)):
            if len(winner_index_list) == 1:
                return winner_index_list.tolist()
            compare_groups = winner_index_list.split(2)
            next_winner_index_list = await self.compare_groups_voting(question, response, compare_groups, all_rm_models_result, compare_standard)
            winner_index_list = torch.tensor(next_winner_index_list)
        return winner_index_list.tolist()


model_dict = {
    'Qwen2.5-Math-PRM-7B': QwenPRM,
    'Qwen2.5-Math-RM-72B': QwenRM_VLLM_version,
    'Qwen2.5-Math-RM-72B-VLLM': QwenRM_VLLM_version,
    'Qwen2.5-Math-RM-72B-test': QwenRM_VLLM_version,
    'Skywork-8B-Reward-Models': SkyworkORM,
    'Skywork-27B-Reward-Models': SkyworkORM,
    'Skywork-27B-Reward-Models-VLLM':SkyworkORM_VLLM_version,
    'AceCodeRM-7B': AceCodeRM_VLLM_version,
    # 'AceCodeRM-32B': AceCodeRMWrap,
    'AceCodeRM-32B': AceCodeRM_VLLM_version,
    'AceCodeRM-32B-test': AceCodeRM_VLLM_version,
    'INF-ORM-Llama3.1-70B': INFORM,
    'INF-ORM-Llama3.1-70B-test': INFORM_VLLM_version,
    'INF-ORM-Llama3.1-70B-VLLM': INFORM_VLLM_version,
    'LDL-Reward-Gemma-2-27B': LDLRM,
    'LDL-Reward-Gemma-2-27B-VLLM': LDLRM_VLLM_version,
    'QRM-Gemma-2-27B': QRM,
    'QRM-Gemma-2-27B-VLLM': QRM_VLLM_version,
    'Skywork-Reward-V2-Llama-3.1-8B-40M':SkyworkORM,
    'Skywork-Reward-V2-Llama-3.1-8B-40M_test':SkyworkORM,
}

rm_path_dict = {
    "Skywork-8B-Reward-Models": "/Skywork-8B-Reward-Models",

}

def auto_get_rm(model_name):
    return model_dict[model_name]

if __name__ == '__main__':
   pass