import torch
from torch.nn import functional as F
from tqdm import tqdm
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, \
    STOPPING_CRITERIA_INPUTS_DOCSTRING, add_start_docstrings
import re
# import the huggingface transformers libraries
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, LlamaForCausalLM, \
    LlamaForSequenceClassification, pipeline
import numpy as np
from typing import Dict, List
import heapq

def is_in(full_str, sub_str):
    return full_str.count(sub_str) > 0

def countwords(s):
    count = len(s.split())
    return count

def clean_llama(text, sep="###"):
    result = text.split(sep)[0]
    return result if len(result) > 0 else " "




class Search_Node:
    def __init__(self, state, parent=None, cache=None, reward = None, terminal=False):
        self.state = state
        self.parent = parent
        self.children = []
        self.inputs = []
        self.prev_state = []
        self.visits = 0
        self.value = 0.0
        self.cache = cache
        self.terminal = terminal
        self.reward = reward

    def select_child(self):
        choices_weights = [(child.value / child.visits + np.sqrt(2 * np.log(self.visits) / child.visits)) for child in
                           self.children]
        return self.children[np.argmax(choices_weights)]

    def add_child(self, state, reward = None, terminal=False):
        child_node = Search_Node(state, self, reward=reward, terminal=terminal)
        self.children.append(child_node)
        return child_node

    def update(self, value):
        self.visits += 1
        self.value += value

    def is_terminal(self):
        # Define your own terminal state check
        return False


