import torch
from utils.loggers import loggers
import math
import numpy as np

# Another version of beam search using new framework
class Beam_New_Search():
    def __init__(self, params, thought_generator, init_sequence, stop_criterion=None, qa_template="Q: <Q>\nA: Let's think step by step.\n<A>", score_func=None, batch_idx=0):
        self.beam_size = params['beam_size']
        self.num_candidates = params['num_candidates']
        self.max_length = params['max_length']
        self.early_stopping = params['early_stopping']

        self.thought_generator = thought_generator
        self.stop_criterion = stop_criterion

        self.qa_template = qa_template

        self.init_sequence = init_sequence

        self.step = 0
        self.score_func = score_func # used to get scores

        # each step generates num_candidate points, save the corresponding score 
        self.states = ['']
        self.concatenate_states = ['']
        self.status = [False]
        self.steps = [0] #record each corresponding step

        # global exploration
        self.non_terminal_states = ['']
        self.non_terminal_idxs = [0]

        # multi-arm bandit stats
        self.visits = [1]

        # we may also change the initial value to 0, beacuse it has not generated anything yet
        self.state_values = [self.get_score([''], return_with_init=True)[0]] # TODO: may need to change

        # debug information
        self.terminal = False
        self.batch_idx = batch_idx
        self.delimeter = '\n\n' if not params['gsm8k'] else '\n'
    
    def select(self):
        '''
            select beam_size points in search tree for expansion.
        '''
        scores = []
        for i in range(len(self.non_terminal_states)):
            idx = self.non_terminal_idxs[i]
            score = self.state_values[idx]
            scores.append(score)
        select_num = min(len(self.non_terminal_states), self.beam_size)
        # print('look select')
        # import IPython; IPython.embed()
        return [self.non_terminal_idxs[idx] for idx in np.argsort(scores)[-select_num:][::-1]]

    def expand(self, parent_idxs):
        '''
            Expand the search tree
        '''
        total_scores, which_beam, total_states = [], [], []
        for parent_idx in parent_idxs:
            init_seq, parent_state = self.get_next_strings(parent_idx, return_with_init=True, split_question_answer=True)
            if not self.stop_criterion(parent_state) or self.step == 0: 
                res = self.thought_generator(init_seq, parent_state, self.steps[parent_idx], self.batch_idx) # (num_candidates,)
                if res == '<SKIP>':
                    return '<SKIP>'
                # Choose the new_idx to expand
                total_scores.extend(res['scores'].cpu().tolist())
                which_beam.extend([parent_idx] * len(res['scores']))
                total_states.extend(res['text'])
                
            else: # temporarily finish for this branch, no need to expand
                total_scores.append(self.state_values[parent_idx])
                which_beam.append(parent_idx)
                total_states.append('')
        selected_idxs = np.argsort(total_scores)[-self.beam_size:][::-1]
        parent_beams = [which_beam[idx] for idx in selected_idxs]
        new_scores = [total_scores[idx] for idx in selected_idxs]
        new_states = [total_states[idx] for idx in selected_idxs]
        done = all(item == '' for item in new_states)
        return new_states, parent_beams, new_scores, done, total_states, total_scores

    def clear_all(self):
        '''
        we need to clear all the things using vanilla beam search
        '''
        self.states = []
        self.concatenate_states = []
        self.steps = []
        self.non_terminal_states = []
        self.non_terminal_idxs = []
        self.state_values = []

    def insert_new_state(self, new_states, new_scores, parent_idxs, done):
        # update the features of the search tree
        concatenate_states = []
        for i in range(len(new_states)):
            if new_states[i] == '':
                concatenate_states.append(self.get_next_strings(parent_idxs[i], return_with_init=False)[0])
            else:
                concatenate_states.append(self.get_next_strings(parent_idxs[i], return_with_init=False)[0] + self.delimeter + new_states[i])

        steps = [self.steps[parent_idxs[i]] + 1 for i in range(len(new_states))]
        self.clear_all()
        
        for new_state, new_score, concatenate_state, step in zip(new_states, new_scores, concatenate_states, steps):
            self.states.append(new_state)
            self.concatenate_states.append(concatenate_state.strip('\n').strip())
            self.steps.append(step)
        
            if not done:
                self.non_terminal_states.append(new_state)
                self.non_terminal_idxs.append(len(self.states) - 1)

            self.state_values.append(new_score)


    def get_next_strings(self, idx=0, return_with_init=True, split_question_answer=False):
        current_string = self.concatenate_states[idx]
        if not split_question_answer:
            if return_with_init:
                if self.step == 0:
                    return [self.qa_template.replace('<Q>', self.init_sequence).replace('<A>', '')]
                return [self.qa_template.replace('<Q>', self.init_sequence).replace('<A>', f"{current_string}{self.delimeter}")]
            return [current_string]
        else:
            assert return_with_init, "Only return with init supports splitting question and answer"
            if self.step == 0:
                return [self.qa_template.replace('<Q>', self.init_sequence).replace('<A>', ''), '']
            return [self.qa_template.replace('<Q>', self.init_sequence).replace('<A>', ''), f"{current_string}{self.delimeter}"]

    # get initial score using Bert Adapter
    def get_score(self, texts, return_with_init=True): # input: List[str], output: List[float]
        if return_with_init: # use qa_template to get qa pairs for texts
            texts = [self.qa_template.replace('<Q>', self.init_sequence).replace('<A>', f"{text}\n") for text in texts]
        scores = [score.item() for score in self.score_func(texts)]
        return scores

    def next_step(self):
        parent_idxs = self.select()
        current_string = [self.get_next_strings(parent_idx, return_with_init=True)[0] for parent_idx in parent_idxs]
        res = self.expand(parent_idxs)
        
        if res == '<SKIP>':
            return '<SKIP>'
        new_states, parent_beams, new_scores, done, total_states, total_scores = res
        self.step += 1
        self.insert_new_state(new_states, new_scores, parent_beams, done)

        sp_text = '\n\n'.join([f"{self.concatenate_states} Score: {new_scores}"])
        loggers["search"].info(f"{'*'*10} Question: {self.batch_idx} Step: {self.step} {'*'*10}\n\n{sp_text}\n")
        # 
        detail_text = '\n\n'.join([f"Input String {idx}: [BOS]{t}[EOS]" for idx, t in enumerate(current_string)]) + '\n\n'
        detail_text += '\n\n'.join([f"Output String: [BOS]{t}[EOS] (Score: {s}, Parent: {id})" for t, s, id  in zip(new_states, new_scores, parent_beams)])
        loggers["search_detail"].info(f"{'*'*10} Question: {self.batch_idx} Step: {self.step} {'*'*10}\n\n{detail_text}\n")

        # decide the stop criterion
        if done and self.early_stopping:
            return '<EARLY_STOP>'
        
        return '<CONTINUE>'
    

    def __call__(self, return_with_init=False):
        '''
        perform next search
        '''
        loggers["search"].info(f"\n{'='*20}\n Question: {self.batch_idx} Q: {self.init_sequence}\n\n")

        # Max length needs to be larger than beam search, since only one 
        for _ in range(self.max_length):
            flag = self.next_step()

            # early stopping
            if flag == '<EARLY_STOP>':
                self.terminal = True
                break
            
            # skip question
            if flag == '<SKIP>':
                loggers["error"].info(f"Question: {self.batch_idx} Next search timed out for {self.init_sequence}")
                return None

        sp_text = '\n\n'.join([f"{self.concatenate_states}"])
        loggers["search"].info(f'{"-"*10} Question: {self.batch_idx}  FINAL {"-"*10}\nterminal {self.terminal}\nstep {self.steps}\n{sp_text}\n{self.states}')
        return self.get_next_strings(0, return_with_init)