import re
import torch

from llava.constants import *
from llava import conversation as conversation_lib
from llava.arguments import *

def preprocess_query(
    text_query: str,
    is_multimodal: bool,
) -> Dict:
    """
    Make query tokens: '<IMAGE>' + 'text_query'
    """
    if is_multimodal:
        text_query = DEFAULT_IMAGE_TOKEN + text_query
        text_query.strip()

    sources = [{'from': 'human', 'value': text_query}]

    return sources


def preprocess_interleaved_section(
    document: Dict,
    is_multimodal: bool,
    is_tabular: bool,
    is_title_only: bool = False,
) -> Dict:
    """
    Make interleaved document tokens: 'Section1' // 'Section2' // ... // 'SectionN'
    """

    section_content_list = []
    for section_idx, (section_title, section_text) in enumerate(zip(document['section_titles'], document['section_texts'])):
        section_content = ('Section title: ' + section_title + '\n' + 'Section content: '
                           + section_text)
        section_content_list.append(section_content)

    if is_multimodal:
        image_section_indices = document['image_section_indices']
        image_section_dict = {}
        for idx, image_section_idx in enumerate(image_section_indices):
            if image_section_idx not in image_section_dict:
                image_section_dict[image_section_idx] = []

            # IMAGE + Image-description(caption) in the document
            image_input = DEFAULT_IMAGE_TOKEN + 'Image description: ' + document['image_reference_descriptions'][idx] + '\n'
            image_section_dict[image_section_idx].append(image_input)

        for section_idx in range(len(section_content_list)):

            if section_idx in image_section_dict:
                # Append the image to the head of the text
                section_content_list[section_idx] = ''.join(image_section_dict[section_idx]) + section_content_list[section_idx]

    if is_tabular:
        table_section_indices = document['table_section_indices']
        table_section_dict = {}
        for idx, table_section_idx in enumerate(table_section_indices):
            if table_section_idx not in table_section_dict:
                table_section_dict[table_section_idx] = []
            
            # Table in the document
            table_input = document['tables'][idx] + '\n'
            table_section_dict[table_section_idx].append(table_input)
        
        for section_idx in range(len(section_content_list)):

            if section_idx in table_section_dict:
                # Append the table to the tail of the text
                section_content_list[section_idx] = section_content_list[section_idx] + ''.join(table_section_dict[section_idx])

    sources = []
    for section_content in section_content_list:
        sources.append([{'from': 'human', 'value': section_content}])

    if is_title_only:
        sources.append([{'from': 'human', 'value': document['title']}])

    return sources


def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, input_type: str = 'query') -> Dict:
    roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}

    im_start, im_end = tokenizer.additional_special_tokens_ids
    nl_tokens = tokenizer("\n").input_ids
    _system = tokenizer("system").input_ids + nl_tokens

    # Apply prompt templates
    input_ids = []

    source = sources
    if roles[source[0]["from"]] != roles["human"]:
        source = source[1:]
    input_id = []

    # System prompt depending on the input type
    if input_type == 'query':
        system_message = "Represent the question for retrieving answers."
    elif input_type == 'document':
        system_message = "Represent the document for retrieval."
    elif input_type == 'query_sec':
        system_message = "Evaluate how precisely each section answers the query."
    elif input_type == 'document_sec':
        system_message = ""
    else:
        raise ValueError(f"{input_type} is not defined")

    system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
    input_id += system
    for j, sentence in enumerate(source):
        role = roles[sentence["from"]]
        if has_image and sentence["value"] is not None and "<image>" in sentence["value"]:
            num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
            texts = sentence["value"].split('<image>')
            _input_id = tokenizer(role).input_ids + nl_tokens
            for i, text in enumerate(texts):
                _input_id += tokenizer(text).input_ids
                if i<len(texts)-1:
                    _input_id += [IMAGE_TOKEN_INDEX] + nl_tokens
            _input_id += [im_end] + nl_tokens
            assert sum([i==IMAGE_TOKEN_INDEX for i in _input_id])==num_image
        else:
            if sentence["value"] is None:
                _input_id = tokenizer(role).input_ids + nl_tokens
            else:
                _input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens

        input_id += _input_id

    input_ids.append(input_id)
    input_ids = torch.tensor(input_ids, dtype=torch.long)

    return dict(
        input_ids=input_ids,
    )

def preprocess(
        sources: Sequence[str],
        tokenizer: transformers.PreTrainedTokenizer,
        has_image: bool = False,
        input_type: str = 'query',
) -> Dict:
    """
    Given a list of sources, each is a conversation list. This transform:
    1. Add signal '### ' at the beginning each sentence, with end signal '\n';
    2. Concatenate conversations together;
    3. Tokenize the concatenated conversation;
    4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
    """
    if conversation_lib.default_conversation.version.startswith("qwen"):
        return preprocess_qwen(sources, tokenizer, has_image=has_image, input_type=input_type)
