import sys, os
import json, re
from loguru import logger
from typing import Union, Dict, List, Optional
import copy
from pathlib import Path
from collections import deque
from transformers import AutoTokenizer
import base64
from tqdm import tqdm
from PIL import Image

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from gpt_api.unigpt import GPT
from common_tools import if_chart as if_chart
from common_tools import extract_context as extract_context


def is_word_count_over_limit(tokenizer, text_segments: Union[deque[str], list[str]], limit: int = 2000) -> bool:
    total_tokens = 0
    for segment in text_segments:
        token_ids = tokenizer.encode(segment)
        total_tokens += len(token_ids)
        if total_tokens > limit:
            return True
    return False


def split_json_by_charts(
    json_path: str, output_path: str, context_limit: int, tokenizer, image_root_dir, vl_agent, text_agent
):
    os.makedirs(output_path, exist_ok=True)
    p = Path(json_path)
    output_json_filename = os.path.join(output_path, p.name)
    image_dir = os.path.join(image_root_dir, p.stem, "auto")

    markdown_sequence = []
    output_image_json = []
    forward_context_deque = deque()
    with open(json_path, "r") as f:
        data = json.load(f)

        def get_markdown_sequence(contents, markdown_sequence: list):
            for content in contents:
                if isinstance(content, dict) and "title" in content.keys():
                    markdown_sequence.append(
                        {"type": "title", "content": content["title"], "chapter_level": content["chapter_level"]},
                    )
                if isinstance(content, dict) and "content" in content:
                    if len(content["content"]) != 0 and isinstance(content["content"], list):
                        markdown_sequence = get_markdown_sequence(content["content"], markdown_sequence)
                        continue
                if isinstance(content, dict) and "type" in content:
                    markdown_sequence.append(content)
            return markdown_sequence

        markdown_sequence = get_markdown_sequence(contents=data["content"], markdown_sequence=markdown_sequence)

    for sequence in markdown_sequence:
        md_content = ""
        # if sequence["content"] or sequence["src"]:
        if sequence["type"] == "title":
            chapter_level = int(sequence["chapter_level"].split("h")[-1])
            if sequence["content"] != None:
                md_content = ("#" * chapter_level) + " " + sequence["content"]
        elif sequence["type"] == "text":
            md_content = sequence["content"]
        elif sequence["type"] == "image":
            p = Path(sequence["src"])
            md_content = f"![{p.stem}](" + sequence["src"] + ")"

        forward_context_deque.append(md_content)
        while is_word_count_over_limit(tokenizer, forward_context_deque, limit=int(context_limit / 2)):
            forward_context_deque.popleft()

        if sequence["type"] == "image":
            output_image_json.append(
                {"image_name": md_content, "image_path": sequence["src"], "context": list(forward_context_deque)}
            )
            continue

        for i in range(len(output_image_json)):
            if not is_word_count_over_limit(tokenizer, output_image_json[i]["context"], limit=context_limit):
                output_image_json[i]["context"].append(md_content)

    with open(output_json_filename, "w") as f:
        if len(output_image_json) != 0:
            json.dump(output_image_json, f, ensure_ascii=False, indent=4)
            logger.info(f"file write to {output_json_filename}")


def split_jsons(vl_qwen_chat, text_qwen_chat, tokenizer, input_dirs, image_root_dir, context_limit=1000):
    root_dir = "data/chartdata/processed_json"
    output_root = "data/chartdata/processed_json/image_with_context"
    for input_dir in input_dirs:
        input_root_dir = os.path.join(root_dir, input_dir)
        output_dir = os.path.join(output_root, input_dir)
        entries_list = list(os.scandir(input_root_dir))
        for entry in tqdm(entries_list, desc=f"{input_dir}"):
            # for entry in entries:
            if entry.is_file():
                if entry.name.endswith((".json")) and not os.path.exists(os.path.join(output_dir, entry.name)):
                    json_path = os.path.join(input_root_dir, entry.name)
                    split_json_by_charts(
                        json_path=json_path,
                        output_path=output_dir,
                        context_limit=context_limit,
                        tokenizer=tokenizer,
                        image_root_dir=os.path.join(image_root_dir, input_dir),
                        vl_agent=vl_qwen_chat,
                        text_agent=text_qwen_chat,
                    )
                # exit()


def filter_chart(json_path: str, output_path: str, image_root_dir, vl_agent, text_agent):
    os.makedirs(output_path, exist_ok=True)
    p = Path(json_path)
    output_json_filename = os.path.join(output_path, p.name)
    image_dir = os.path.join(image_root_dir, p.stem, "auto")

    filtered_data = []
    try:
        with open(json_path, "r") as f:
            data = json.load(f)
            # INFO judge if the image is a chart
            for i in range(len(data)):
                if if_chart(vl_agent, os.path.join(image_dir, data[i]["image_path"])):
                    filtered_data.append(data[i])

            # INFO extract context
            # for i in tqdm(range(len(data)), desc="Extracting Context"):
            #     image_json = data[i]
            #     data[i]["context"] = extract_context(text_agent, image_json["image_name"], image_json["context"])

        with open(output_json_filename, "w") as f:
            if len(filtered_data) != 0:
                json.dump(filtered_data, f, ensure_ascii=False, indent=4)
                logger.info(f"file write to {output_json_filename}")
    except Exception as e:
        logger.error(f"error durning open file: {json_path}")
        logger.error(f"{e}")


def filter_charts(vl_qwen_chat, text_qwen_chat, input_dirs, image_root_dir):
    root_dir = "data/chartdata/processed_json/image_with_context"
    output_root = "data/chartdata/filtered_images"
    for input_dir in input_dirs:
        input_root_dir = os.path.join(root_dir, input_dir)
        output_dir = os.path.join(output_root, input_dir)
        image_root_input_dir = os.path.join(image_root_dir, input_dir)
        entries_list = list(os.scandir(input_root_dir))
        for entry in tqdm(entries_list, desc=f"{input_dir}"):
            if entry.is_file():
                if entry.name.endswith((".json")):  # and not os.path.exists(os.path.join(output_dir, entry.name)):
                    json_path = os.path.join(input_root_dir, entry.name)
                    filter_chart(
                        json_path=json_path,
                        output_path=output_dir,
                        image_root_dir=image_root_input_dir,
                        vl_agent=vl_qwen_chat,
                        text_agent=text_qwen_chat,
                    )
                    # exit()


if __name__ == "__main__":
    vl_qwen_chat = GPT(model="qwen2.5-vl-7b", vendor="", stream=False, temperature=0.2)
    text_qwen_chat = GPT(model="qwen3_32b", vendor="", stream=False, temperature=0.2)
    local_tokenizer_path = Path("models/Qwen2.5-VL-7B-Instruct")
    tokenizer = AutoTokenizer.from_pretrained(local_tokenizer_path, trust_remote_code=True)
    input_dirs = [
        "CNNIC",
        "GlobalTerrorismDatabase",
        "Guizhou",
        "NationalBureauOfStatistics",
        "telecomworld",
        "worldbank",
        "oced",
    ]
    image_root_dir = "data/chartdata/processed/mineru_python"
    # TODO oecd 需要先split然后再filter
    # split_jsons(vl_qwen_chat, text_qwen_chat, tokenizer, input_dirs, image_root_dir, context_limit=600)
    filter_charts(vl_qwen_chat, text_qwen_chat, input_dirs, image_root_dir)
