# Copyright (c) Opendatalab. All rights reserved.
import copy
import json
import os
from pathlib import Path

from loguru import logger

from mineru.cli.common import convert_pdf_bytes_to_bytes_by_pypdfium2, prepare_env, read_fn
from mineru.data.data_reader_writer import FileBasedDataWriter
from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox
from mineru.utils.enum_class import MakeMode
from mineru.backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
from mineru.backend.pipeline.pipeline_middle_json_mkcontent import union_make as pipeline_union_make
from mineru.backend.pipeline.model_json_to_middle_json import result_to_middle_json as pipeline_result_to_middle_json
from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path


def do_parse(
    output_dir,  # Output directory for storing parsing results
    pdf_file_names: list[str],  # List of PDF file names to be parsed
    pdf_bytes_list: list[bytes],  # List of PDF bytes to be parsed
    p_lang_list: list[str],  # List of languages for each PDF, default is 'ch' (Chinese)
    backend="pipeline",  # The backend for parsing PDF, default is 'pipeline'
    parse_method="auto",  # The method for parsing PDF, default is 'auto'
    p_formula_enable=True,  # Enable formula parsing
    p_table_enable=False,  # INFO Enable table parsing
    server_url=None,  # Server URL for vlm-sglang-client backend
    f_draw_layout_bbox=True,  # Whether to draw layout bounding boxes
    f_draw_span_bbox=True,  # Whether to draw span bounding boxes
    f_dump_md=True,  # Whether to dump markdown files
    f_dump_middle_json=True,  # Whether to dump middle JSON files
    f_dump_model_output=True,  # Whether to dump model output files
    f_dump_orig_pdf=True,  # Whether to dump original PDF files
    f_dump_content_list=True,  # Whether to dump content list files
    f_make_md_mode=MakeMode.MM_MD,  # The mode for making markdown content, default is MM_MD
    start_page_id=0,  # Start page ID for parsing, default is 0
    end_page_id=None,  # End page ID for parsing, default is None (parse all pages until the end of the document)
):
    failed_paths = []
    if backend == "pipeline":
        logger.info(f"start to convert_pdf_bytes_to_bytes_by_pypdfium2()")
        for idx, pdf_bytes in enumerate(pdf_bytes_list):
            try:
                new_pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
                pdf_bytes_list[idx] = new_pdf_bytes
            except Exception as e:
                logger.exception(e)
                failed_paths.append(pdf_file_names[idx])
                continue

        logger.info(f"read pdf ok! start pipeline_doc_analyze()")
        try:
            infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list = pipeline_doc_analyze(
                pdf_bytes_list,
                p_lang_list,
                parse_method=parse_method,
                formula_enable=p_formula_enable,
                table_enable=p_table_enable,
            )
        except Exception as e:
            logger.exception(e)
            failed_paths = pdf_file_names
            return failed_paths
        logger.info(f"pipeline_doc_analyze() ok")
        for idx, model_list in enumerate(infer_results):
            try:
                model_json = copy.deepcopy(model_list)
                pdf_file_name = pdf_file_names[idx]
                local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
                image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)

                images_list = all_image_lists[idx]
                pdf_doc = all_pdf_docs[idx]
                _lang = lang_list[idx]
                _ocr_enable = ocr_enabled_list[idx]
                middle_json = pipeline_result_to_middle_json(
                    model_list, images_list, pdf_doc, image_writer, _lang, _ocr_enable, p_formula_enable
                )

                pdf_info = middle_json["pdf_info"]

                pdf_bytes = pdf_bytes_list[idx]
                if f_draw_layout_bbox:
                    draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")

                if f_draw_span_bbox:
                    draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")

                if f_dump_orig_pdf:
                    md_writer.write(
                        f"{pdf_file_name}_origin.pdf",
                        pdf_bytes,
                    )

                if f_dump_md:
                    image_dir = str(os.path.basename(local_image_dir))
                    md_content_str = pipeline_union_make(pdf_info, f_make_md_mode, image_dir)
                    md_writer.write_string(
                        f"{pdf_file_name}.md",
                        md_content_str,  # type: ignore
                    )

                if f_dump_content_list:
                    image_dir = str(os.path.basename(local_image_dir))
                    content_list = pipeline_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir)
                    md_writer.write_string(
                        f"{pdf_file_name}_content_list.json",
                        json.dumps(content_list, ensure_ascii=False, indent=4),
                    )

                if f_dump_middle_json:
                    md_writer.write_string(
                        f"{pdf_file_name}_middle.json",
                        json.dumps(middle_json, ensure_ascii=False, indent=4),
                    )

                if f_dump_model_output:
                    md_writer.write_string(
                        f"{pdf_file_name}_model.json",
                        json.dumps(model_json, ensure_ascii=False, indent=4),
                    )

                logger.info(f"local output dir is {local_md_dir}")
            except Exception as e:
                logger.exception(e)
                failed_paths.append(pdf_file_names[idx])
                continue
    else:
        if backend.startswith("vlm-"):
            backend = backend[4:]

        f_draw_span_bbox = False
        parse_method = "vlm"
        for idx, pdf_bytes in enumerate(pdf_bytes_list):
            try:
                pdf_file_name = pdf_file_names[idx]
                pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
                local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
                image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
                middle_json, infer_result = vlm_doc_analyze(
                    pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url
                )

                pdf_info = middle_json["pdf_info"]

                if f_draw_layout_bbox:
                    draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")

                if f_draw_span_bbox:
                    draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")

                if f_dump_orig_pdf:
                    md_writer.write(
                        f"{pdf_file_name}_origin.pdf",
                        pdf_bytes,
                    )

                if f_dump_md:
                    image_dir = str(os.path.basename(local_image_dir))
                    md_content_str = vlm_union_make(pdf_info, f_make_md_mode, image_dir)
                    md_writer.write_string(
                        f"{pdf_file_name}.md",
                        md_content_str,  # type: ignore
                    )

                if f_dump_content_list:
                    image_dir = str(os.path.basename(local_image_dir))
                    content_list = vlm_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir)
                    md_writer.write_string(
                        f"{pdf_file_name}_content_list.json",
                        json.dumps(content_list, ensure_ascii=False, indent=4),
                    )

                if f_dump_middle_json:
                    md_writer.write_string(
                        f"{pdf_file_name}_middle.json",
                        json.dumps(middle_json, ensure_ascii=False, indent=4),
                    )

                if f_dump_model_output:
                    model_output = ("\n" + "-" * 50 + "\n").join(infer_result)
                    md_writer.write_string(
                        f"{pdf_file_name}_model_output.txt",
                        model_output,
                    )

                logger.info(f"local output dir is {local_md_dir}")
            except Exception as e:
                logger.exception(e)
                failed_paths.append(pdf_file_names[idx])
                continue
    return failed_paths


