import os
import io
import time
import json
import torch
import pickle
import argparse
from tqdm import tqdm
from datetime import datetime
from PIL import Image, ImageFile
from collections import defaultdict
from anno_prompts import get_description_prompt, get_annotation_prompt
from anno_utils import normalize_description, normalize_annotation

import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from tools.lvlm_tools import name_lvlm, get_lvlm, run_lvlm

ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None


def main(args):
    ### Loading datasets
    if args.dataset == "WikiWeb":
        data_folder = args.data_folder
        dataset_path = os.path.join(data_folder, "WikiWeb/wiki_data.pkl")
        with open(dataset_path, 'rb') as file:
            dataset = pickle.load(file)
            idx_list = list(range(len(dataset)))
        if args.use_intersect == "inter":
            intersect_idx_path = os.path.join(data_folder, "WikiWeb/intersect_idx.json")
            with open(intersect_idx_path, 'r') as file:
                intersect_idx_dict = json.load(file)
        if args.mode == "chunk":
            idx_list = idx_list[args.chunk_idx * args.chunk_size:(args.chunk_idx + 1) * args.chunk_size]
        else:
            idx_list = eval(f"idx_list[{args.slice}]") if ":" in args.slice else [idx for idx in eval(args.slice)]
    else:
        raise NotImplementedError

    ### Creating output directories
    output_dir = os.path.join(os.getcwd(), f"../output/annotation")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if args.mode == "chunk":
        chunk_dir = os.path.join(output_dir, f"{name_lvlm(args)}_{args.use_wiki_text}/chunks")
        if not os.path.exists(chunk_dir):
            os.makedirs(chunk_dir)
        output_file = os.path.join(chunk_dir, f"chunk_{args.chunk_idx}.json")
    else:
        output_file = os.path.join(output_dir, f"anno_{name_lvlm(args)}_{args.use_wiki_text}_{time.strftime('%m%d_%H%M%S')}.json")

    ### Loading existing output files
    annotations = []
    last_idx = -1
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            annotations = json.load(f)
            if len(annotations) > 0:
                last_idx = annotations[-1]['idx']

    ### Loading temporary files
    has_temp_data = False
    temp_dir = os.path.join(output_dir, f"{name_lvlm(args)}_{args.use_wiki_text}/temp")
    if args.use_regenerate == "no-regen":
        if os.path.exists(temp_dir):
            temp_data = []
            for temp_idx in range(len(os.listdir(temp_dir))):
                file_name = os.path.join(temp_dir, f"temp_{temp_idx}.json")
                if os.path.exists(file_name):
                    with open(file_name, "r") as f:
                        temp_data += json.load(f)
            if len(temp_data) == len(dataset):
                has_temp_data = True

    ### Loading models
    load_start_time = datetime.now()
    processor, model = get_lvlm(args)
    load_end_time = datetime.now()
    print(f"Loading duration: {str(load_end_time - load_start_time)[:-7]}")

    ### Running models
    for cnt, idx in tqdm(enumerate(idx_list), total=len(idx_list), desc="Processing diagrams"):
        if cnt % args.save_interval == 0:
            with open(output_file, "w") as f:
                json.dump(annotations, f, indent=4)
        if idx <= last_idx:
            continue
        data = dataset[idx]
        image = Image.open(io.BytesIO(data['image_bytes']))
        if args.use_intersect == "inter":
            if str(idx) in intersect_idx_dict:
                data['tag'] = intersect_idx_dict[str(idx)]
            else:
                annotations.append({'idx': idx, 'description': 'Filtered: Tagging Intersection', 'annotation': 'Filtered: Tagging Intersection'})
                continue

        # Step 1: Description
        if args.use_regenerate == "regen" or not has_temp_data:
            description_prompt = get_description_prompt(args, data)
            description_str = run_lvlm(args, model, processor, image, description_prompt)
        else:
            description_str = temp_data[idx]['description']
        data['model_description'] = description_str
        description = normalize_description(description_str) if args.use_norm == "norm" else description_str

        # Step 2: Annotation
        anno_prompt = get_annotation_prompt(args, data)
        anno_str = run_lvlm(args, model, processor, image, anno_prompt)
        anno, err_msg = normalize_annotation(anno_str) if args.use_norm == "norm" else (anno_str, None)

        # Collecting results
        annotations.append({'idx': idx, 'description': description, 'annotation': anno})

    run_end_time = datetime.now()
    print(f"Running duration: {str(run_end_time - load_end_time)[:-7]}")

    ### Saving results
    with open(output_file, "w") as f:
        json.dump(annotations, f, indent=4)

    if args.mode == "chunk" and args.use_regenerate == "regen":
        if not os.path.exists(temp_dir):
            os.makedirs(temp_dir)
        temp_file = os.path.join(temp_dir, f"temp_{args.chunk_idx}.json")
        with open(temp_file, "w") as f:
            temp_data = [{'idx': item['idx'], 'description': item['description']} for item in annotations]
            json.dump(temp_data, f, indent=4)

    if args.mode == "modify":
        idx_dict, enum_idx_dict = defaultdict(list), defaultdict(list)
        chunk_dir = os.path.join(output_dir, f"{name_lvlm(args)}_{args.use_wiki_text}/chunks")
        with open(os.path.join(chunk_dir, f"chunk_0.json"), "r") as f:
            chunk_size = len(json.load(f))
        for enum_idx, item in enumerate(annotations):
            idx_dict[item['idx'] // chunk_size].append(item['idx'] % chunk_size)
            enum_idx_dict[item['idx'] // chunk_size].append(enum_idx)
        num_chunks = len(os.listdir(chunk_dir))
        for chunk_idx in range(num_chunks):
            file_name = os.path.join(chunk_dir, f"chunk_{chunk_idx}.json")
            with open(file_name, "r") as f:
                chunk_data = json.load(f)
            for idx, enum_idx in zip(idx_dict[chunk_idx], enum_idx_dict[chunk_idx]):
                chunk_data[idx] = annotations[enum_idx]
            with open(file_name, "w") as f:
                json.dump(chunk_data, f, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    # General
    parser.add_argument('--dataset', type=str, default="WikiWeb", help="[WikiWeb]")
    parser.add_argument('--data_folder', type=str, default="../../Datasets")
    parser.add_argument('--lvlm_path', type=str, default="allenai/Molmo-7B-D-0924")
    parser.add_argument('--max_new_tokens', type=int, default=1000)
    parser.add_argument('--save_interval', type=int, default=100)
    # Task
    parser.add_argument('--use_wiki_text', type=str, default='wiki', help="[wiki, no-wiki]")
    parser.add_argument('--use_norm', type=str, default='norm', help="[norm, no-norm]")
    parser.add_argument('--use_regenerate', type=str, default='no-regen', help="[regen, no-regen]")
    parser.add_argument('--use_intersect', type=str, default="no-inter", help="[inter, no-inter]")
    # Mode
    parser.add_argument('--mode', type=str, default="debug", help="[debug, chunk, modify]")
    parser.add_argument('--slice', type=str, default="")
    parser.add_argument('--chunk_size', type=int, default=5000)
    parser.add_argument('--chunk_idx', type=int, default=0)
    args = parser.parse_args()

    print("########## Information ##########")
    print(f"Model: {args.lvlm_path}")
    print(f"Max new tokens: {args.max_new_tokens}")
    print(f"Use of wiki text: {args.use_wiki_text}")
    print(f"Use of normalization: {args.use_norm}")
    print(f"Mode: {args.mode}")
    print(f"Slice: {args.slice}")
    print(f"Chunk size: {args.chunk_size}")
    print(f"Chunk idx: {args.chunk_idx}")
    print(f"Starting time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    main(args)
    print(f"Ending time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
