import sys
sys.path.append('workpath')

import argparse
import re
from vllm import LLM, SamplingParams
import os
import torch
import uvicorn
from typing import List
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from openrlhf.utils import get_tokenizer
from openrlhf.models import get_llm_for_sequence_regression
from openrlhf.utils import get_tokenizer
from openrlhf.utils.logging_utils import init_logger
import json
from transformers import AutoModelForSequenceClassification
import time
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from openrlhf.models import Actor
import requests

logger = init_logger(__name__)

def strip_sequence(text, pad_token, eos_token):
    pad_token_escaped = re.escape(pad_token)
    eos_token_escaped = re.escape(eos_token)

    pattern = f"^({eos_token_escaped}|{pad_token_escaped})+"
    text = re.sub(pattern, "", text)

    pattern = f"({eos_token_escaped}|{pad_token_escaped})+$"
    text = re.sub(pattern, "", text)
    return text

import re
import torch

def auto_extract_prompts_answers(model_path):
    tokenizer_a = get_tokenizer(model_path, None, "left", use_fast=True)

    example_input = 'self_prompt'
    example_answer = 'self_answer'

    sample = [{"content": example_input, "role":"user"}, {"content": example_answer, "role":"assistant"}]
    apply_chat_template = tokenizer_a.apply_chat_template
    total_sentence = apply_chat_template(sample, tokenize=False)

    prompt_start = total_sentence.find(example_input)
    prompt_end = prompt_start + len(example_input)
    if prompt_start == -1:
        raise ValueError("no prompt")
    prompt_prefix = total_sentence[:prompt_start]

    answer_start = total_sentence.find(example_answer, prompt_end)
    if answer_start == -1:
        raise ValueError("no answer")
    answer_end = answer_start + len(example_answer)
    prompt_suffix = total_sentence[prompt_end:answer_start]
    answer_prefix = total_sentence[prompt_end:answer_start]
    answer_suffix = total_sentence[answer_end:]

    print(repr(prompt_prefix))
    print(repr(prompt_suffix))
    print(repr(answer_prefix))
    print(repr(answer_suffix))

    prompt_pattern = re.escape(prompt_prefix.strip()) + r"(.*?)" + re.escape(prompt_suffix.strip())
    answer_pattern = re.escape(answer_prefix.strip()) + r"(.*?)" + re.escape(answer_suffix.strip())

    prompt_re = re.compile(prompt_pattern, re.DOTALL)
    answer_re = re.compile(answer_pattern, re.DOTALL)

    return prompt_re, answer_re, prompt_prefix, prompt_suffix, answer_suffix

