from typing import List
import logging
# import fitz  # PyMuPDF
from concurrent.futures import ThreadPoolExecutor
from collections import Counter
from multiprocessing import Pool
from transformers import AutoTokenizer

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def batch_split_into_chunks(texts, chunk_size, model, print_size=False):
    with Pool() as pool:
        # Using starmap to run split_into_chunks function in parallel
        results = pool.starmap(split_into_chunks, [(x, chunk_size, model) for x in texts])

    gathered_chunks = []  # list to gather the chunks
    gathered_num_chunks = []  # list to gather the number of chunks

    # Append results to gathered_chunks
    for chunks in results:
        gathered_chunks.append(chunks)

    # Append chunk lengths to gathered_num_chunks
    for chunks in gathered_chunks:
        gathered_num_chunks.append(len(chunks))

    count = Counter(gathered_num_chunks)
    print('&'*80)
    print('&' * 80)
    print(f"Chunk lengthes are :{count}")  # Output: Counter({4: 3, 2: 2, 1: 1, 3: 1, 5: 1})
    print('&' * 80)
    print('&' * 80)

    return gathered_chunks, gathered_num_chunks


def split_into_chunks(text: str, chunk_size: int, model: str = "llama-3.3-70b-versatile", print_size=True) -> List[str]:
    """
    Split text into chunks based on word count.
    
    Args:
        text: The input text to split
        chunk_size: Maximum number of words per chunk
        model: Not used, kept for compatibility
        
    Returns:
        List[str]: List of text chunks
    """
    # 
    tokenizer = AutoTokenizer.from_pretrained(model)
    # Split by paragraphs first to maintain context
    paragraphs = text.split('\n\n')
    words = []
    current_chunk = []
    chunks = []
    
    # Iterate through paragraphs and split them into chunks
    for paragraph in paragraphs:
        # Skip empty paragraphs
        if not paragraph.strip():
            continue
        
        # Encode the paragraph to get the number of words
        paragraph_tokens = tokenizer.encode(paragraph, add_special_tokens=False)

        if len(paragraph_tokens) + len(current_chunk) > chunk_size:
            if current_chunk:  # Save current chunk if it exists
                decoded_chunk = tokenizer.decode(current_chunk, skip_special_tokens=True)
                chunks.append(decoded_chunk)
                current_chunk = []

            while len(paragraph_tokens) > chunk_size:
                # Split the paragraph into smaller chunks
                sub_paragraph = tokenizer.decode(paragraph_tokens[:chunk_size], skip_special_tokens=True)
                chunks.append(sub_paragraph)
                paragraph_tokens = paragraph_tokens[chunk_size:]

            # Now we can add the remaining paragraph to the current chunk
            current_chunk = paragraph_tokens
        else: 
            current_chunk.extend(paragraph_tokens) 
        
    # Add any remaining text
    if current_chunk:
        decoded_chunk = tokenizer.decode(current_chunk, skip_special_tokens=True)
        chunks.append(decoded_chunk)

    if print_size:
        chunk_lens = [len(c.split(' ')) for c in chunks]
        logger.info(f"Split text into {len(chunks)} chunks --- (Chunk length: {chunk_lens})")
    
    return chunks


def count_tokens(text: str, model: str = "llama-3.3-70b-versatile") -> int:
    """
    Count the number of words in a text string.
    
    Args:
        text: The input text
        model: Not used, kept for compatibility
        
    Returns:
        int: Number of words
    """
    return len(text.split())


