"""
example:
"""

import os
from PIL import Image
import time
import argparse
import csv
import random
import glob
import numpy as np
import tarfile
import io
import multiprocessing as mp

def extract_numeric_part(filename):
    return int(filename.split('/')[-1].split('.')[0])  

def process_tar_file(input_tar_file_path):
    captions = []
    print(f"Start Opening and Processing {input_tar_file_path}")
    
    with tarfile.open(input_tar_file_path, "r") as tar:
        for tarinfo in tar:
            if not tarinfo.isfile():
                continue
            file_name, ext = os.path.splitext(tarinfo.name)
            file_content = tar.extractfile(tarinfo.name).read()

            if ext in ['.txt']:
                all_caption = file_content.decode()
                gen_captions = all_caption.split('\t')[1]
                captions.append([file_name, gen_captions])

    return captions
    # write it into a new file
    
def run_inference(dir_list, output_dir):
    for dir_index, subdir in enumerate(dir_list):
        total_begin_time = time.time()

        output_tsv_file = os.path.join(output_dir, os.path.basename(subdir).split('/')[0] + ".tsv")
        generated_captions = process_tar_file(subdir)
        
        with open(output_tsv_file, 'w', newline='', encoding='utf-8') as file:
            writer = csv.writer(file, delimiter='\t')
            # Write the header row if needed
            writer.writerow(["image_file_name", "image_caption"])
            # Write the data rows
            writer.writerows(generated_captions)

        total_end_time = time.time()
        total_time = total_end_time - total_begin_time
        print(f"Total time taken to process {subdir} minutes: {divmod(total_time, 60)}")
        total_left_time = (len(dir_list) - dir_index - 1) * total_time
        hours, remainder = divmod(total_left_time, 3600)
        minutes, seconds = divmod(remainder, 60)
        print(f"Estimated time left: {int(hours)}:{int(minutes)}:{int(seconds)}")


def main():
    parser = argparse.ArgumentParser(description="Generate captions for images in directories.")
    parser.add_argument("--input_dirs", type=str, default="/dataset/coco2014_wds_test", help="Path to the directory containing tar files.")
    parser.add_argument("--output_dir", type=str, default="/dataset/cc3m_wds/annotation/", help="Path to the output directory to save generated captions.")
    parser.add_argument("--begin_index", type=int, default=0, help="The index of the first tar file to process.")
    parser.add_argument("--end_index", type=int, default=9999999, help="The index of the last tar file to process.")
    parser.add_argument("--processes", type=int, default=8, help="The number of processes to run.")
    parser.add_argument("--pretrained_model_path", type=str, default="pretrained_models/grit_b_densecap_objectdet.pth", help="The path to the pretrained model.")
    parser.add_argument
    args = parser.parse_args()

    # Initialize the multiprocessing context
    mp.set_start_method('spawn')

    sub_dirs = [] 
    for sub_dir in os.listdir(args.input_dirs):
        if '.tar' in sub_dir:
            sub_dirs.append(os.path.join(args.input_dirs, sub_dir))

    #
    # Sort the 'doc_shards' list in order
    dir_list = sorted(sub_dirs, key=extract_numeric_part)

     # Create the image list
    selected_dir_list = []
    for dir_index, subdir in enumerate(dir_list):
        if dir_index < args.begin_index:
            continue
        if dir_index >= args.end_index:
            break
        # print("Processing directory", subdir)
        selected_dir_list.append(subdir)

    # Define the number of GPUs you have
    world_size = args.processes

    print(selected_dir_list[0])
    # Split the dir_list among each GPU
    num_images_per_gpu = (len(selected_dir_list) + world_size - 1) // world_size
    dir_list_per_gpu = [selected_dir_list[i*num_images_per_gpu:(i+1)*num_images_per_gpu] for i in range(world_size)]

    # Start each process for each GPU
    processes = []
    for rank in range(world_size):
        p = mp.Process(target=run_inference, args=(dir_list_per_gpu[rank], args.output_dir))
        p.start()
        processes.append(p)


    # Wait for all processes to finish
    for p in processes:
        p.join()

if __name__ == "__main__":
    main()