import torch
from utils.loggers import loggers
import numpy as np
import time
import random

THRESHOLD = -1e5
NEGATIVE_SAMPLE = 10

# Another version of beam search using new framework
class Next_Linear_Search():
    def __init__(self, params, thought_generator, init_sequence, stop_criterion=None, qa_template="Q: <Q>\nA: Let's think step by step.\n<A>", score_func=None, device=None, negative_generator=None, batch_idx=0):
        self.beam_size = params['beam_size']
        self.num_candidates = params['num_candidates']
        self.max_length = params['max_length']
        self.early_stopping = params['early_stopping']
        self.sigma = params['sigma'] # prior noise parameter
        self.alpha = params['alpha'] # prior weight parameter
        self.c = params['c'] # exploration-exploitation trade-off

        self.remove_root = params.get('remove_root', False)
        assert self.remove_root, "Only support remove_root now"

        self.thought_generator = thought_generator
        self.negative_generator = negative_generator
        self.stop_criterion = stop_criterion

        self.qa_template = qa_template

        self.init_sequence = init_sequence

        self.step = 0
        self.score_func = score_func # used to get scores
        self.device = device

        # each step generates num_candidate points, save the corresponding score 
        self.states = ['']
        self.concatenate_states = ['']
        # self.states_nums = 0 if self.remove_root else 1 # record the stored states number
        self.parents = []
        self.status = [False]
        self.steps = [0] #record each corresponding step
        self.embeddings = [] # adapter embeddings for each state
        self.expanded = [False] # We expand each node only once

        # global exploration
        self.non_terminal_states = ['']
        self.non_terminal_idxs = [0]

        # multi-arm bandit stats
        self.visits = []

        # baysian regression
        self.r = None
        self.S = None # embeddings of all the states, [dim, num_states]
        self.A_inv = None # [dim, dim]

        # we may also change the initial value to 0, beacuse it has not generated anything yet
        # self.state_values = [self.get_gaussian_process_score(0) + self.c * self.get_gaussian_process_variance(0)] # TODO: may need to change
        self.state_values = [0.]

        # debug information
        self.terminal = False
        # self.weight_init = weight_init

        # return value
        self.search_idxs = None

        self.batch_idx = batch_idx
        self.delimeter = '\n\n' if not params['gsm8k'] else '\n'
        self.is_gsm8k = params['gsm8k']

    
    def compute_InvA(self):
        '''
            Compute (1/sigma^2 SS^T + alpha I)^-1
        '''
        A = (self.S @ self.S.T) / self.sigma**2 + self.alpha * torch.eye(self.S.shape[0]).to(self.device)
        return torch.inverse(A)
        

    def get_linear_regression_variance(self, idx: int = None, s=None):
        '''
        Linear Bayesian Model: r(s) = s^T w + \epsilon, where \epsilon ~ N(0, \sigma^2), w ~ N(0, \alpha^{-1} I)
        \sigma_{pred}^2 = s^T A^{-1} s + \sigma^2
        \mu_{pred} = s^T A^{-1} S r / \sigma^2
        A = (1/\sigma^2) SS^T + \alpha I
        '''
        if s is None:
            s = self.embeddings[idx].view(-1,1).to(self.device)
        return  (s.T @ self.A_inv @ s).item() ** 0.5
    
    def select(self):
        '''
            select beam_size points in search tree for expansion.
        '''
        scores = []
        unexpanded_states = [idx for idx in self.non_terminal_idxs if not self.expanded[idx]]
        for i in range(len(unexpanded_states)):
            idx = unexpanded_states[i]
            score = self.state_values[idx]
            scores.append(score)
        select_num = min(len(unexpanded_states), self.beam_size)
        # print('look select')
        # import IPython; IPython.embed()
        return [unexpanded_states[idx] for idx in np.argsort(scores)[-select_num:][::-1]]

    def expand(self, parent_idxs):
        '''
            Expand the search tree
        '''
        total_scores, which_beam, total_states, total_adapter_scores, total_Qs, total_Us, total_embeddings = [], [], [], [], [], [], []
        for parent_idx in parent_idxs:
            init_seq, parent_state = self.get_next_strings(parent_idx, return_with_init=True, split_question_answer=True)
            # print('Check at Expand function'); import IPython; IPython.embed()
            if not self.stop_criterion(parent_state) or self.step == 0: 
                for i in range(3):
                    res = self.thought_generator(init_seq, parent_state, self.steps[parent_idx], self.batch_idx) # (num_candidates,)
                    if res == '<SKIP>':
                        time.sleep(1)
                        continue
                    else:
                        break
                if res == '<SKIP>':
                    return '<SKIP>'

                # Choose the new_idx to expand
                scores, adapter_scores, Qs, Us, embeddings = [], [], [], [], []

                for i in range(len(res['text'])):
                    candidate = res['text'][i]
                    embedding = res['embeddings'][i]
                    score = res['scores'][i].item()
                    if candidate == "<EMPTY>":
                        scores.append(-float('inf'))
                        adapter_scores.append(-float('inf'))
                        Qs.append(-float('inf'))
                        Us.append(-float('inf'))
                        embeddings.append(None)
                        continue
                    # for the first iteration, we can treat it as a special case
                    if self.S is not None:
                        # Q = self.get_linear_regression_score(s=embedding)
                        Q = score
                        U = self.get_linear_regression_variance(s=embedding)
                        scores.append(Q + self.c * U)
                        Qs.append(Q)
                        Us.append(U)
                    else:
                        scores.append(score)
                        Qs.append(score)
                        Us.append(0)
                    embeddings.append(embedding)
                    adapter_scores.append(score)
                total_scores.extend(scores)
                which_beam.extend([parent_idx] * len(res['scores']))
                total_states.extend(res['text'])
                total_adapter_scores.extend(adapter_scores)
                total_Qs.extend(Qs)
                total_Us.extend(Us)
                total_embeddings.extend(embeddings)
            else:
                total_scores.append(self.state_values[parent_idx])
                which_beam.append(parent_idx)
                total_states.append('')
                total_adapter_scores.append(0)
                total_Qs.append(0)
                total_Us.append(0)
                total_embeddings.append(None)
        # selected_idxs = np.argsort(total_scores)[-self.beam_size:][::-1]
        sorted_idxs = np.argsort(total_scores)[::-1]
        selected_idxs = [idx for idx in sorted_idxs if total_scores[idx] > THRESHOLD]

        parent_beams = [which_beam[idx] for idx in selected_idxs]
        new_scores = [total_scores[idx] for idx in selected_idxs]
        new_states = [total_states[idx] for idx in selected_idxs]
        new_adapter_scores = [total_adapter_scores[idx] for idx in selected_idxs]
        new_exploitations = [total_Qs[idx] for idx in selected_idxs]
        new_explorations = [total_Us[idx] for idx in selected_idxs]
        new_embeddings = [total_embeddings[idx] for idx in selected_idxs]
        done = all(item == '' for item in new_states)
        return new_states, parent_beams, new_scores, done, new_adapter_scores, new_exploitations, new_explorations, new_embeddings
    
    def clear_all(self):
        '''
        clear all for the first epoch(remove root)
        '''
        self.states = []
        self.concatenate_states = []
        self.steps = []
        self.non_terminal_states = []
        self.non_terminal_idxs = []
        self.state_values = []
        self.expanded = []

    def remove_points(self, idxs):
        '''
        mark the given idxs points in the tree with expanded so that it won't be selected to expand again
        '''
        for idx in idxs:
            self.expanded[idx] = True
    
    # we choose samples with highest score to be negative samples
    # we sample half from high score sample, to learn false positive samples
    # another half from low score sample, to add negative samples
    def negative_augment(self, return_with_init=False):
        sorted_idxs = np.argsort(self.state_values)[::-1] # from large to small, exclude search_idxs
        sorted_idxs = [idx for idx in sorted_idxs if idx not in self.search_idxs and not self.expanded[idx]]
        sample_num = min(NEGATIVE_SAMPLE, len(sorted_idxs))
        negative_idxs = sorted_idxs[:sample_num//2] + sorted_idxs[-sample_num//2:] # find all the leaf nodes, from small to large
        negative_paths = []
        for negative_idx in negative_idxs:
            # negative_state = self.get_next_strings(negative_idx, True)[0]
            init_seq, negative_state = self.get_next_strings(negative_idx, return_with_init=True, split_question_answer=True)
            if not self.stop_criterion(negative_state):# complement if not completed
                negative_step = self.steps[negative_idx]
                generate_state = self.negative_generator(init_seq, negative_state, negative_step)[0]
                negative_paths.append(self.get_next_strings(negative_idx, return_with_init)[0] + self.delimeter + generate_state.strip())
            else:
                negative_paths.append(self.get_next_strings(negative_idx, return_with_init)[0])
        return negative_paths



    # in this function, we need to update each node in the SearchNet after inserting a new node
    def insert_new_state(self, new_states, new_adapter_scores, new_scores, new_embeddings, parent_beams, done, parent_idxs):
        # update the features of the search tree
        concatenate_states = []
        for i in range(len(new_states)):
            if new_states[i] == '':
                concatenate_states.append(self.get_next_strings(parent_beams[i], return_with_init=False)[0])
            else:
                concatenate_states.append(self.get_next_strings(parent_beams[i], return_with_init=False)[0] + self.delimeter + new_states[i])
        steps = [self.steps[parent_beams[i]] + 1 for i in range(len(new_states))]
        self.search_idxs = []
        
        if self.step == 0 and self.remove_root:
            self.clear_all()
        
        for new_state, new_adapter_score, new_score, new_embedding, concatenate_state, step, parent_beam in zip(new_states, new_adapter_scores, new_scores, new_embeddings, concatenate_states, steps, parent_beams):
            if new_state == '':
                self.search_idxs.append(parent_beam)
                continue
            self.states.append(new_state)
            self.concatenate_states.append(concatenate_state.strip('\n').strip())
            self.status.append(done)
            self.steps.append(step)
            self.embeddings.append(new_embedding)
            self.expanded.append(False)
            if self.step == 0:
                self.parents.append(None)
            else:
                self.parents.append(parent_beam)
        
            if not done:
                self.non_terminal_states.append(new_state)
                self.non_terminal_idxs.append(len(self.states) - 1)

            # state_value = self.get_score([concatenate_state], return_with_init=True)[0]
            # state_value = self.get_score([concatenate_state.strip('\n').strip()], return_with_init=True)[0]
            self.state_values.append(new_score)
            if self.step:
                self.visits[parent_beam] += 1
            self.visits.append(0)

            # update s
            s = new_embedding.view(-1,1).to(self.device)
            self.S = torch.cat([self.S, s], dim=1) if self.S is not None else s
            # update r
            r = torch.tensor([new_adapter_score]).unsqueeze(-1).to(self.device)
            self.r = torch.cat([self.r, r], dim=0) if self.r is not None else r
        # update A_inv
        self.A_inv = self.compute_InvA()
        # update exploitation and exploration for each state
        for i in range(1,len(self.states)):
            # self.state_values[i] = self.get_linear_regression_score(i) + self.c * self.get_linear_regression_variance(i)
            self.state_values[i] = self.r[i].item() + self.c * self.get_linear_regression_variance(i)
        if not(self.step == 0 and self.remove_root):
            for parent_idx in parent_idxs:
                parent_state = self.get_next_strings(parent_idx, return_with_init=True)[0]
                if not self.stop_criterion(parent_state):
                    self.remove_points([parent_idx])

    def get_next_strings(self, idx=0, return_with_init=True, split_question_answer=False):
        current_string = self.concatenate_states[idx]
        if not split_question_answer:
            if return_with_init:
                if self.step == 0:
                    return [self.qa_template.replace('<Q>', self.init_sequence).replace('<A>', '')]
                return [self.qa_template.replace('<Q>', self.init_sequence).replace('<A>', f"{current_string}{self.delimeter}")]
            return [current_string]
        else:
            assert return_with_init, "Only return with init supports splitting question and answer"
            if self.step == 0:
                return [self.qa_template.replace('<Q>', self.init_sequence).replace('<A>', ''), '']
            return [self.qa_template.replace('<Q>', self.init_sequence).replace('<A>', ''), f"{current_string}{self.delimeter}"]

    def next_step(self):
        #
        # print('look next step')
        test = ""
        if len(self.states) >= 3:
            for i in range(3):
                test += f'Exploitation for State {i} is {self.r[i].item()}, Exploration for State {i} is {self.c * self.get_linear_regression_variance(i)}\n'
                # test += f'Exploitation for State {i} is {self.get_linear_regression_score(i)}, Exploration for State {i} is {self.c * self.get_linear_regression_variance(i)}\n'
        parent_idxs = self.select()
        current_string = [self.get_next_strings(parent_idx, return_with_init=True)[0] for parent_idx in parent_idxs]
        res = self.expand(parent_idxs)
        # print('look expand')
        if res == '<SKIP>':
            return '<SKIP>'
        new_states, parent_beams, new_scores, done, new_adapter_scores, exploitations, explorations, new_embeddings = res
        self.insert_new_state(new_states, new_adapter_scores, new_scores, new_embeddings, parent_beams, done, parent_idxs)
        self.step += 1
        # breakpoint()
        # log in loggers
        sp_text = '\n\n'.join([
            "State: " + self.concatenate_states[beam] + '\n' + new_state + f"  Exploitation Score: {exploit} Exploration Score: {self.c * explore} Score: {new_score} Adapter Score: {new_adapter_score}"
            for beam, new_state, new_score, new_adapter_score, exploit, explore in zip(parent_beams, new_states, new_scores, new_adapter_scores, exploitations, explorations)
        ])
        loggers["search"].info(f"{'*'*10} Question: {self.batch_idx} Step: {self.step} {'*'*10}\n\n{'='*10} Test {'='*10}\n {test}\n{'='*26} \n\n{sp_text}\n")
        detail_text = '\n\n'.join([f"{'='*5} Step {self.steps[parent_idxs[idx]]} {'='*5} Input String {parent_idxs[idx]}: [BOS]{t}[EOS]" for idx, t in enumerate(current_string)]) + '\n\n'
        detail_text += '\n\n'.join([f"Output String: [BOS]{t}[EOS] (Score: {s}, Parent: {id})" for t, s, id  in zip(new_states, new_scores, parent_beams)])
        loggers["search_detail"].info(f"{'*'*10} Question: {self.batch_idx} Step: {self.step} {'*'*10}\n\n{detail_text}\n")

        # loggers["search_detail"].info(f"{'='*10} Question: {self.batch_idx} Step: {self.step} {'='*10}\n\n{detail_text}\n")
        # decide the stop criterion
        if done and self.early_stopping:
            return '<EARLY_STOP>'
        
        return '<CONTINUE>'
    

    def __call__(self, return_with_init=False, negative_augment=False):
        '''
        perform next search
        '''
        loggers["search"].info(f"\n{'='*20}\n Question: {self.batch_idx} Q: {self.init_sequence}\n\n")

        # Max length needs to be larger than beam search, since only one 
        for _ in range(self.max_length):
            flag = self.next_step()

            # early stopping
            if flag == '<EARLY_STOP>':
                self.terminal = True
                break
            
            # skip question
            if flag == '<SKIP>':
                loggers["error"].info(f"Question: {self.batch_idx} Next search timed out for {self.init_sequence}")
                return None

        sp_text = '\n\n'.join([f"State: {concatenate_state} Score: {state_value}" for concatenate_state, state_value in zip(self.concatenate_states, self.state_values)])
        loggers["search"].info(f'{"-"*10} Question: {self.batch_idx} FINAL {"-"*10}\nterminal {self.terminal}\n\n{sp_text}')
        detail_text = '\n\n'.join([self.get_next_strings(idx, return_with_init)[0] for idx in self.search_idxs])
        loggers["search_detail"].info(f'{"-"*10} Question: {self.batch_idx} FINAL {"-"*10}\nterminal {self.terminal}\n\n{detail_text}')

        if negative_augment:
            return ([self.get_next_strings(idx, return_with_init)[0] for idx in self.search_idxs], self.negative_augment(return_with_init))
        else:
            return [self.get_next_strings(idx, return_with_init)[0] for idx in self.search_idxs]


        