"""
example:

python gen_dense_caption_from_shards_multiple_gpus.py --input_dirs dataset/cc3m_wds/val \
--output_dir /dataset/cc3m_wds/annotation/val_dense_caption --world_size 2 --process_per_gpu 2
"""

import os
from PIL import Image
import time
import torch
import math
import argparse
from torch.utils.data import Dataset
import csv
from torch.utils.data import DataLoader
import torch.nn as nn
import random
import torch.distributed as dist
import torch.multiprocessing as mp
import glob
import numpy as np
import tarfile
import re
import json
import io
from models.grit_src.image_dense_captions import dense_pred_to_caption, setup_cfg, get_parser, VisualizationDemo, dense_pred_to_caption_with_norm
from utils.util import resize_long_edge_cv2, resize_image_cv2
from detectron2.data.detection_utils import read_image
from detectron2.engine.defaults import DefaultPredictor
from models.grit_src.grit.predictor import  CustomBatchPrediction

# some hyperparameters
datatype = torch.float16
beam_size = 10
num_captions = 5
valid_image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff']

def check_if_image_file(file_path):
    if not os.path.isfile(file_path):
        return False
    if any(file_path.lower().endswith(ext) for ext in valid_image_extensions):
        return True
    return False

def extract_numeric_part(filename):
    return int(filename.split('/')[-1].split('.')[0])  

# https://en.wikipedia.org/wiki/YUV#SDTV_with_BT.601
_M_RGB2YUV = [[0.299, 0.587, 0.114], [-0.14713, -0.28886, 0.436], [0.615, -0.51499, -0.10001]]
_M_YUV2RGB = [[1.0, 0.0, 1.13983], [1.0, -0.39465, -0.58060], [1.0, 2.03211, 0.0]]

def convert_PIL_to_numpy(image, format):
    """
    Convert PIL image to numpy array of target format.

    Args:
        image (PIL.Image): a PIL image
        format (str): the format of output image

    Returns:
        (np.ndarray): also see `read_image`
    """
    if format is not None:
        # PIL only supports RGB, so convert to RGB and flip channels over below
        conversion_format = format
        if format in ["BGR", "YUV-BT.601"]:
            conversion_format = "RGB"
        image = image.convert(conversion_format)
    image = np.asarray(image)
    # PIL squeezes out the channel dimension for "L", so make it HWC
    if format == "L":
        image = np.expand_dims(image, -1)

    # handle formats not supported by PIL
    elif format == "BGR":
        # flip channels if needed
        image = image[:, :, ::-1]
    elif format == "YUV-BT.601":
        image = image / 255.0
        image = np.dot(image, np.array(_M_RGB2YUV).T)

    return image

def gen_captions_for_images(images, device_id):
    caption_list = []
    args2 = get_parser(device_id)
    print("device_id", device_id)
    cfg = setup_cfg(args2, device_id=device_id)
    demo = CustomBatchPrediction(cfg)
    batch_size = 16
    batch_images = []
    image_name_lists = []
    new_height = 360
    new_width = 640
    # reshape to same dimension
    with torch.no_grad():
        for image_idx, [image_name, image] in enumerate(images):
            if (image_idx % batch_size == 0 and image_idx != 0) or (image_idx == len(images) - 1):
                images = [
                    {'image': torch.from_numpy(single_image)}
                    for single_image in batch_images
                ]
                predictions = demo.run_on_batch(images)
                for idx, prediction in enumerate(predictions):
                    dense_caption = dense_pred_to_caption_with_norm(prediction, image_height=360, image_width=640)
                    caption_list.append([image_name_lists[idx], dense_caption])
                batch_images = []
                image_name_lists = []
            else:
                img = convert_PIL_to_numpy(image, format="BGR")
                # img = read_image(image_src, format="BGR")
                img = resize_image_cv2(img, new_width=new_width, new_height=new_height)
                img = np.transpose(img,(2,0,1))
                print("img.shape", img.shape)
                batch_images.append(img)
                image_name_lists.append(image_name)
            if image_idx % 10 == 0:
                print(f"Processed {image_idx} images.")
    return caption_list

