import os
import re
import time
import json
import pickle
import argparse
from tqdm import tqdm
from datetime import datetime
import google.generativeai as genai
from collections import defaultdict
from anno_prompts import get_description_prompt, get_annotation_prompt
from anno_utils import normalize_description, normalize_annotation


def main(args):
    ### Loading datasets
    if args.dataset == "WikiWeb":
        data_folder = args.data_folder
        dataset_path = os.path.join(data_folder, "WikiWeb/wiki_data.pkl")
        with open(dataset_path, 'rb') as file:
            dataset = pickle.load(file)
            idx_list = list(range(len(dataset)))
        if args.use_intersect == "inter":
            intersect_idx_path = os.path.join(data_folder, "WikiWeb/intersect_idx.json")
            with open(intersect_idx_path, 'r') as file:
                intersect_idx_dict = json.load(file)
        if args.mode == "chunk":
            idx_list = idx_list[args.chunk_idx * args.chunk_size:(args.chunk_idx + 1) * args.chunk_size]
        else:
            idx_list = eval(f"idx_list[{args.slice}]") if ":" in args.slice else [idx for idx in eval(args.slice)]
    else:
        raise NotImplementedError

    ### Creating output directories
    output_dir = os.path.join(os.getcwd(), f"../output/annotation")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if args.mode == "chunk":
        chunk_dir = os.path.join(output_dir, f"gemini_{args.use_wiki_text}/chunks")
        if not os.path.exists(chunk_dir):
            os.makedirs(chunk_dir)
        output_file = os.path.join(chunk_dir, f"chunk_{args.chunk_idx}.json")
    else:
        output_file = os.path.join(output_dir, f"anno_gemini_{args.use_wiki_text}_{time.strftime('%m%d_%H%M%S')}.json")

    ### Loading existing output files
    annotations = []
    last_idx = -1
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            annotations = json.load(f)
            if len(annotations) > 0:
                last_idx = annotations[-1]['idx']

    ### Loading temporary files
    has_temp_data = False
    temp_dir = os.path.join(output_dir, f"gemini_{args.use_wiki_text}/temp")
    if args.use_regenerate == "no-regen":
        if os.path.exists(temp_dir):
            temp_data = []
            for temp_idx in range(len(os.listdir(temp_dir))):
                file_name = os.path.join(temp_dir, f"temp_{temp_idx}.json")
                if os.path.exists(file_name):
                    with open(file_name, "r") as f:
                        temp_data += json.load(f)
            if len(temp_data) == len(dataset):
                has_temp_data = True

    ### Loading models
    load_start_time = datetime.now()
    genai.configure(api_key=args.API_KEY)
    model = genai.GenerativeModel(args.lvlm_path)
    load_end_time = datetime.now()
    print(f"Loading duration: {str(load_end_time - load_start_time)[:-7]}")

    ### Running models
    for cnt, idx in tqdm(enumerate(idx_list), total=len(idx_list), desc="Processing diagrams"):
        if cnt % args.save_interval == 0:
            with open(output_file, "w") as f:
                json.dump(annotations, f, indent=4)
        if idx <= last_idx:
            continue
        data = dataset[idx]
        if args.use_intersect == "inter":
            if str(idx) in intersect_idx_dict:
                data['tag'] = intersect_idx_dict[str(idx)]
            else:
                annotations.append({'idx': idx, 'description': 'Filtered: Tagging Intersection', 'annotation': 'Filtered: Tagging Intersection'})
                continue

        # Step 1: Description
        if args.use_regenerate == "regen" or not has_temp_data:
            description_prompt = get_description_prompt(args, data)
            description_prompt += f"![Diagram Image]({data['image_url']})"
            description_str = model.generate_content([description_prompt]).text.strip()
            description_str = re.sub(r'^```json', '', re.sub(r'```$', '', description_str)).strip()
        else:
            description_str = temp_data[idx]['description']
        data['model_description'] = description_str
        description = normalize_description(description_str) if args.use_norm == "norm" else description_str

        # Step 2: Annotation
        anno_prompt = get_annotation_prompt(args, data)
        anno_prompt += f"![Diagram Image]({data['image_url']})"
        anno_str = model.generate_content([anno_prompt]).text.strip()
        anno_str = re.sub(r'^```json', '', re.sub(r'```$', '', anno_str)).strip()
        anno, err_msg = normalize_annotation(anno_str) if args.use_norm == "norm" else (anno_str, None)

        # Collecting results
        annotations.append({'idx': idx, 'description': description, 'annotation': anno})
        time.sleep(10)

    run_end_time = datetime.now()
    print(f"Running duration: {str(run_end_time - load_end_time)[:-7]}")

    ### Saving results
    with open(output_file, "w") as f:
        json.dump(annotations, f, indent=4)

    if args.mode == "chunk" and args.use_regenerate == "regen":
        if not os.path.exists(temp_dir):
            os.makedirs(temp_dir)
        temp_file = os.path.join(temp_dir, f"temp_{args.chunk_idx}.json")
        with open(temp_file, "w") as f:
            temp_data = [{'idx': item['idx'], 'description': item['description']} for item in annotations]
            json.dump(temp_data, f, indent=4)

    if args.mode == "modify":
        idx_dict, enum_idx_dict = defaultdict(list), defaultdict(list)
        chunk_dir = os.path.join(output_dir, f"gemini_{args.use_wiki_text}/chunks")
        with open(os.path.join(chunk_dir, f"chunk_0.json"), "r") as f:
            chunk_size = len(json.load(f))
        for enum_idx, item in enumerate(annotations):
            idx_dict[item['idx'] // chunk_size].append(item['idx'] % chunk_size)
            enum_idx_dict[item['idx'] // chunk_size].append(enum_idx)
        num_chunks = len(os.listdir(chunk_dir))
        for chunk_idx in range(num_chunks):
            file_name = os.path.join(chunk_dir, f"chunk_{chunk_idx}.json")
            with open(file_name, "r") as f:
                chunk_data = json.load(f)
            for idx, enum_idx in zip(idx_dict[chunk_idx], enum_idx_dict[chunk_idx]):
                chunk_data[idx] = annotations[enum_idx]
            with open(file_name, "w") as f:
                json.dump(chunk_data, f, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    # General
    parser.add_argument('--dataset', type=str, default="WikiWeb", help="[WikiWeb]")
    parser.add_argument('--data_folder', type=str, default="../../Datasets")
    parser.add_argument('--API_KEY', type=str, default="XXX")
    parser.add_argument('--lvlm_path', type=str, default="gemini-1.5-flash")
    parser.add_argument('--save_interval', type=int, default=25)
    # Task
    parser.add_argument('--use_wiki_text', type=str, default='wiki', help="[wiki, no-wiki]")
    parser.add_argument('--use_norm', type=str, default='norm', help="[norm, no-norm]")
    parser.add_argument('--use_regenerate', type=str, default='no-regen', help="[regen, no-regen]")
    parser.add_argument('--use_intersect', type=str, default="no-inter", help="[inter, no-inter]")
    # Mode
    parser.add_argument('--mode', type=str, default="debug", help="[debug, chunk, modify]")
    parser.add_argument('--slice', type=str, default="")
    parser.add_argument('--chunk_size', type=int, default=5000)
    parser.add_argument('--chunk_idx', type=int, default=0)
    args = parser.parse_args()

    print("########## Information ##########")
    print(f"Model: {args.lvlm_path}")
    print(f"API key: {args.API_KEY}")
    print(f"Use of wiki text: {args.use_wiki_text}")
    print(f"Use of normalization: {args.use_norm}")
    print(f"Mode: {args.mode}")
    print(f"Slice: {args.slice}")
    print(f"Chunk size: {args.chunk_size}")
    print(f"Chunk idx: {args.chunk_idx}")
    print(f"Starting time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    main(args)
    print(f"Ending time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

