import os
import io
import time
import json
import torch
import pickle
import argparse
from tqdm import tqdm
from datetime import datetime
from PIL import Image, ImageFile
from collections import defaultdict
from syn_prompts import get_description_prompt, get_triple_prompt
from syn_utils import normalize_triple

import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from tools.lvlm_tools import name_lvlm, get_lvlm, run_lvlm

ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None


def main(args):
    ### Loading datasets
    if args.dataset == "WikiWeb":
        data_folder = args.data_folder
        dataset_path = os.path.join(data_folder, "WikiWeb/wiki_data.pkl")
        with open(dataset_path, 'rb') as file:
            dataset = pickle.load(file)
            idx_list = list(range(len(dataset)))
        if args.use_intersect == "inter":
            intersect_idx_path = os.path.join(data_folder, "WikiWeb/intersect_idx.json")
            with open(intersect_idx_path, 'r') as file:
                intersect_idx_dict = json.load(file)
        if args.mode == "chunk":
            idx_list = idx_list[args.chunk_idx * args.chunk_size:(args.chunk_idx + 1) * args.chunk_size]
        else:
            idx_list = eval(f"idx_list[{args.slice}]") if ":" in args.slice else [idx for idx in eval(args.slice)]
    else:
        raise NotImplementedError

    ### Creating output directories
    output_dir = os.path.join(os.getcwd(), f"../output/synthesis")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if args.mode == "chunk":
        chunk_dir = os.path.join(output_dir, f"{name_lvlm(args)}_{args.use_wiki_text}/chunks")
        if not os.path.exists(chunk_dir):
            os.makedirs(chunk_dir)
        output_file = os.path.join(chunk_dir, f"chunk_{args.chunk_idx}.json")
    else:
        output_file = os.path.join(output_dir, f"triple_{name_lvlm(args)}_{args.use_wiki_text}_{time.strftime('%m%d_%H%M%S')}.json")

    ### Loading existing output files
    triples = []
    last_idx = -1
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            triples = json.load(f)
            if len(triples) > 0:
                last_idx = triples[-1]['idx']

    ### Loading temporary files
    has_temp_data = False
    temp_dir = os.path.join(output_dir, f"{name_lvlm(args)}_{args.use_wiki_text}/temp")
    if args.use_regenerate == "no-regen":
        if os.path.exists(temp_dir):
            temp_data = []
            for temp_idx in range(len(os.listdir(temp_dir))):
                file_name = os.path.join(temp_dir, f"temp_{temp_idx}.json")
                if os.path.exists(file_name):
                    with open(file_name, "r") as f:
                        temp_data += json.load(f)
            if len(temp_data) == len(dataset):
                has_temp_data = True

    ### Loading models
    load_start_time = datetime.now()
    processor, model = get_lvlm(args)
    load_end_time = datetime.now()
    print(f"Loading duration: {str(load_end_time - load_start_time)[:-7]}")

    ### Running models
    for cnt, idx in tqdm(enumerate(idx_list), total=len(idx_list), desc="Processing diagrams"):
        if cnt % args.save_interval == 0:
            with open(output_file, "w") as f:
                json.dump(triples, f, indent=4)
        if idx <= last_idx:
            continue
        data = dataset[idx]
        image = Image.open(io.BytesIO(data['image_bytes']))
        if args.use_intersect == "inter":
            if str(idx) in intersect_idx_dict:
                data['tag'] = intersect_idx_dict[str(idx)]
            else:
                triples.append({'idx': idx, 'description': 'Filtered: Tagging Intersection', 'triple': 'Filtered: Tagging Intersection'})
                continue

        # Step 1: Description
        if args.use_regenerate == "regen" or not has_temp_data:
            description_prompt = get_description_prompt(args, data)
            description_str = run_lvlm(args, model, processor, image, description_prompt)
        else:
            description_str = temp_data[idx]['description']
        data['description'] = description_str

        # Step 2: Getting triples
        triple_prompt = get_triple_prompt(args, data)
        triple_str = run_lvlm(args, model, processor, image, triple_prompt)
        triple = normalize_triple(triple_str) if args.use_norm == "norm" else triple_str

        # Collecting results
        triples.append({'idx': idx, 'description': description_str, 'triple': triple})

    run_end_time = datetime.now()
    print(f"Running duration: {str(run_end_time - load_end_time)[:-7]}")

    ### Saving results
    with open(output_file, "w") as f:
        json.dump(triples, f, indent=4)

    if args.mode == "chunk" and args.use_regenerate == "regen":
        if not os.path.exists(temp_dir):
            os.makedirs(temp_dir)
        temp_file = os.path.join(temp_dir, f"temp_{args.chunk_idx}.json")
        with open(temp_file, "w") as f:
            temp_data = [{'idx': item['idx'], 'description': item['description']} for item in triples]
            json.dump(temp_data, f, indent=4)

    if args.mode == "modify":
        idx_dict, enum_idx_dict = defaultdict(list), defaultdict(list)
        chunk_dir = os.path.join(output_dir, f"{name_lvlm(args)}_{args.use_wiki_text}/chunks")
        with open(os.path.join(chunk_dir, f"chunk_0.json"), "r") as f:
            chunk_size = len(json.load(f))
        for enum_idx, item in enumerate(triples):
            idx_dict[item['idx'] // chunk_size].append(item['idx'] % chunk_size)
            enum_idx_dict[item['idx'] // chunk_size].append(enum_idx)
        num_chunks = len(os.listdir(chunk_dir))
        for chunk_idx in range(num_chunks):
            file_name = os.path.join(chunk_dir, f"chunk_{chunk_idx}.json")
            with open(file_name, "r") as f:
                chunk_data = json.load(f)
            for idx, enum_idx in zip(idx_dict[chunk_idx], enum_idx_dict[chunk_idx]):
                chunk_data[idx] = triples[enum_idx]
            with open(file_name, "w") as f:
                json.dump(chunk_data, f, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    # General
    parser.add_argument('--dataset', type=str, default="WikiWeb", help="[WikiWeb]")
    parser.add_argument('--data_folder', type=str, default="../../Datasets")
    parser.add_argument('--lvlm_path', type=str, default="allenai/Molmo-7B-D-0924")
    parser.add_argument('--max_new_tokens', type=int, default=1000)
    parser.add_argument('--save_interval', type=int, default=100)
    # Task
    parser.add_argument('--use_wiki_text', type=str, default='wiki', help="[wiki, no-wiki]")
    parser.add_argument('--use_norm', type=str, default='norm', help="[norm, no-norm]")
    parser.add_argument('--use_regenerate', type=str, default='no-regen', help="[regen, no-regen]")
    parser.add_argument('--use_intersect', type=str, default="no-inter", help="[inter, no-inter]")
    # Mode
    parser.add_argument('--mode', type=str, default="debug", help="[debug, chunk, modify]")
    parser.add_argument('--slice', type=str, default="")
    parser.add_argument('--chunk_size', type=int, default=2500)
    parser.add_argument('--chunk_idx', type=int, default=0)
    args = parser.parse_args()

    print("########## Information ##########")
    print(f"Model: {args.lvlm_path}")
    print(f"Max new tokens: {args.max_new_tokens}")
    print(f"Use of wiki text: {args.use_wiki_text}")
    print(f"Use of normalization: {args.use_norm}")
    print(f"Use of regeneration: {args.use_regenerate}")
    print(f"Use of intersection: {args.use_intersect}")
    print(f"Mode: {args.mode}")
    print(f"Slice: {args.slice}")
    print(f"Chunk size: {args.chunk_size}")
    print(f"Chunk idx: {args.chunk_idx}")
    print(f"Starting time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    main(args)
    print(f"Ending time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