def get_worker_prompt(data, mode='longbench', use_sum_tag=False, use_sum_tag2=False):

    if data in ["gov_report", "multi_news"]:
        if use_sum_tag:
            sum_instruction = " Put your summary inside the summary tag, like <summary>your summary</summary>."
        else:
            sum_instruction = ""
        instruction = "You need to read [SOURCE TEXT] and [PREVIOUS SUMMARY] and generate a summary to include them both. Later, this summary will be used for other agents to generate a summary for the whole text. Thus, your generated summary should be relatively long."
        if 'no_ans' in mode:
            instruction += " Do not answer the question."

        worker_prompt = instruction + sum_instruction +  """\n\n[SOURCE TEXT]: {input_chunk}

[PREVIOUS SUMMARY]: {prev_cu}

Summary:"""

        worker_prompt0 = instruction + sum_instruction + """\n\n[SOURCE TEXT]: {input_chunk}

Summary:"""

    elif data in ["narrativeqa", "qasper", "hotpotqa", "musique", 'multifieldqa_en', 'qmsum', "2wikimqa", 'longbench-v2']:
        if use_sum_tag2:
            instruction = "You need to read [SOURCE TEXT] and [PREVIOUS SUMMARY] and generate a summary to include them both. Later, this summary will be used for other agents to answer [QUERY], if any. So please write the relatively long summary that can include the evidence for answering [QUESTION]."
            sum_instruction = " Put your summary inside the summary tag, like <summary>your summary</summary>."
        else: 
            if use_sum_tag:
                sum_instruction = " Put your summary inside the summary tag, like <summary>your summary</summary>."
            else:
                sum_instruction = ""
            instruction = "You need to read [SOURCE TEXT] and [PREVIOUS SUMMARY] and generate a summary to include them both. Later, this summary will be used for other agents to answer [QUERY], if any. So please write the summary that can include the evidence for answering [QUESTION]."
        if 'no_ans' in mode:
            instruction += " Do not answer the question."

        worker_prompt = instruction + sum_instruction + """\n\n[SOURCE TEXT]: {input_chunk}

[PREVIOUS SUMMARY]: {prev_cu}

[QUESTION]: {query}

Summary:"""

        worker_prompt0 = instruction + sum_instruction + """\n\n[SOURCE TEXT]: {input_chunk}

[QUESTION] {query}

Summary:"""
    else:
        raise NotImplementedError("Not supported data")

    return worker_prompt0, worker_prompt


def get_manager_prompt(data, mode='longbench', use_answer_tag=True, use_cot=False):
    # if "no_conflict" in mode:
    #     add_prompt = " Integrate all informations from every summary, resolve contradictions, and provide a comprehensive answer. "
    if "no_conflict" in mode:
        add_prompt = " Integrate all informations from every summary and resolve contradictions to answer the question. "
    else:
        add_prompt = ""
        # Using all context available, resolve any contradictions, and provide a comprehensive answer below.
        # 

    if use_answer_tag and data != 'longbench-v2':
        add_prompt_ans = " Put your answer inside the answer tag, like <answer>your answer</answer>."
    else:
        add_prompt_ans = ""
    
    if data in ["hotpotqa", "2wikimqa", "musique"]:
        pre = "Answer the question based on the given passages. Only give me the answer and do not output any other words." + add_prompt + add_prompt_ans
        mid = "\n\nThe following are given passages. However, the source text is too long and has been summarized. You need to answer based on the summary:\n{summary}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words."
        post = add_prompt + "\n\nQuestion: {query}\nAnswer:"
    elif data in ["narrativeqa", "qasper", "multifieldqa_en"]:
        pre = "Answer the question based on the given passages as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the given passages, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation." + add_prompt + add_prompt_ans
        mid = "\n\nThe following are given passages. However, the source text is too long and has been summarized. You need to answer based on the summary:\n{summary}\n\nAnswer the question based on the given passages as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the given passages, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation."
        post = add_prompt + add_prompt_ans + "\n\nQuestion: {query}\n\nAnswer:"
    elif data in ["gov_report", "multi_news"]:
        pre = "Read the following text and write a one-page summary." + add_prompt + add_prompt_ans
        mid = "\n\nThe following are given passages. However, the source text is too long and has been summarized. You need to answer based on the summary:{summary}\n\nNow, write a one-page summary."
        post = add_prompt + add_prompt_ans + "\n\nSummary:"
    elif data == 'longbench-v2':
        pre = "Please read the given passages and answer the question below." + add_prompt + add_prompt_ans
        mid = '''\n\nThe following are given passages. However, the source text is too long and has been summarized. You need to answer based on the summary:

<text>
{summary}
</text>\n\nWhat is the correct answer to this question.'''
        post = add_prompt + add_prompt_ans + ''' {query}

Format your response as follows: "The correct answer is (insert answer here)".'''
    else:
        raise NotImplementedError

    prompt = pre + mid + post

    if use_cot:
        cot_prompt = " First, think step by step to analyze the question and the provided information. Then, provide a final answer."
        prompt += cot_prompt

    return prompt


def get_data_specific_prompt(data, mode, ablation_type='None', summary_tag=True):
    worker_prompt0, worker_prompt = get_worker_prompt(data, mode, use_sum_tag2=summary_tag)
    if ablation_type != 'None' and 'cot' in ablation_type:
        manager_prompt = get_manager_prompt(data, mode, use_cot=True)
    else:
        manager_prompt = get_manager_prompt(data, mode)

    return (worker_prompt0, worker_prompt), manager_prompt
