import os
import io
import time
import json
import pickle
import argparse
from tqdm import tqdm
from datetime import datetime
from PIL import Image, ImageFile
from collections import defaultdict
from tag_prompts import get_description_prompt, get_tagging_prompt
from tag_utils import normalize_tag

import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from tools.lvlm_tools import name_lvlm, get_lvlm, run_lvlm

ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None


def main(args):
    ### Loading datasets
    if args.dataset == "WikiWeb":
        data_folder = args.data_folder
        dataset_path = os.path.join(data_folder, "WikiWeb/wiki_data.pkl")
        with open(dataset_path, 'rb') as file:
            dataset = pickle.load(file)
            idx_list = list(range(len(dataset)))
        if args.mode == "chunk":
            idx_list = idx_list[args.chunk_idx * args.chunk_size:(args.chunk_idx + 1) * args.chunk_size]
        else:
            idx_list = eval(f"idx_list[{args.slice}]") if ":" in args.slice else [idx for idx in eval(args.slice)]
    else:
        raise NotImplementedError

    ### Creating output directories
    output_dir = os.path.join(os.getcwd(), f"../output/tagging")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if args.mode == "chunk":
        chunk_dir = os.path.join(output_dir, f"{name_lvlm(args)}_{args.use_option}/chunks")
        if not os.path.exists(chunk_dir):
            os.makedirs(chunk_dir)
        output_file = os.path.join(chunk_dir, f"chunk_{args.chunk_idx}.json")
    else:
        output_file = os.path.join(output_dir, f"tag_{name_lvlm(args)}_{args.use_option}_{time.strftime('%m%d_%H%M%S')}.json")

    ### Loading existing output files
    taggings = []
    last_idx = -1
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            taggings = json.load(f)
            if len(taggings) > 0:
                last_idx = taggings[-1]['idx']

    ### Loading models
    load_start_time = datetime.now()
    processor, model = get_lvlm(args)
    load_end_time = datetime.now()
    print(f"Loading duration: {str(load_end_time - load_start_time)[:-7]}")

    ### Running models
    for cnt, idx in tqdm(enumerate(idx_list), total=len(idx_list), desc="Processing diagrams"):
        if cnt % args.save_interval == 0:
            with open(output_file, "w") as f:
                json.dump(taggings, f, indent=4)
        if idx <= last_idx:
            continue
        data = dataset[idx]
        image = Image.open(io.BytesIO(data['image_bytes']))

        # Step 1: Description
        if args.use_desc == "desc":
            description_prompt = get_description_prompt(args, data)
            description = run_lvlm(args, model, processor, image, description_prompt)
            data['model_description'] = description
        else:
            description = None

        # Step 2: Tagging
        tag_prompt = get_tagging_prompt(args, data)
        tag_str = run_lvlm(args, model, processor, image, tag_prompt)
        tag, err_msg = normalize_tag(tag_str) if args.use_norm == "norm" else (tag_str, None)

        # Collecting results
        taggings.append({'idx': idx, 'description': description, 'tag': tag})

    run_end_time = datetime.now()
    print(f"Running duration: {str(run_end_time - load_end_time)[:-7]}")

    ### Saving results
    with open(output_file, "w") as f:
        json.dump(taggings, f, indent=4)
    if args.mode == "modify":
        idx_dict, enum_idx_dict = defaultdict(list), defaultdict(list)
        chunk_dir = os.path.join(output_dir, f"{name_lvlm(args)}_{args.use_option}/chunks")
        with open(os.path.join(chunk_dir, f"chunk_0.json"), "r") as f:
            chunk_size = len(json.load(f))
        for enum_idx, item in enumerate(taggings):
            idx_dict[item['idx'] // chunk_size].append(item['idx'] % chunk_size)
            enum_idx_dict[item['idx'] // chunk_size].append(enum_idx)
        num_chunks = len(os.listdir(chunk_dir))
        for chunk_idx in range(num_chunks):
            file_name = os.path.join(chunk_dir, f"chunk_{chunk_idx}.json")
            with open(file_name, "r") as f:
                chunk_data = json.load(f)
            for idx, enum_idx in zip(idx_dict[chunk_idx], enum_idx_dict[chunk_idx]):
                chunk_data[idx] = taggings[enum_idx]
            with open(file_name, "w") as f:
                json.dump(chunk_data, f, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    # General
    parser.add_argument('--dataset', type=str, default="WikiWeb", help="[WikiWeb]")
    parser.add_argument('--data_folder', type=str, default="../../Datasets")
    parser.add_argument('--lvlm_path', type=str, default="allenai/Molmo-7B-D-0924")
    parser.add_argument('--max_new_tokens', type=int, default=1000)
    parser.add_argument('--save_interval', type=int, default=100)
    # Task
    parser.add_argument('--use_option', default='no-option', help="[option, no-option]")
    parser.add_argument('--use_desc', default='desc', help="[desc, no-desc]")
    parser.add_argument('--use_wiki_text', type=str, default='no-wiki', help="[wiki, no-wiki]")
    parser.add_argument('--use_norm', type=str, default='norm', help="[norm, no-norm]")
    # Mode
    parser.add_argument('--mode', type=str, default="debug", help="[debug, chunk, modify]")
    parser.add_argument('--slice', type=str, default="")
    parser.add_argument('--chunk_size', type=int, default=5000)
    parser.add_argument('--chunk_idx', type=int, default=0)
    args = parser.parse_args()

    print("########## Information ##########")
    print(f"Model: {args.lvlm_path}")
    print(f"Max new tokens: {args.max_new_tokens}")
    print(f"Use of options: {args.use_option}")
    print(f"Use of description: {args.use_desc}")
    print(f"Use of wiki text: {args.use_wiki_text}")
    print(f"Use of normalization: {args.use_norm}")
    print(f"Mode: {args.mode}")
    print(f"Slice: {args.slice}")
    print(f"Chunk size: {args.chunk_size}")
    print(f"Chunk idx: {args.chunk_idx}")
    print(f"Starting time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    main(args)
    print(f"Ending time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
