import re
import os
from typing import List, Any
import tiktoken
from transformers import AutoTokenizer

VERBOSE = os.getenv("VERBOSE", "0").lower() in ("true", "1", "yes", "on")

class Chunker:

    def __init__(self, model):
        self.tokenizer = AutoTokenizer.from_pretrained(model)
        self.cache = {}

    def split_list_to_chunks(self, lst: list, chunk_num):
        '''Divide the list into chunk_num parts'''
        length = len(lst)
        if len(lst) <= chunk_num:
            return lst

        chunk_size = length // chunk_num
        result = [
            lst[i * chunk_size:(i + 1) * chunk_size]
            for i in range(chunk_num - 1)
        ]
        # The last block contains all the remaining elements
        result.append(lst[(chunk_num - 1) * chunk_size:])
        assert len(result) == chunk_num
        assert sum([len(i) for i in result]) == length
        return result

    def split_sentences(self, text, spliter):
        # Split by punctuation and keep punctuation
        text = text.strip()
        sentence_list = re.split(spliter, text)

        # Rearrange sentences and punctuation
        if spliter != ' ':
            sentences = [
                "".join(i)
                for i in zip(sentence_list[0::2], sentence_list[1::2])
            ]
            if len(sentence_list) % 2 != 0 and sentence_list[-1] != '':
                sentences.append(sentence_list[-1])
        else:
            sentences = [i + ' ' for i in sentence_list if i != '']
            sentences[-1] = sentences[-1].strip()
        return sentences

    def split_into_chunks(self, text, chunk_size, spliter=r'([。！？；.?!;])'):
        # Split by punctuation and keep punctuation
        # Rearrange sentences and punctuation
        sentences = self.split_sentences(text, spliter)

        chunks = []
        current_chunk = ""
        current_chunk_length = 0

        for s_idx, sentence in enumerate(sentences):
            sentence_length = self.get_prompt_length(sentence)

            if current_chunk_length + sentence_length <= chunk_size:
                current_chunk += sentence
                current_chunk_length += sentence_length
            else:
                if current_chunk:
                    if self.get_prompt_length(current_chunk) <= chunk_size:
                        chunks.append(current_chunk)
                    else:
                        if spliter != ' ':  # Avoid infinite loops
                            chunks.extend(
                                self.split_into_chunks(current_chunk,
                                                       chunk_size=chunk_size,
                                                       spliter=' '))
                        # else:
                        #     tmp = [current_chunk[i:i+chunk_size] for i in range(0,len(current_chunk), chunk_size)]
                        #     chunks.extend(tmp)
                current_chunk = sentence
                current_chunk_length = sentence_length

        if current_chunk != '':
            if self.get_prompt_length(current_chunk) <= chunk_size:
                chunks.append(current_chunk)
            else:
                if spliter != ' ':  # Avoid infinite loops
                    chunks.extend(
                        self.split_into_chunks(current_chunk,
                                               chunk_size=chunk_size,
                                               spliter=' '))
                # else:
                #     tmp = [current_chunk[i:i+chunk_size] for i in range(0,len(current_chunk), chunk_size)]
                #     chunks.extend(tmp)
        # Re-segment the last two blocks

        if len(chunks) > 1 and self.get_prompt_length(
                chunks[-1]) < chunk_size // 2:
            last_chunk = chunks.pop()
            penultimate_chunk = chunks.pop()
            combined_text = penultimate_chunk + last_chunk

            new_sentences = self.split_sentences(combined_text, spliter)

            # Reallocate sentence using double pointer
            new_penultimate_chunk = ""
            new_last_chunk = ""
            i, j = 0, len(new_sentences) - 1

            while i <= j and len(new_sentences) != 1:
                flag = False
                if self.get_prompt_length(new_penultimate_chunk +
                                          new_sentences[i]) <= chunk_size:
                    flag = True
                    new_penultimate_chunk += new_sentences[i]
                    if i == j:
                        break
                    i += 1
                if self.get_prompt_length(new_last_chunk +
                                          new_sentences[j]) <= chunk_size:
                    new_last_chunk = new_sentences[j] + new_last_chunk
                    j -= 1
                    flag = True
                if flag == False:
                    break
            if i < j:
                # If there is any unallocated part, split it by punctuation or space and then allocate it
                remaining_sentences = new_sentences[i:j + 1]
                if remaining_sentences:
                    remaining_text = "".join(remaining_sentences)
                    words = remaining_text.split(' ')
                    end_index = len(words) - 1
                    for index, w in enumerate(words):
                        if self.get_prompt_length(' '.join(
                            [new_penultimate_chunk, w])) <= chunk_size:
                            new_penultimate_chunk = ' '.join(
                                [new_penultimate_chunk, w])
                        else:
                            end_index = index
                            break
                    if end_index != len(words) - 1:
                        new_last_chunk = ' '.join(
                            words[end_index:]) + ' ' + new_last_chunk
            if len(new_sentences) == 1:
                chunks.append(penultimate_chunk)
                chunks.append(last_chunk)
            else:
                chunks.append(new_penultimate_chunk)
                chunks.append(new_last_chunk)

        return chunks

    def chunk_docs(self,
                   doc: str,
                   chunk_size: int,
                   separator='\n',
                   chunk_overlap=0) -> List[str]:

        splits = doc.split(separator)
        splits = [s for s in splits if s != '']
        separator_len = self.get_prompt_length_no_special(separator)

        docs = []
        current_doc: List[str] = []
        total = 0
        for d in splits:
            _len = self.get_prompt_length_no_special(d)
            if (total + _len + (separator_len if len(current_doc) > 0 else 0)
                    > chunk_size):
                if total > chunk_size:
                    if VERBOSE:
                        print(f"Created a chunk of size {total}, "
                            f"which is longer than the specified {chunk_size}")

                    if len(current_doc) == 1:  # if one chunk is too long

                        split_again = self.split_into_chunks(
                            current_doc[0], chunk_size)
                        docs.extend(split_again)
                        current_doc = []
                        total = 0

                if len(current_doc) > 0:
                    doc = separator.join(current_doc)
                    if doc is not None:
                        docs.append(doc)
                    # Keep on popping if:
                    # - we have a larger chunk than in the chunk overlap
                    # - or if we still have any chunks and the length is long
                    while total > chunk_overlap or (
                            total + _len +
                        (separator_len if len(current_doc) > 0 else 0)
                            > chunk_size and total > 0):
                        total -= self.get_prompt_length_no_special(
                            current_doc[0]) + (separator_len
                                               if len(current_doc) > 1 else 0)
                        current_doc = current_doc[1:]

            current_doc.append(d)
            total += _len + (separator_len if len(current_doc) > 1 else 0)
        # Check if the last one exceeds
        if self.get_prompt_length_no_special(
                current_doc[-1]) > chunk_size and len(current_doc) == 1:
            split_again = self.split_into_chunks(current_doc[0], chunk_size)
            docs.extend(split_again)
            current_doc = []
        else:
            doc = separator.join(current_doc)
            if doc is not None:
                docs.append(doc)
        docs = [d for d in docs if d.strip() != ""]
        return docs

    def get_prompt_length(self, prompt, **kwargs: Any) -> int:
        if isinstance(prompt, list):
            prompt = self.join_docs(prompt)
        if prompt in self.cache:
            return self.cache[prompt]
        else:
            length = len(self.tokenizer.encode(prompt, **kwargs))
            self.cache[prompt] = length
            return length

    def get_prompt_length_format(self, prompt, **kwargs: Any) -> int:
        # Calculate the length after formatting
        if isinstance(prompt, list):
            prompt = ''.join(self.format_chunk_information(prompt))
        return len(self.tokenizer.encode(prompt, **kwargs))

    def get_prompt_length_no_special(self, prompt, **kwargs: Any) -> int:
        if isinstance(prompt, list):
            prompt = self.join_docs(prompt)
        if not isinstance(self.tokenizer, tiktoken.core.Encoding):
            try:
                return len(
                    self.tokenizer.encode(prompt,
                                        add_special_tokens=False,
                                        **kwargs))
            except:
                return len(
                    self.tokenizer.encode(prompt))
        else:
            return len(
                self.tokenizer.encode(prompt,
                                      disallowed_special='all',
                                      **kwargs))

    def join_docs(self, docs: List[str]) -> str:
        if isinstance(docs, str):
            return docs
        return '\n\n'.join(docs)
    
    def format_chunk_information(self, docs):
        # format chunk
        new_docs = [
            f'Information of Chunk {index}:\n{d}\n'
            for index, d in enumerate(docs)
        ]
        return new_docs
    


