import os
import re
import time
import json
import pickle
import argparse
from tqdm import tqdm
from datetime import datetime
import google.generativeai as genai
from collections import defaultdict
from tag_prompts import get_description_prompt, get_tagging_prompt
from tag_utils import normalize_tag


def main(args):
    ### Loading datasets
    if args.dataset == "WikiWeb":
        data_folder = args.data_folder
        dataset_path = os.path.join(data_folder, "WikiWeb/wiki_data.pkl")
        with open(dataset_path, 'rb') as file:
            dataset = pickle.load(file)
            idx_list = list(range(len(dataset)))
        if args.mode == "chunk":
            idx_list = idx_list[args.chunk_idx * args.chunk_size:(args.chunk_idx + 1) * args.chunk_size]
        else:
            idx_list = eval(f"idx_list[{args.slice}]") if ":" in args.slice else [idx for idx in eval(args.slice)]
    else:
        raise NotImplementedError

    ### Creating output directories
    output_dir = os.path.join(os.getcwd(), f"../output/tagging")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if args.mode == "chunk":
        chunk_dir = os.path.join(output_dir, f"gemini_{args.use_option}/chunks")
        if not os.path.exists(chunk_dir):
            os.makedirs(chunk_dir)
        output_file = os.path.join(chunk_dir, f"chunk_{args.chunk_idx}.json")
    else:
        output_file = os.path.join(output_dir, f"tag_gemini_{args.use_option}_{time.strftime('%m%d_%H%M%S')}.json")

    ### Loading existing output files
    taggings = []
    last_idx = -1
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            taggings = json.load(f)
            if len(taggings) > 0:
                last_idx = taggings[-1]['idx']

    ### Loading models
    load_start_time = datetime.now()
    genai.configure(api_key=args.API_KEY)
    model = genai.GenerativeModel(args.lvlm_path)
    load_end_time = datetime.now()
    print(f"Loading duration: {str(load_end_time - load_start_time)[:-7]}")

    ### Running models
    for cnt, idx in tqdm(enumerate(idx_list), total=len(idx_list), desc="Processing diagrams"):
        if cnt % args.save_interval == 0:
            with open(output_file, "w") as f:
                json.dump(taggings, f, indent=4)
        if idx <= last_idx:
            continue
        data = dataset[idx]

        # Step 1: Description
        if args.use_desc == "desc":
            description_prompt = get_description_prompt(args, data)
            description_prompt += f"![Diagram Image]({data['image_url']})"
            description = model.generate_content([description_prompt]).text.strip()
            data['model_description'] = description
        else:
            description = None

        # Step 2: Tagging
        tag_prompt = get_tagging_prompt(args, data)
        tag_prompt += f"![Diagram Image]({data['image_url']})"
        tag_str = model.generate_content([tag_prompt]).text.strip()
        tag_str = re.sub(r'^```json', '', re.sub(r'```$', '', tag_str)).strip()
        tag, err_msg = normalize_tag(tag_str) if args.use_norm == "norm" else (tag_str, None)

        # Collecting results
        taggings.append({'idx': idx, 'description': description, 'tag': tag})
        time.sleep(10)

    run_end_time = datetime.now()
    print(f"Running duration: {str(run_end_time - load_end_time)[:-7]}")

    ### Saving results
    with open(output_file, "w") as f:
        json.dump(taggings, f, indent=4)
    if args.mode == "modify":
        idx_dict, enum_idx_dict = defaultdict(list), defaultdict(list)
        chunk_dir = os.path.join(output_dir, f"gemini_{args.use_option}/chunks")
        with open(os.path.join(chunk_dir, f"chunk_0.json"), "r") as f:
            chunk_size = len(json.load(f))
        for enum_idx, item in enumerate(taggings):
            idx_dict[item['idx'] // chunk_size].append(item['idx'] % chunk_size)
            enum_idx_dict[item['idx'] // chunk_size].append(enum_idx)
        num_chunks = len(os.listdir(chunk_dir))
        for chunk_idx in range(num_chunks):
            file_name = os.path.join(chunk_dir, f"chunk_{chunk_idx}.json")
            with open(file_name, "r") as f:
                chunk_data = json.load(f)
            for idx, enum_idx in zip(idx_dict[chunk_idx], enum_idx_dict[chunk_idx]):
                chunk_data[idx] = taggings[enum_idx]
            with open(file_name, "w") as f:
                json.dump(chunk_data, f, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    # General
    parser.add_argument('--dataset', type=str, default="WikiWeb", help="[WikiWeb]")
    parser.add_argument('--data_folder', type=str, default="../../Datasets")
    parser.add_argument('--API_KEY', type=str, default="XXX")
    parser.add_argument('--lvlm_path', type=str, default="gemini-1.5-flash")
    parser.add_argument('--save_interval', type=int, default=25)
    # Task
    parser.add_argument('--use_option', default='no-option', help="[option, no-option]")
    parser.add_argument('--use_desc', default='desc', help="[desc, no-desc]")
    parser.add_argument('--use_wiki_text', type=str, default='no-wiki', help="[wiki, no-wiki]")
    parser.add_argument('--use_norm', type=str, default='norm', help="[norm, no-norm]")
    # Mode
    parser.add_argument('--mode', type=str, default="debug", help="[debug, chunk, modify]")
    parser.add_argument('--slice', type=str, default="")
    parser.add_argument('--chunk_size', type=int, default=5000)
    parser.add_argument('--chunk_idx', type=int, default=0)
    args = parser.parse_args()

    print("########## Information ##########")
    print(f"Model: {args.lvlm_path}")
    print(f"API key: {args.API_KEY}")
    print(f"Use of options: {args.use_option}")
    print(f"Use of description: {args.use_desc}")
    print(f"Use of wiki text: {args.use_wiki_text}")
    print(f"Use of normalization: {args.use_norm}")
    print(f"Mode: {args.mode}")
    print(f"Slice: {args.slice}")
    print(f"Chunk size: {args.chunk_size}")
    print(f"Chunk idx: {args.chunk_idx}")
    print(f"Starting time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    main(args)
    print(f"Ending time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
