import os
import time
import json
import pickle
import argparse
from tqdm import tqdm
from datetime import datetime
from anno_prompts import get_description_prompt, get_annotation_prompt


def main(args):
    ### Loading datasets
    if args.dataset == "WikiWeb":
        data_folder = args.data_folder
        dataset_path = os.path.join(data_folder, "WikiWeb/wiki_data.pkl")
        with open(dataset_path, 'rb') as file:
            dataset = pickle.load(file)
            idx_list = list(range(len(dataset)))
        if args.use_intersect == "inter":
            intersect_idx_path = os.path.join(data_folder, "WikiWeb/intersect_idx.json")
            with open(intersect_idx_path, 'r') as file:
                intersect_idx_dict = json.load(file)
        if args.mode == "chunk":
            idx_list = idx_list[args.chunk_idx * args.chunk_size:(args.chunk_idx + 1) * args.chunk_size]
        else:
            idx_list = eval(f"idx_list[{args.slice}]") if ":" in args.slice else [idx for idx in eval(args.slice)]
    else:
        raise NotImplementedError

    ### Creating output directories
    output_dir = os.path.join(os.getcwd(), f"../output/annotation")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if args.mode == "chunk":
        chunk_dir = os.path.join(output_dir, f"gemini_inputs_{args.dump_goal}/chunks")
        if not os.path.exists(chunk_dir):
            os.makedirs(chunk_dir)
        output_file = os.path.join(chunk_dir, f"chunk_{args.chunk_idx}.json")
    else:
        output_file = os.path.join(output_dir, f"gemini_inputs_{args.dump_goal}_{time.strftime('%m%d_%H%M%S')}.json")

    ### Loading existing output files
    inputs = []
    last_idx = -1
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            inputs = json.load(f)
            if len(inputs) > 0:
                last_idx = inputs[-1]['idx']

    ### Dumping
    for cnt, idx in tqdm(enumerate(idx_list), total=len(idx_list), desc="Processing diagrams"):
        if cnt % args.save_interval == 0:
            with open(output_file, "w") as f:
                json.dump(inputs, f, indent=4)
        if idx <= last_idx:
            continue
        data = dataset[idx]
        if args.use_intersect == "inter":
            if str(idx) in intersect_idx_dict:
                data['tag'] = intersect_idx_dict[str(idx)]
            else:
                continue
        prompt = get_description_prompt(args, data)
        inputs.append({'idx': idx, 'image_url': data['image_url'], 'prompt': prompt})

    ### Saving results
    with open(output_file, "w") as f:
        json.dump(inputs, f, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    # General
    parser.add_argument('--dataset', type=str, default="WikiWeb", help="[WikiWeb]")
    parser.add_argument('--data_folder', type=str, default="../../Datasets")
    parser.add_argument('--lvlm_path', type=str, default="")
    parser.add_argument('--save_interval', type=int, default=100)
    # Task
    parser.add_argument('--dump_goal', type=str, default='annotation')
    parser.add_argument('--use_wiki_text', type=str, default='wiki', help="[wiki, no-wiki]")
    parser.add_argument('--use_intersect', type=str, default="no-inter", help="[inter, no-inter]")
    # Mode
    parser.add_argument('--mode', type=str, default="debug", help="[debug, chunk, modify]")
    parser.add_argument('--slice', type=str, default="")
    parser.add_argument('--chunk_size', type=int, default=5000)
    parser.add_argument('--chunk_idx', type=int, default=0)
    args = parser.parse_args()

    print("########## Information ##########")
    print(f"Model: {args.lvlm_path}")
    print(f"Use of wiki text: {args.use_wiki_text}")
    print(f"Mode: {args.mode}")
    print(f"Slice: {args.slice}")
    print(f"Chunk size: {args.chunk_size}")
    print(f"Chunk idx: {args.chunk_idx}")
    print(f"Starting time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    main(args)
    print(f"Ending time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

