from generate import Generator, TextProcess
import numpy as np
import torch
import random, copy
from tqdm.autonotebook import tqdm
import torch.nn.functional as F

def my_convert(tokens):
    convert = lambda token: token.replace('▁', ' ').replace('<0x0A>', '\n')
    return ''.join([convert(t) for t in tokens])

class TCClassifier:
    def __init__(self, generator:Generator, sentence:str, split_pos:list[int], prompt_tc:list[int], content_list:list[list[str]], disturb_num:int, threshold:float, metric:str='dists', pair:bool=True, filter_list: list[str]|None=None, filter_behavior: str = 'ignore', split_punc: list[str]=[], special_number_process: bool=True, random_seed:int=20230819):
        """Classify the tokens in the sentence into template or content.
        Args:
            generator: The generator used to generate the content.
            sentence: The complete tested sentence.
            split_pos: A index list: split_pos[i] denotes the first token index of the i-th _words_. len == num_tokens + 1 (end_of_text)
            prompt_tc: A list of 0 and 1, 0 denotes the position is content, 1 denotes the position is template. The word in the prompt will not be classified.
            content_list: A list of list of str, each list of str denotes the content of the corresponding position in prompt_tc.
            disturb_num: The number of disturbed sentences for each position.
            threshold: The threshold of variance of the output distribution. If the variance is smaller than the threshold, the position is template.
            pair: Whether the content generation will be paired, defualt: True.
            filter_list: A list of tokens should be filtered out when test the variance of the output. The probability of the filtered tokens will be set to zero and the others will be re-normalized.
            special_number_process: only used when the tokenizer is Llama-2. Because the tokenizer of Llama-2 will split each number as well as the whitespace into splitted tokens, which leads to odd classification. We need to process these tokens.
            random_seed: The random seed.
        """

        self.generator = generator
        self.tokenizer = generator.tokenizer
        if 'Llama-2' in self.generator.model.name_or_path:
            self.tokenizer.convert_tokens_to_string = my_convert
        self.seq = sentence # The complete tested sentence
        self.split_pos = split_pos.copy() # A index list: split_pos[i] denotes the first token index of the i-th _words_.
        self.start_pos = split_pos[len(prompt_tc)] # The start position of the classification.

        tp = TextProcess(generator=generator, input_text=sentence)
        self.tokens = tp.tokens
        if len(self.tokens)!= split_pos[-1]:
            raise ValueError('The length of tokens should equal to the last element of split_pos (end_of_text).')
        self.tokens.append('<|eot|>')

        self.input_tc = prompt_tc.copy()
        if not len(content_list) == sum([t == 0 for t in prompt_tc]):
            raise ValueError('The length of content_list should equal to the number of 0 in prompt_tc.')

        self.content_list = copy.deepcopy(content_list)
        self.disturb_num = disturb_num
        self.threshold = threshold
        self.metric = metric
        self.pair = pair
        random.seed(random_seed)
        self.filter_list = filter_list
        self.filter_ids = self.prepare_filter_list()
        self.filter_behavior = filter_behavior
        if self.filter_behavior not in ['ignore', 'next']:
            raise ValueError('The filter_behavior should be ignore or next.')
        if self.pair:
            sample_index = random.sample(range(len(self.content_list[0])), self.disturb_num)
            self.content_list = [ list(c[i] for i in sample_index) for c in self.content_list ]
        self.split_punc = split_punc # Used in 'generate content', it should be consistent with the prompt split. Otherwise, the behavior could be strange.
        self.is_llama = 'Llama-2' in self.generator.model.name_or_path
        self.special_number_procss_flag = True if self.is_llama and special_number_process else False # in classify_step
        self.finish = False

    def prepare_filter_list(self) -> list[int]|None:
        if self.filter_list is None:
            return None
        ids = []
        try:
            for token in self.filter_list:
                ids.append(self.tokenizer.vocab[token])
        except KeyError:
            raise ValueError('The filter_list contains tokens {} that are not in the vocabulary.'.format(token))
        return ids

    def current_position(self):
        return len(self.input_tc)
    
    def _process_filter_tokens(self, preds: torch.Tensor, seq: str):
        """Set the probability of the filter tokens to zero and re-normalize the other tokens."""
        # copy 
        preds = preds.clone()
        if self.filter_ids is None:
            return F.sotfmax(preds, dim=0)
        if self.filter_behavior == 'ignore':
            preds[self.filter_ids] = -1e12
            return F.softmax(preds, dim=0)
        if self.filter_behavior == 'next':
            original_prob = F.softmax(preds, dim=0)
            next_probs = {}
            for filtered_id, filtered_token in zip(self.filter_ids, self.filter_list):
                if original_prob[filtered_id] < 0.01:
                    # ignore to save time
                    continue
                tp = TextProcess(generator=self.generator, input_text=seq+filtered_token)
                next_prob = tp.pred_dists[0,-1,:]
                # We also need to filter the filter tokens in the next_prob
                next_prob[self.filter_ids] = -1e12
                next_prob = F.softmax(next_prob, dim=0)
                next_probs[filtered_id] = next_prob
            prob = torch.zeros_like(original_prob)
            for idx, next_prob in next_probs.items():
                prob = prob + original_prob[idx] * next_prob
            original_prob[list(next_probs.keys())] = 0
            prob = prob + original_prob
            return prob

    def _disturb_content(self) -> list[str]:
        """Return a list of disturbed sentences, length = disturb_num."""
        if len(self.content_list) != sum([t == 0 for t in self.input_tc]):
            raise ValueError('The length of content_list should equal to the number of 0 in tc.')
        return_seqs = ['' for _ in range(self.disturb_num)]
        content_num = 0
        if self.pair:
            assert len(self.content_list[0]) == self.disturb_num
        for i, tc in enumerate(self.input_tc):
            if tc == 0: # content
                samples = self.content_list[content_num] if self.pair else random.sample(self.content_list[content_num], self.disturb_num)
                return_seqs = [(seq + content) for (seq, content) in zip(return_seqs, samples)]
                content_num += 1
            else:   # template
                add = self.tokenizer.convert_tokens_to_string(self.tokens[self.split_pos[i]:self.split_pos[i+1]])
                if i == 0 and add[0] == ' ':
                    add = add[1:]
                return_seqs = [seq + add for seq in return_seqs]
        return return_seqs
    
    def _generate_content(self):
        """Generate content and append it to content_list."""
        input_seqs = self._disturb_content()
        contents = []
        for input_seq in tqdm(input_seqs, desc='Generating content', leave=False):
            subword = []
            go_on_flag = True
            tp = TextProcess(generator=self.generator, input_text=input_seq)
            orig_preds = tp.pred_dists[0,-1,:]
            preds = self._process_filter_tokens(orig_preds, input_seq)
            token = self.tokenizer.convert_ids_to_tokens([torch.argmax(preds)])[0]
            subword.append(token)

            if token == '</s>' or token == '<|endoftext|>':
                go_on_flag = False 
            if token.endswith(('Ġ', 'Ċ', 'ċ', '▁', '<0x0A>') + tuple(self.split_punc)):
                print('Warning: The generated content ends with a special token. Content: {}'.format(subword))
                go_on_flag = False

            input_seq += self.tokenizer.convert_tokens_to_string([token])

            if go_on_flag: # determine whether the content has been generated completely.
                for _ in range(10):
                    tp = TextProcess(generator=self.generator, input_text=input_seq)
                    orig_preds = tp.pred_dists[0,-1,:]
                    # whether filter when continuous generation for content
                    preds = F.softmax(orig_preds, dim=0)
                    # preds = self._process_filter_tokens(orig_preds, input_seq)
                    token = self.tokenizer.convert_ids_to_tokens([torch.argmax(preds)])[0]
                    if token.startswith(('Ġ', 'Ċ', 'ċ', '▁') + tuple(self.split_punc)):
                        break
                    if token == '</s>' or token == '<|endoftext|>':
                        break
                    elif token[0:6] == '<0x0A>':
                        subword.append('<0x0A>')
                        break
                    else:
                        subword.append(token)
                        input_seq += self.tokenizer.convert_tokens_to_string([token])
                else:
                    tqdm.write('Warning: The generated content is too long. Content: {}'.format(subword))
            content = self.tokenizer.convert_tokens_to_string(subword)
            if content[0] != ' ':
                content = ' ' + content
            contents.append(content)
        tqdm.write('Generated contents: {}'.format(contents))
        self.content_list.append(contents)

    def classfify_step(self, debug): 
        input_seqs = self._disturb_content()
        if debug:
            print('\nInput seqs:')
            for seq in input_seqs:
                print(seq)
            print('\n')
        preds = []
        for input_seq in input_seqs:
            tp = TextProcess(generator=self.generator, input_text=input_seq)
            output = tp.pred_dists[0,-1,:]
            pred = self._process_filter_tokens(output, input_seq)
            if self.metric == 'dists':
                preds.append(pred)
            elif self.metric == 'ranking':
                argsort = torch.argsort(pred, descending=True)
                ranking = torch.zeros_like(pred)
                ranking[argsort] = torch.arange(pred.shape[0]).to(torch.float) # shape: (vocab_size, ), each position is the ranking (descending) of the corresponding token
                ranking = 1 / (1 + ranking**2)
                preds.append(ranking)
            else:
                raise ValueError('The metric should be dists or ranking.')
            if debug:
                # original
                top_tokens = self.tokenizer.convert_ids_to_tokens(torch.argsort(output, descending=True)[:10])
                top_probs = F.softmax(output, dim=0)[torch.argsort(output, descending=True)[:10]].tolist()
                # print in .xx format
                print('original top 10: ', end='')
                for i in range(10):
                    print('{} {:.2f}%'.format(top_tokens[i], 100*top_probs[i]), end='; ' if i != 9 else '\n')
                
                if self.filter_ids is not None:
                    # filterd
                    top_tokens = self.tokenizer.convert_ids_to_tokens(torch.argsort(pred, descending=True)[:10])
                    top_probs = pred[torch.argsort(pred, descending=True)[:10]].tolist()
                    # print in .xx format
                    print('filtered top 10: ', end='')
                    for i in range(10):
                        print('{} {:.2f}%'.format(top_tokens[i], 100*top_probs[i]), end='; ' if i != 9 else '\n')
        preds = torch.stack(preds, dim=0) # shape: (disturb_num, vocab_size)
        var = torch.sum(torch.var(preds, dim=0)) # shape: (vocab_size, ) -> float
        try:
            word = self.tokenizer.convert_tokens_to_string(self.tokens[self.split_pos[self.current_position()]: self.split_pos[self.current_position()+1]])
        except IndexError:
            word = self.tokenizer.convert_tokens_to_string(self.tokens[self.split_pos[self.current_position()]:])
        if var < self.threshold:
            # Template
            tqdm.write('The position {:>3d} word {:>10} is a T with variance {:.2f}.'.format(self.current_position(), repr(word), var))
            self.input_tc.append(1)
        else:
            tqdm.write('The position {:>3d} word {:>10} is a C with variance {:.2f}.'.format(self.current_position(), repr(word), var))
            self._generate_content()
            self.input_tc.append(0)

    def classify(self, debug=False):
        if self.finish:
            raise ValueError('The classification is finished. Please create a new instance to classify.')
        with torch.no_grad():
            with tqdm(total=len(self.split_pos) - self.current_position(), desc='Classifying') as pbar:
                while self.current_position() < len(self.split_pos):
                    self.classfify_step(debug=debug)
                    pbar.update(1)
        self.finish = True
        return self.input_tc
    
    def show(self):
        if not self.finish:
            raise ValueError('The classification is not finished. Please run classify() first.')
        print('-'*30)
        print('The T/C classification of the sentence: (prediction starts from the position {})'.format(self.split_pos.index(self.start_pos)))
        for i, tc in enumerate(self.input_tc):
            try:
                word = self.tokenizer.convert_tokens_to_string(self.tokens[self.split_pos[i]: self.split_pos[i+1]])
            except IndexError:
                word = self.tokenizer.convert_tokens_to_string(self.tokens[self.split_pos[i]:])
            print('The position {:>3d} word {:>10} is a {}.'.format(i, repr(word), 'T' if tc == 1 else 'C'))
        print('-'*30)