def parse_doc(
    path_list: list[Path],
    output_dir,
    lang="ch",
    backend="pipeline",
    method="auto",
    server_url=None,
    start_page_id=0,  # Start page ID for parsing, default is 0
    end_page_id=None,  # End page ID for parsing, default is None (parse all pages until the end of the document)
    enable_table=False,  # if enable table parsing
):
    """
    Parameter description:
    path_list: List of document paths to be parsed, can be PDF or image files.
    output_dir: Output directory for storing parsing results.
    lang: Language option, default is 'ch', optional values include['ch', 'ch_server', 'ch_lite', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']。
        Input the languages in the pdf (if known) to improve OCR accuracy.  Optional.
        Adapted only for the case where the backend is set to "pipeline"
    backend: the backend for parsing pdf:
        pipeline: More general.
        vlm-transformers: More general.
        vlm-sglang-engine: Faster(engine).
        vlm-sglang-client: Faster(client).
        without method specified, pipeline will be used by default.
    method: the method for parsing pdf:
        auto: Automatically determine the method based on the file type.
        txt: Use text extraction method.
        ocr: Use OCR method for image-based PDFs.
        Without method specified, 'auto' will be used by default.
        Adapted only for the case where the backend is set to "pipeline".
    server_url: When the backend is `sglang-client`, you need to specify the server_url, for example:`http://127.0.0.1:30000`
    """

    file_name_list = []
    pdf_bytes_list = []
    lang_list = []
    failed_paths = []
    batch_size = 10
    for i in range(0, len(path_list), batch_size):
        batch_paths = path_list[i : i + batch_size]

        file_name_list = []
        pdf_bytes_list = []
        lang_list = []

        for path in batch_paths:
            file_stem = Path(path).stem
            expected_output_dir = Path(output_dir) / file_stem
            if expected_output_dir.is_dir():
                logger.info(f"{file_stem} exist, skip")
                continue
            file_name_list.append(file_stem)
            pdf_bytes_list.append(read_fn(path))
            lang_list.append(lang)
        if not file_name_list:
            continue

        logger.info(f"start to process {i//batch_size + 1}, with {len(file_name_list)} files...")
        batch_failed_paths = do_parse(
            output_dir=output_dir,
            pdf_file_names=file_name_list,
            pdf_bytes_list=pdf_bytes_list,
            p_lang_list=lang_list,
            backend=backend,
            parse_method=method,
            server_url=server_url,
            start_page_id=start_page_id,
            end_page_id=end_page_id,
            p_table_enable=enable_table,
        )
        if batch_failed_paths:
            failed_paths.extend(batch_failed_paths)
    return failed_paths


if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = "6"
    # args
    # data_root = "data/chartdata/chinese"  # TODO chinese
    data_root = "data/chartdata/english"
    output_root = "data/chartdata/processed/mineru_python"
    # pdf_dir_names = [
    #     "CNNIC",
    #     "Guizhou",
    #     "NationalBureauOfStatistics",
    #     "telecomworld",
    # ]
    pdf_dir_names = [
        # "worldbank",
        # "statista",
        # "GlobalTerrorismDatabase",
        "oced"
    ]
    # 1-s2.0-S2214629624005085-main

    for dir in pdf_dir_names:
        # pdf_files_dir = "data/chartdata/test"
        pdf_files_dir = os.path.join(data_root, dir, "pdf")
        output_dir = os.path.join(output_root, dir)
        os.makedirs(output_dir, exist_ok=True)
        pdf_suffixes = [".pdf"]
        image_suffixes = [".png", ".jpeg", ".jpg"]

        doc_path_list = []
        for doc_path in Path(pdf_files_dir).glob("*"):
            if doc_path.suffix in pdf_suffixes + image_suffixes:
                doc_path_list.append(doc_path)

        """Use pipeline mode if your environment does not support VLM"""
        failed_paths = parse_doc(doc_path_list, output_dir, backend="pipeline", enable_table=False)

        if failed_paths:
            with open(os.path.join(output_root, "failed", dir + ".txt"), "w", encoding="utf-8") as f:
                for path in failed_paths:
                    f.write(path + "\n")
            logger.info(f"write failed pdfs to {os.path.join(output_root, 'failed', dir + '.txt')}")
        # exit()
