"""
example:
"""
import os
import time
import argparse
import csv
import multiprocessing as mp
from gpt4_utils import call_llm_summary_without_history

def extract_numeric_part(filename):
    return int(filename.split('/')[-1].split('.')[0])  
# Function to read captions from a given file into a dictionary
def read_captions_into_dict(file_path):
    captions_dict = {}
    with open(file_path, 'r', encoding='utf-8') as file:
        reader = csv.reader(file, delimiter='\t')
        for row in reader:
            if "." not in row[0]:
                image_name = row[0]
            else:
                image_name = row[0].split('.')[0]
            captions_dict[image_name] = row[1]
    return captions_dict

# Function to save the summaries into a new TSV file
def save_summaries(file_name, image_name, summaries):
    with open(file_name, 'w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file, delimiter='\t')
        writer.writerow(["image_file_name", "summary_caption"])
        writer.writerow([image_name, summaries])



def preprocess_dense_caption(dense_caption):
    """
    become bounding box and confidence shorter
    """
    return dense_caption

def summarize_captions(caption_list):
    image_caption = caption_list[0]
    dense_caption = caption_list[1]
    object_caption = caption_list[2]
    # print("image_caption: ", image_caption, "\n")
    # print("dense_caption: ", dense_caption, "\n")
    # print("object_caption: ", object_caption, "\n")
    # summary = " ".join([image_caption, dense_caption, object_caption])
    summary = call_llm_summary_without_history([image_caption, preprocess_dense_caption(dense_caption)])
    return summary

def process_one_file(input_file_path, args):
    """
    open three different tar files and summary the captions
    """
    input_file_path = input_file_path.split('/')[-1]
    output_file_path = f'{args.output_dir}/{input_file_path.split(".")[0]}'
    file_path1 = f'{args.input_caption_dirs}/{input_file_path}'
    file_path2 = f'{args.input_dense_caption_dirs}/{input_file_path}'
    file_path3 = f'{args.input_object_caption_dirs}/{input_file_path}'

    captions_dict1 = read_captions_into_dict(file_path1)
    captions_dict2 = read_captions_into_dict(file_path2)
    captions_dict3 = read_captions_into_dict(file_path3)

    count = 0
    for image_name in captions_dict1.keys():
        if os.path.exists(f'{output_file_path}/{image_name}.tsv'):
            print(f"Skip processing {image_name} because {output_file_path}/{image_name}.tsv already exists.")
            continue
        count += 1
        if not os.path.exists(f'{output_file_path}'):
            os.makedirs(f'{output_file_path}')
        # if os.path.exists(f'{output_file_path}/{image_name}.tsv'):
        #     continue
        all_captions = [
            captions_dict1.get(image_name, ""),
            captions_dict2.get(image_name, ""),
            captions_dict3.get(image_name, "")
        ]
        summary = summarize_captions(all_captions)
        if summary == None:
            print(f"Summary for {image_name} is None!")
            continue
        save_summaries(f'{output_file_path}/{image_name}.tsv', image_name, summary)
        if count % 10 == 0:
            print(f'Summaries for {image_name} saved!')
            print(f'Processed {count} images in {output_file_path}!')
    # write it into a new file
    
def run_inference(rank, args, dir_list):
    print(f"Process process: {rank}")
    for dir_index, subdir in enumerate(dir_list):
        total_begin_time = time.time()
        process_one_file(subdir, args)
        total_end_time = time.time()
        total_time = total_end_time - total_begin_time
        print(f"Total time taken to process {subdir} minutes: {divmod(total_time, 60)}")
        total_left_time = (len(dir_list) - dir_index - 1) * total_time
        hours, remainder = divmod(total_left_time, 3600)
        minutes, seconds = divmod(remainder, 60)
        print(f"Estimated time left: {int(hours)}:{int(minutes)}:{int(seconds)}")


def main():
    parser = argparse.ArgumentParser(description="Generate captions for images in directories.")
    parser.add_argument("--input_caption_dirs", type=str, default="/dataset/cc3m_wds/annotation/val_caption", help="Path to the directory containing tar files.")
    parser.add_argument("--input_dense_caption_dirs", type=str, default="/dataset/cc3m_wds/annotation/val_dense_caption", help="Path to the directory containing tar files.")
    parser.add_argument("--input_object_caption_dirs", type=str, default="/dataset/cc3m_wds/annotation/val_sam_caption", help="Path to the directory containing tar files.")
    parser.add_argument("--output_dir", type=str, default="/dataset/cc3m_wds/annotation/val_summary", help="Path to the output directory to save generated captions.")
    parser.add_argument("--begin_index", type=int, default=0, help="The index of the first tar file to process.")
    parser.add_argument("--end_index", type=int, default=9999999, help="The index of the last tar file to process.")
    parser.add_argument("--processes", type=int, default=1, help="The number of processes run.")
    parser.add_argument
    args = parser.parse_args()


    sub_dirs = [] 
    for sub_dir in os.listdir(args.input_caption_dirs):
        if '.tar' in sub_dir:
            sub_dirs.append(os.path.join(args.input_caption_dirs, sub_dir))

    #
    # Sort the 'doc_shards' list in order
    dir_list = sorted(sub_dirs, key=extract_numeric_part)

     # Create the image list
    selected_dir_list = []
    for dir_index, subdir in enumerate(dir_list):
        if dir_index < args.begin_index:
            continue
        if dir_index >= args.end_index:
            break
        # print("Processing directory", subdir)
        selected_dir_list.append(subdir)

    # Define the number of GPUs you have
    world_size = args.processes

    print(selected_dir_list[0])
    # Split the dir_list among each GPU
    num_images_per_gpu = (len(selected_dir_list) + world_size - 1) // world_size
    dir_list_per_gpu = [selected_dir_list[i*num_images_per_gpu:(i+1)*num_images_per_gpu] for i in range(world_size)]

    # Start each process for each GPU
    processes = []
    for rank in range(world_size):
        p = mp.Process(target=run_inference, args=(rank, args, dir_list_per_gpu[rank]))
        p.start()
        processes.append(p)


    # Wait for all processes to finish
    for p in processes:
        p.join()

if __name__ == "__main__":
    main()