class RewardModelProxy:
    def __init__(self, args):
        self.args = args

        self.reward_models = []
        self.tokenizers = []
        self.reward_names = args.reward_pretrain
        self.reward_model = get_llm_for_sequence_regression(
            args.reward_pretrain,
            "reward",
            normalize_reward=args.normalize_reward,
            use_flash_attention_2=args.flash_attn,
            bf16=args.bf16,
            load_in_4bit=args.load_in_4bit,
            value_head_prefix=args.value_head_prefixs,
            device_map=f'cuda:5'
        )
        self.reward_model.eval()
        self.tokenizer = get_tokenizer(
            args.reward_pretrain, self.reward_model, "left", None, use_fast=not args.disable_fast_tokenizer
        )

        self.max_length = args.max_len
        self.batch_size = args.batch_size

        self.prompt_re, self.answer_re, self.prompt_prefix, self.prompt_suffix, self.answer_suffix = auto_extract_prompts_answers(args.pretrain)
        escaped_prompt_prefix = re.escape(self.prompt_prefix)
        escaped_answer_suffix = re.escape(self.answer_suffix)
        pattern = rf"{escaped_prompt_prefix}(.*?){escaped_answer_suffix}"
        self.sub_pattern = re.compile(pattern, re.DOTALL)

        self.streams = [torch.cuda.Stream(device=self.reward_models[index].device) for index in range(len(self.tokenizers))]

        os.environ["CUDA_VISIBLE_DEVICES"] = "6,7"
        self.llm = LLM(model=args.LLM_judge, tensor_parallel_size=2)
        self.sampling_params = SamplingParams(temperature=0.0)

        self.LLM_judge_prompt = """
            Task: Analyze the given sentence pair and determine if they are similar.

            Inputs:
            Sentence 1:
            ### [First sentence]

            Sentence 2:
            ### [Second sentence]

            If the two sentences entered express the same meaning, output "yes"; otherwise, output "no". 
            Only respond with "yes" or "no" in lowercase, without explanations.
        """

        self.user_question = """
            ### Sentence 1: {sentence_1}
            ### Sentence 2: {sentence_2}
        """



    def strip_text(self, query, prompt):
        prompts = self.prompt_re.findall(query)
        answers = self.answer_re.findall(query)

        if not prompts or not answers:
            prompts = self.sub_pattern.findall(query)
            if not prompts:
                query = strip_sequence(query, self.tokenizer.pad_token, self.tokenizer.eos_token) + self.tokenizer.eos_token
                return query, query
            sample = [{"content": prompts[0], "role":"user"}]
            temp = self.tokenizer.apply_chat_template(sample, tokenize=False).rstrip("\n")
            if not temp.endswith(self.tokenizer.eos_token):
                temp += " " + self.tokenizer.eos_token
            return temp, prompts[0]

        if len(prompts) != len(answers):
            print(f"error")
        sample = [{"content": prompts[0], "role":"user"}, {"content": answers[0], "role":"assistant"}]
        temp = self.tokenizer.apply_chat_template(sample, tokenize=False).rstrip("\n")
        if not temp.endswith(self.tokenizer.eos_token):
            temp += " " + self.tokenizer.eos_token
        return temp, answers[0]
    
    def construct(self, prompt, answer):
        prompt = self.prompt_re.findall(prompt)[0]
        sample = [{"content": prompt, "role":"user"}, {"content": answer, "role":"assistant"}]
        temp = self.tokenizer.apply_chat_template(sample, tokenize=False).rstrip("\n")
        if not temp.endswith(self.tokenizer.eos_token):
            temp += " " + self.tokenizer.eos_token
        return temp
    
    def LLM_inference(self, sentence_1, sentence_2):
        sentence = self.LLM_judge_prompt + '\n' + self.user_question.format(sentence_1=sentence_1, sentence_2=sentence_2)
        
        apply_chat_template = self.LLM_tokenizer.apply_chat_template
        if apply_chat_template:
            chat = sentence
            if isinstance(chat, str):
                chat = [{"role": "user", "content": chat}]
            sentence = apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
        inputs = self.tokenize_fn(sentence, device=self.LLM_judge.model.device, tokenizer=self.LLM_tokenizer)
        outputs = self.LLM_judge.model.generate(
            **inputs,
            use_cache=True,
            max_new_tokens=self.max_length,
            top_p=0.9,
            early_stopping=False,
            num_beams=1,
            temperature=0.7,
            repetition_penalty=0.7,
            pad_token_id=self.LLM_tokenizer.pad_token_id,
            eos_token_id=self.LLM_tokenizer.eos_token_id,
        )
        output = [x[len(inputs['input_ids'][0]):] for x in outputs]
        output = self.LLM_tokenizer.batch_decode(output, skip_special_tokens=True)
        output = output[0]
        if 'yes' in output.lower():
            return 1
        return -1
    
    def LLM_inference_vllm_batch(self,sentence_list_1:List[str],sentence_list_2:List[str]):
        batch = []
       
        for s1, s2 in zip(sentence_list_1,sentence_list_2):
            sentence = self.LLM_judge_prompt + '\n' + self.user_question.format(sentence_1=s1, sentence_2=s2) + '/no_think'
            batch.append(sentence)
        outputs = self.llm.generate(batch,self.sampling_params)

        return [1 if "yes" in output.outputs[0].text.lower() else 0 for output in outputs]


    
    def get_reward(self, queries, prompts, chosen, reject):
        if self.batch_size is None:
            batch_size = len(queries)
        else:
            batch_size = self.batch_size

        # remove pad_token
        all_queries = []
        all_chosen = []
        all_reject = []
        all_plain_queries = []
        all_plain_chosen = []
        all_plain_reject = []
        for i in range(len(queries)):
            queries_i, plain_query = self.strip_text(queries[i], prompts[i])
            chosen_i = self.construct(prompts[i], chosen[i])
            reject_i = self.construct(prompts[i], reject[i])
            all_queries.append(queries_i)
            all_chosen.append(chosen_i)
            all_reject.append(reject_i)
            all_plain_queries.append(plain_query)
            all_plain_chosen.append(chosen[i])
            all_plain_reject.append(reject[i])

        scores = []
        # batch
        with torch.no_grad():
            for i in range(0, len(queries), batch_size):
                current_batch_queries = all_queries[i : min(len(queries), i + batch_size)]
                current_batch_chosen = all_chosen[i : min(len(queries), i + batch_size)]
                current_batch_reject = all_reject[i : min(len(queries), i + batch_size)]
                current_batch_plain_queires = all_plain_queries[i : min(len(queries), i + batch_size)]
                current_batch_plain_chosen = all_plain_chosen[i : min(len(queries), i + batch_size)]
                current_batch_plain_reject = all_plain_reject[i : min(len(queries), i + batch_size)]

                inputs = self.tokenize_fn(current_batch_queries, device=self.reward_model.device, tokenizer=self.tokenizer)
                reward_scores = self.reward_model(inputs["input_ids"], inputs["attention_mask"]).squeeze(-1)
                # inputs = self.tokenize_fn(current_batch_chosen, device=self.reward_model.device, tokenizer=self.tokenizer)
                # chosen_scores = self.reward_model(inputs["input_ids"], inputs["attention_mask"]).squeeze(-1)
                inputs = self.tokenize_fn(current_batch_reject, device=self.reward_model.device, tokenizer=self.tokenizer)
                reject_scores = self.reward_model(inputs["input_ids"], inputs["attention_mask"]).squeeze(-1)
                
                LLM_scores = self.LLM_inference_vllm_batch(current_batch_plain_queires,current_batch_plain_chosen)
                LLM_scores = torch.tensor(LLM_scores).to(self.reward_model.device)

                print("reject_scores", reject_scores)
                print("reward_scores", reward_scores)
                print("LLM_scores", LLM_scores)

                returned_score = (reject_scores - reward_scores) * LLM_scores
            
                returned_score = returned_score.tolist()
                scores.extend(returned_score)

        return scores

    def tokenize_fn(self, texts, device, tokenizer):
        batch = tokenizer(
            texts,
            return_tensors="pt",
            add_special_tokens=False,
            max_length=self.max_length,
            padding=True,
            truncation=True,
        )
        return {k: v.to(device) for k, v in batch.items()}


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Reward Model
    parser.add_argument("--pretrain", type=str, default=None, help="lists of HF model name or path")
    parser.add_argument("--LLM_judge", type=str, default=None, help="lists of HF model name or path")
    parser.add_argument("--reward_pretrain", type=str, default=None, help="lists of HF model name or path")
    parser.add_argument("--normalize_reward", action="store_true", default=False, help="Enable Reward Normazation")
    parser.add_argument("--value_head_prefixs", type=str, default="score")
    parser.add_argument("--max_len", type=int, default="2048")

    parser.add_argument("--port", type=int, default=5000, help="Port number for the server")
    parser.add_argument("--host", type=str, default="0.0.0.0", help="IP for the server")

    # Performance
    parser.add_argument("--load_in_4bit", action="store_true", default=False)
    parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16")
    parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2")
    parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False)
    parser.add_argument("--batch_size", type=int, default=32)

    # ModelScope parameters
    parser.add_argument("--use_ms", action="store_true", default=False)

    # self define
    parser.add_argument("--file_path", type=str, default=False)

    args = parser.parse_args()

    if args.use_ms:
        from modelscope.utils.hf_util import patch_hub

        # Patch hub to download models from modelscope to speed up.
        patch_hub()

    reward_model = RewardModelProxy(args)
    app = FastAPI()

    @app.post("/get_reward")
    async def get_reward(request: Request):
        data = await request.json()
        queries = data.get("query")
        prompts = data.get("prompts")
        chosen = data.get("chosen_labels")
        reject = data.get("reject_labels")
        
        if await request.is_disconnected():
            logger.info(f"Client Disconnected !")
            return
        rewards = reward_model.get_reward(queries, prompts, chosen, reject)
        result = {"rewards": rewards}
        logger.info(f"Sent JSON: {result}")
        return JSONResponse(result)

    uvicorn.run(app, host=args.host, port=args.port, log_level="info")
