import os
import time
import json
import pickle
import argparse
from tqdm import tqdm
from datetime import datetime
from prompts_v2 import get_wiki_prompt, get_triple_prompt, get_empty_prompt
from anno_analysis_utils import get_chunk_data


def main(args):
    ### Loading datasets
    if args.dataset == "WikiWeb":
        data_folder = args.data_folder
        dataset_path = os.path.join(data_folder, "WikiWeb/wiki_data.pkl")
        with open(dataset_path, 'rb') as file:
            dataset = pickle.load(file)
            idx_list = list(range(len(dataset)))
        if args.use_intersect == "inter":
            intersect_idx_path = os.path.join(data_folder, "WikiWeb/intersect_idx.json")
            with open(intersect_idx_path, 'r') as file:
                intersect_idx_dict = json.load(file)
        if args.mode == "chunk":
            idx_list = idx_list[args.chunk_idx * args.chunk_size:(args.chunk_idx + 1) * args.chunk_size]
        else:
            idx_list = eval(f"idx_list[{args.slice}]") if ":" in args.slice else [idx for idx in eval(args.slice)]
    else:
        raise NotImplementedError

    ### Loading annotations and triples
    annotation_dir = os.path.join(os.getcwd(), f"../../output/annotation/gemini_wiki")
    annotations, annotation_idx_dict = get_chunk_data(os.path.join(annotation_dir))
    if args.dump_goal == "triple":
        triple_dir = os.path.join(os.getcwd(), f"../../output/synthesis/gemini_wiki")
        triples, triple_idx_dict = get_chunk_data(os.path.join(triple_dir))

    ### Creating output directories
    output_dir = os.path.join(os.getcwd(), f"../../output/annotation/analysis")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if args.mode == "chunk":
        chunk_dir = os.path.join(output_dir, f"gemini_inputs_v2_{args.dump_goal}/chunks")
        if not os.path.exists(chunk_dir):
            os.makedirs(chunk_dir)
        output_file = os.path.join(chunk_dir, f"chunk_{args.chunk_idx}.json")
    else:
        output_file = os.path.join(output_dir, f"gemini_inputs_v2_{args.dump_goal}_{time.strftime('%m%d_%H%M%S')}.json")

    ### Loading existing output files
    inputs = []
    last_idx = -1
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            inputs = json.load(f)
            if len(inputs) > 0:
                last_idx = inputs[-1]['idx']

    ### Dumping
    types = ['Recognition', 'Understanding', 'Grounding', 'Reasoning']
    for cnt, idx in tqdm(enumerate(idx_list), total=len(idx_list), desc="Processing diagrams"):
        if cnt % args.save_interval == 0:
            with open(output_file, "w") as f:
                json.dump(inputs, f, indent=4)
        if idx <= last_idx:
            continue
        data = dataset[idx]
        if args.use_intersect == "inter":
            if str(idx) in intersect_idx_dict:
                data['tag'] = intersect_idx_dict[str(idx)]
            else:
                continue

        if idx not in annotation_idx_dict:
            continue
        annotation = annotations[annotation_idx_dict[idx]]['annotation']
        if annotation is None:
            continue

        if args.dump_goal == "empty":
            fold_1 = {t: get_empty_prompt(annotation[t]) for t in types}
            fold_2 = {t: get_empty_prompt(annotation[t], rev=True) for t in types}
        elif args.dump_goal == "wiki":
            fold_1 = {t: get_wiki_prompt(data, annotation[t]) for t in types}
            fold_2 = {t: get_wiki_prompt(data, annotation[t], rev=True) for t in types}
        elif args.dump_goal == "triple":
            if idx not in triple_idx_dict:
                continue
            triple = triples[triple_idx_dict[idx]]['triple']
            fold_1 = {t: get_triple_prompt(triple, annotation[t]) for t in types}
            fold_2 = {t: get_triple_prompt(triple, annotation[t], rev=True) for t in types}
        else:
            raise NotImplementedError

        inputs.append({'idx': idx, 'image_url': data['image_url'], 'prompt': {'fold_1': fold_1, 'fold_2': fold_2}})

    ### Saving results
    with open(output_file, "w") as f:
        json.dump(inputs, f, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    # General
    parser.add_argument('--dataset', type=str, default="WikiWeb", help="[WikiWeb]")
    parser.add_argument('--data_folder', type=str, default="../../../Datasets")
    parser.add_argument('--save_interval', type=int, default=100)
    # Task
    parser.add_argument('--dump_goal', type=str, default='empty', help="[empty, wiki, triple")
    parser.add_argument('--use_intersect', type=str, default="no-inter", help="[inter, no-inter]")
    # Mode
    parser.add_argument('--mode', type=str, default="debug", help="[debug, chunk, modify]")
    parser.add_argument('--slice', type=str, default="")
    parser.add_argument('--chunk_size', type=int, default=5000)
    parser.add_argument('--chunk_idx', type=int, default=0)
    args = parser.parse_args()

    print("########## Information ##########")
    print(f"Dump goal: {args.dump_goal}")
    print(f"Mode: {args.mode}")
    print(f"Slice: {args.slice}")
    print(f"Chunk size: {args.chunk_size}")
    print(f"Chunk idx: {args.chunk_idx}")
    print(f"Starting time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    main(args)
    print(f"Ending time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