# reward based search
class search_alignment:
    def __init__(self,
                 config,
                 method='our_method',
                 LLM_Prompt='RLHFlow/LLaMA3-iterative-DPO-final',
                 LLM_Decoding='RLHFlow/LLaMA3-iterative-DPO-final',
                 Reward_model='RLHFlow/RewardModel-Mistral-7B-for-DPA-v1',
                 Max_Prompt_length=500,
                 Max_Node_length=200,
                 Max_Response_length=1000,
                 Sample_Prompt_num=2,
                 Sample_Node_num=2,
                 Sample_Original_Prompt=True,
                 LLM_Decoding_GPU='cuda:0',
                 LLM_Prompt_GPU='cuda:0',
                 Reward_model_GPU='cuda:0',
                 torch_dtype=torch.float16):

        self.config = config
        self.method = method
        try:
            self.preference_weights = torch.tensor(self.config["preference_weight"])
            self.preference_index = torch.tensor(self.config["preference_index"])
        except:
            self.preference_weights = None
            self.preference_index = None

        print(f"Loading tokenizer...")
        self.tokenizer_gen = AutoTokenizer.from_pretrained(LLM_Decoding)
        self.tokenizer_prompt = AutoTokenizer.from_pretrained(LLM_Prompt)
        self.tokenizer_rm = AutoTokenizer.from_pretrained(Reward_model)

        print("Loading LLM...")
        self.LLM_Decoding_path = LLM_Decoding
        self.LLM_Prompt_path = LLM_Prompt
        self.Reward_model_path = Reward_model

        # self.eos = eos
        self.Max_Prompt_length = Max_Prompt_length
        self.Max_Node_length = Max_Node_length
        self.Max_Response_length = Max_Response_length
        self.Sample_Prompt_num = Sample_Prompt_num
        self.Sample_Node_num = Sample_Node_num
        self.Sample_Original_Prompt = Sample_Original_Prompt

        self.LLM_Decoding_GPU = LLM_Decoding_GPU
        self.LLM_Prompt_GPU = LLM_Prompt_GPU
        self.Reward_model_GPU = Reward_model_GPU

        self.torch_dtype = torch_dtype

        
        self.eos_token = self.tokenizer_gen.eos_token_id
        self.eos_str = self.tokenizer_gen.eos_token

        if self.LLM_Decoding_path == 'RLHFlow/LLaMA3-iterative-DPO-final':
            self.LLM_gen = AutoModelForCausalLM.from_pretrained(self.LLM_Decoding_path).to(self.LLM_Decoding_GPU)
        else:
            self.LLM_gen = AutoModelForCausalLM.from_pretrained(self.LLM_Decoding_path, torch_dtype=torch_dtype).to(self.LLM_Decoding_GPU)

        if self.LLM_Prompt_path == self.LLM_Decoding_path:
            self.LLM_prompt = self.LLM_gen
        else:
            self.LLM_prompt = AutoModelForCausalLM.from_pretrained(self.LLM_Prompt_path).to(self.LLM_Prompt_GPU)

        print("Loading RM...")

        if self.Reward_model_path =="RLHFlow/RewardModel-Mistral-7B-for-DPA-v1":
            self.rm = AutoModelForSequenceClassification.from_pretrained(self.Reward_model_path,trust_remote_code=True).to(self.Reward_model_GPU)
            self.rm_template = "[INST] You must read the following conversation carefully and rate the assistant's response from score 0-100 in these aspects: helpfulness, correctness, coherence, honesty, complexity, verbosity\n\nUser: {prompt}\n\nAssistant: {response} [/INST]"

        self.LLM_gen.eval()
        self.LLM_prompt.eval()
        self.rm.eval()


    def our_method(self, prompt):
        tree = Search_Node(prompt)
        revise_prompt = []

        if self.LLM_Prompt_path == 'RLHFlow/LLaMA3-iterative-DPO-final':
            # topk_weights, topk_indices = torch.topk(self.preference_weights, self.Sample_Prompt_num)
            # topk_attributes = [self.attributes[i] for i in topk_indices.tolist()]
            if self.Sample_Prompt_num == 3:
                revise_prompt.append(("Your are given the following <question>," + "<question>:" + prompt + "Now, your task is to Rewrite the <question> to make it more detailed.The rewritten question must be between [REVISE] and [/REVISE] tags."))
                revise_prompt.append(("Your are given the following <question>," + "<question>:" + prompt + "Now, your task is to Rewrite the <question> to make it more secure.The rewritten question must be between [REVISE] and [/REVISE] tags."))
                revise_prompt.append(("Your are given the following <question>," + "<question>:" + prompt + "Now, your task is to Rewrite the <question> to make it more detailed and secure, without deviating from the original content.The rewritten question must be between [REVISE] and [/REVISE] tags."))
            if self.Sample_Prompt_num == 2:
                revise_prompt.append(("Your are given the following <question>," + "<question>:" + prompt + "Now, your task is to Rewrite the <question> to make it more detailed and secure, without deviating from the original content.The rewritten question must be between [REVISE] and [/REVISE] tags."))
                revise_prompt.append(("Your are given the following <question>," + "<question>:" + prompt + "Now, your task is to Rewrite the <question> to make it more secure.The rewritten question must be between [REVISE] and [/REVISE] tags."))
            if self.Sample_Prompt_num == 1:
                revise_prompt.append(("Your are given the following <question>," + "<question>:" + prompt + "Now, your task is to Rewrite the <question> to make it more detailed and legal.The rewritten question must be between [REVISE] and [/REVISE] tags."))
            prompt_sample = 0
            for r_prompt in revise_prompt:
                # r_prompt = revise_prompt[i]

                messages = [
                    {"role": "user",
                     "content": r_prompt},
                ]

                model_inputs = self.tokenizer_prompt.apply_chat_template(messages, return_tensors="pt").to(self.LLM_Prompt_GPU)
                output_tokens = self.LLM_prompt.generate(model_inputs, max_new_tokens=self.Max_Prompt_length, do_sample=True)
                r_prompt = self.tokenizer_prompt.batch_decode(output_tokens)[0]

                pattern = r'\[REVISE\](.*?)\[/REVISE\]'
                match = re.findall(pattern, r_prompt, re.DOTALL)
                # match = re.search(pattern, r_prompt, re.DOTALL)
                if match:
                    new_prompt = match[-1]
                    if new_prompt == " and ":
                        if prompt_sample > 3:
                            break
                        prompt_sample += 1
                        revise_prompt.append(("Your are given the following <question>," + "<question>:" + prompt + "Now, your task is to Rewrite the <question> to make it more detailed.The rewritten question must be between [REVISE] and [/REVISE] tags."))

                    else:
                        tree.add_child(new_prompt.replace('\n', ''))

            if self.Sample_Original_Prompt:
                tree.add_child(prompt)

            max_reward = -100
            new_tree = None
            final_prompt = prompt
            fp_index = True
            response_set = []

            current_prompt = prompt
            while not tree.terminal:
                roolout = True

                for node in tree.children:
                    prompt_rollout = True
                    if prompt_rollout and fp_index:
                        current_prompt = node.state
                        prompt_rollout = False

                    t_prompt = node.state
                    terminal = False
                    child_node = node

                    while not terminal:
                        if len(node.children) > 0 and roolout:
                            new_tree = node
                            break

                        t_prompt = [
                            {"role": "user",
                             "content": t_prompt},
                        ]
                        model_inputs = self.tokenizer_gen.apply_chat_template(t_prompt, return_tensors="pt").to(self.LLM_Decoding_GPU)

                        count = torch.sum(model_inputs == self.eos_token).item()
                        if count >= 2:
                            model_inputs = model_inputs[:, :-2]

                        output_tokens = self.LLM_gen.generate(model_inputs, max_new_tokens=self.Max_Node_length, do_sample=True)
                        output_tokens = output_tokens[:, 4:]
                        count = torch.sum(output_tokens == self.eos_token).item()
                        state = self.tokenizer_gen.batch_decode(output_tokens)[0]

                        if count >= 2 or output_tokens.shape[1] > self.Max_Response_length:
                            roolout = False
                            terminal = True
                            response = state.split(self.eos_str, 1)[1].replace(self.eos_str, '')
                            model_inputs = self.tokenizer_rm(self.rm_template.format(prompt=prompt, response=response),
                                                             return_tensors="pt").to(self.Reward_model_GPU)
                            with torch.no_grad():
                                score = self.rm(**model_inputs).logits.squeeze()

                            rewards = (score - 10) / 10
                            rewards = torch.mean(rewards)

                            final_reward = rewards
                            response_set.append({"Revise_prompt":current_prompt, "response": response, "multi_rewards":((score - 10) / 10).tolist(), "final_reward": final_reward.tolist()})

                            child_node.add_child(state, reward=((score - 10) / 10), terminal=True)

                            if final_reward > max_reward:
                                max_reward = final_reward
                                new_tree = node
                                final_prompt = current_prompt
                            else:
                                node.terminal = True

                        else:
                            roolout = False

                            child_node = child_node.add_child(state)
                            t_prompt = state

                tree = new_tree

                if fp_index:
                    current_prompt = final_prompt
                    fp_index = False

                current_state = tree.state
                terminal_node_max = tree.children[0]
                for _ in range(self.Sample_Node_num):
                    t_prompt = [
                        {"role": "user",
                         "content": current_state},
                    ]
                    model_inputs = self.tokenizer_gen.apply_chat_template(t_prompt, return_tensors="pt").to(self.LLM_Decoding_GPU)
                    if tree.parent.parent != None:
                        model_inputs = model_inputs[:, :-2]

                    output_tokens = self.LLM_gen.generate(model_inputs, max_new_tokens=self.Max_Node_length, do_sample=True)
                    output_tokens = output_tokens[:, 4:]
                    count = torch.sum(output_tokens == self.eos_token).item()
                    state = self.tokenizer_gen.batch_decode(output_tokens)[0]

                    if count >= 2 or output_tokens.shape[1] > self.Max_Response_length:

                        

                        response = state.split(self.eos_str, 1)[1].replace(self.eos_str, '')
                        model_inputs = self.tokenizer_rm(self.rm_template.format(prompt=prompt, response=response),
                                                         return_tensors="pt").to(self.Reward_model_GPU)
                        with torch.no_grad():
                            score = self.rm(**model_inputs).logits.squeeze()

                        rewards = (score - 10) / 10
                        rewards = torch.mean(rewards)

                        final_reward = rewards
                        response_set.append(
                            {"Revise_prompt": final_prompt, "response": response, "multi_rewards": ((score - 10) / 10).tolist(),"final_reward": final_reward.tolist()})

                        terminal_node = tree.add_child(state, reward=((score - 10) / 10), terminal=True)

                        if final_reward >= max_reward:
                            max_reward = final_reward
                            tree = terminal_node
                            terminal_node_max = terminal_node
                        else:
                            tree = terminal_node_max
                            tree.terminal = True
                    else:
                        tree.add_child(state)

            while tree.children != []:
                tree = tree.children[0]

            try:
                final_answer = tree.state.split(self.eos_str, 1)[1].replace(self.eos_str, '')
            except:
                final_answer = tree.state
            return final_answer, final_prompt, response_set

    def beam_search_w_RM(self, prompt, batch=6, node=50,topk=3):
        search = True
        inputs = []
        max_reward = -100
        final_state = None
        while search:
            if inputs == []:
                for i in range(batch):
                    inputs.append(prompt)
            rewards = []
            states = []
            for i in range(len(inputs)):
                final = False
                t_prompt = [
                    {"role": "user",
                     "content": inputs[i]},
                ]
                model_inputs = self.tokenizer_gen.apply_chat_template(t_prompt, return_tensors="pt").to(
                    self.LLM_Decoding_GPU)
                count = torch.sum(model_inputs == self.eos_token).item()
                if count >=3 or model_inputs.shape[1] > self.Max_Response_length:
                    response = inputs[i].split(self.eos_str, 1)[1].replace(self.eos_str, '')
                    model_inputs = self.tokenizer_rm(self.rm_template.format(prompt=prompt, response=response),return_tensors="pt").to(self.Reward_model_GPU)
                    with torch.no_grad():
                        score = self.rm(**model_inputs).logits.squeeze()

                    reward = (score - 10) / 10
                    reward = torch.mean(reward)
                    if reward>max_reward:
                        final_state = inputs[i].split(self.eos_str, 1)[1].replace(self.eos_str, '')
                    if final:
                        search = False
                    continue

                final = True

                if count >= 2:
                    model_inputs = model_inputs[:, :-2]

                output_tokens = self.LLM_gen.generate(model_inputs, max_new_tokens=node,
                                                      do_sample=True)
                output_tokens = output_tokens[:, 4:]

                state = self.tokenizer_gen.batch_decode(output_tokens)[0]

                response = state.split(self.eos_str, 1)[1].replace(self.eos_str, '')
                model_inputs = self.tokenizer_rm(self.rm_template.format(prompt=prompt, response=response),
                                                 return_tensors="pt").to(self.Reward_model_GPU)
                with torch.no_grad():
                    score = self.rm(**model_inputs).logits.squeeze()

                reward = (score - 10) / 10
                rewards.append(torch.mean(reward))
                states.append(state)
            if rewards == []:
                break
            rewards = torch.stack(rewards, dim=0)
            if topk > rewards.shape[0]:
                c_topk = rewards.shape[0]
            else:
                c_topk = topk
            _, index = torch.topk(rewards,k=c_topk)
            inputs = []
            for i in index:
                inputs.append(states[i])
            inputs = inputs*int(batch / topk)
        return final_state, prompt


    