def process_tar_file(input_tar_file_path, device_id):
    images = []
    print(f"Start Opening and Processing {input_tar_file_path}")
    
    with tarfile.open(input_tar_file_path, "r") as tar:
        for tarinfo in tar:
            if not tarinfo.isfile():
                continue
            file_name, ext = os.path.splitext(tarinfo.name)
            file_content = tar.extractfile(tarinfo.name).read()

            if ext in valid_image_extensions:
                image_file = io.BytesIO(file_content)
                images.append([file_name + ext, Image.open(image_file)])

    # Generate captions for all images
    generated_captions = gen_captions_for_images(images, device_id)
    
    return generated_captions
    # write it into a new file
    
def run_inference(rank, args, dir_list, output_dir):
    torch.manual_seed(0)  # set a manual seed for reproducibility
    real_rank = rank // args.process_per_gpu
    device = torch.device("cuda", real_rank)  # Assign a specific GPU to each process based on its rank
    print(f"Process rank: {rank}, device: {device}")
    for dir_index, subdir in enumerate(dir_list):
        total_begin_time = time.time()

        output_tsv_file = os.path.join(output_dir, os.path.basename(subdir).split('/')[0] + ".tsv")
        generated_captions = process_tar_file(subdir, device_id=real_rank)
        
        with open(output_tsv_file, 'w', newline='', encoding='utf-8') as file:
            writer = csv.writer(file, delimiter='\t')
            # Write the header row if needed
            writer.writerow(["image_file_name", "image_dense_caption"])
            # Write the data rows
            writer.writerows(generated_captions)

        total_end_time = time.time()
        total_time = total_end_time - total_begin_time
        print(f"Total time taken to process {subdir} minutes: {divmod(total_time, 60)}")
        total_left_time = (len(dir_list) - dir_index - 1) * total_time
        hours, remainder = divmod(total_left_time, 3600)
        minutes, seconds = divmod(remainder, 60)
        print(f"Estimated time left: {int(hours)}:{int(minutes)}:{int(seconds)}")


def main():
    parser = argparse.ArgumentParser(description="Generate captions for images in directories.")
    parser.add_argument("--input_dirs", type=str, default="/dataset/coco2014_wds_test", help="Path to the directory containing tar files.")
    parser.add_argument("--output_dir", type=str, default="/dataset/cc3m_wds/annotation/", help="Path to the output directory to save generated captions.")
    parser.add_argument("--begin_index", type=int, default=0, help="The index of the first tar file to process.")
    parser.add_argument("--end_index", type=int, default=9999999, help="The index of the last tar file to process.")
    parser.add_argument("--world_size", type=int, default=4, help="The number of GPUs you have.")
    parser.add_argument("--process_per_gpu", type=int, default=1, help="The number of processes to run on each GPU.")
    parser.add_argument
    args = parser.parse_args()

    # Initialize the multiprocessing context
    mp.set_start_method('spawn')

    sub_dirs = [] 
    for sub_dir in os.listdir(args.input_dirs):
        if '.tar' in sub_dir:
            sub_dirs.append(os.path.join(args.input_dirs, sub_dir))

    #
    # Sort the 'doc_shards' list in order
    dir_list = sorted(sub_dirs, key=extract_numeric_part)

     # Create the image list
    selected_dir_list = []
    for dir_index, subdir in enumerate(dir_list):
        if dir_index < args.begin_index:
            continue
        if dir_index >= args.end_index:
            break
        # print("Processing directory", subdir)
        selected_dir_list.append(subdir)

    # Define the number of GPUs you have
    world_size = args.world_size * args.process_per_gpu

    print(selected_dir_list[0])
    # Split the dir_list among each GPU
    num_images_per_gpu = (len(selected_dir_list) + world_size - 1) // world_size
    dir_list_per_gpu = [selected_dir_list[i*num_images_per_gpu:(i+1)*num_images_per_gpu] for i in range(world_size)]

    # Start each process for each GPU
    processes = []
    for rank in range(world_size):
        p = mp.Process(target=run_inference, args=(rank, args, dir_list_per_gpu[rank], args.output_dir))
        p.start()
        processes.append(p)


    # Wait for all processes to finish
    for p in processes:
        p.join()

if __name__ == "__main__":
    main()