import os
import json
import logging
from api_key import api_key,base_url
from prompts.debug import Prompts
from utils.gpt import Debug_with_GPT4V
from utils.functions import merge_label_files, merge_tag_files
from preprocess import get_attrs, get_tags, refine_tags, merge_tags, get_labels
import pickle
import argparse
import glob


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--worker", type=int, default=0)
    args = parser.parse_args()
    total_worker = 20
    div = True
    worker = args.worker

    # prepare data
    logging.basicConfig(level=logging.INFO)
    logging.info('[0*] initializing GPT agent and data...')
    class_ids =  ['n04399382', 'n02132136', 'n02133161', 'n02134084', 'n02134418']
    class_names = json.load(open('cases/imagenet_cls_names.json'))
    base_dir = './data/classification'
    images = {}
    for i, subfolder in enumerate(class_ids):
        subfolder_path = os.path.join(base_dir, subfolder)
        image_paths = glob.glob(os.path.join(subfolder_path, '*'))
        image_paths.sort(key=lambda x: os.path.basename(x))
        image_paths = [path.replace("\\",'/') for path in image_paths]
        images[class_names[subfolder]] = image_paths
    data = images

    # output file name
    attrs_file = 'outputs/classification/bear_attrs.json'
    tags_file = 'outputs/classification/bear_tags.json'
    tags_refined_file = 'outputs/classification/bear_tags_refined.json'
    label_file = 'outputs/classification/bear_labels_v3.pkl'

    # init gpt and prompts
    debug_agent = Debug_with_GPT4V(api_key,base_url)
    prompt_agent = Prompts()
    logging.info('[0 ] GPT agent and data initialized')

    # [1] attribute extraction
    pairs_per_class = 5 # dataset-related attributes
    if os.path.exists(attrs_file):
        logging.info(f'[1*] loading attributes from {attrs_file}...')
        with open(attrs_file, 'r') as f:
            attributes = json.load(f)
    else:
        logging.info('[1*] extracting attributes...')
        attributes = get_attrs(prompt_agent, debug_agent,
                            data, pairs_per_class, task="classification")
        with open(attrs_file, 'w') as f:
            json.dump(attributes, f, indent=4)
    logging.info('[1 ] attributes ready')

    # [2] attribute to tags
    if os.path.exists(tags_file):
        logging.info(f'[2*] loading tags from {tags_file}...')
        with open(tags_file, 'r') as f:
            tags = json.load(f)
    else:
        logging.info('[2*] tagging attributes...')
        tags = get_tags(prompt_agent, debug_agent, attributes)
        with open(tags_file, 'w') as f:
            json.dump(tags, f, indent=4)
    logging.info('[2 ] tags ready')
    
    # sperate jobs to workers, you can set more than 50 workers and use sh to run, speed depends on your internet
    data_for_worker = {}
    for cls in data:
        data_per_worker = len(data[cls])//total_worker
        data_for_worker[cls] = data[cls][worker*data_per_worker:(worker+1)*data_per_worker] if worker != total_worker -1 else data[cls][worker*data_per_worker:]

    # [3] update tags by datas
    # divide the task to worker -> accelerate
    if os.path.exists(tags_refined_file):
        logging.info(f'[3*] loading tags from {tags_refined_file}...')
        with open(tags_refined_file, 'r') as f:
            tags = json.load(f)
    else:
        logging.info('[3*] refine tags..., worker %d'%worker)
        tag_file_worker = os.path.splitext(tags_refined_file)[0] + '_%d.json'%worker
        new_tags = refine_tags(prompt_agent, debug_agent, data_for_worker, tags, worker, maximum_update_data=400)
        with open(tag_file_worker, 'w') as f:
            json.dump(new_tags, f, indent=4)
        new_tags = merge_tag_files(tags_refined_file, total_worker)
        tags = merge_tags(prompt_agent, debug_agent, tags, new_tags)
        with open(tags_refined_file, 'w') as f:
            json.dump(tags, f, indent=4)

    # [4] image labeling
    if os.path.exists(label_file):
        logging.info(f'[4*] loading tags from {label_file}...')
        with open(label_file, 'rb') as f:
            labels = pickle.load(f)
    else:
        logging.info('[4*] labeling images..., worker %d'%worker)
        label_file_worker = os.path.splitext(label_file)[0] + '_%d.pkl'%worker
        labels, tags = get_labels(prompt_agent, debug_agent, data_for_worker, tags, worker)
        with open(label_file_worker, "wb") as file:
            pickle.dump(labels, file)
        labels = merge_label_files(label_file, total_worker)
        with open(label_file, 'wb') as file:
            pickle.dump(labels, file)
    logging.info('[4 ] labels ready')