"""
example:
"""

import os
from PIL import Image
import time
import torch
import math
import argparse
from torch.utils.data import Dataset
import csv
from torch.utils.data import DataLoader
import torch.nn as nn
import random
import torch.distributed as dist
import torch.multiprocessing as mp
import glob
import numpy as np
import tarfile
import re
import json
import io
from models.segment_models.semantic_segment_anything_model import SemanticSegment
from models.segment_models.semgent_anything_model import SegmentAnything
from utils.util import resize_long_edge_cv2

# some hyperparameters
datatype = torch.float16
beam_size = 10
num_captions = 5
valid_image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff']

def check_if_image_file(file_path):
    if not os.path.isfile(file_path):
        return False
    if any(file_path.lower().endswith(ext) for ext in valid_image_extensions):
        return True
    return False

def extract_numeric_part(filename):
    return int(filename.split('/')[-1].split('.')[0]) 

# https://en.wikipedia.org/wiki/YUV#SDTV_with_BT.601
_M_RGB2YUV = [[0.299, 0.587, 0.114], [-0.14713, -0.28886, 0.436], [0.615, -0.51499, -0.10001]]
_M_YUV2RGB = [[1.0, 0.0, 1.13983], [1.0, -0.39465, -0.58060], [1.0, 2.03211, 0.0]]


def convert_PIL_to_numpy(image, format):
    """
    Convert PIL image to numpy array of target format.

    Args:
        image (PIL.Image): a PIL image
        format (str): the format of output image

    Returns:
        (np.ndarray): also see `read_image`
    """
    if format is not None:
        # PIL only supports RGB, so convert to RGB and flip channels over below
        conversion_format = format
        if format in ["BGR", "YUV-BT.601"]:
            conversion_format = "RGB"
        image = image.convert(conversion_format)
    image = np.asarray(image)
    # PIL squeezes out the channel dimension for "L", so make it HWC
    if format == "L":
        image = np.expand_dims(image, -1)

    # handle formats not supported by PIL
    elif format == "BGR":
        # flip channels if needed
        image = image[:, :, ::-1]
    elif format == "YUV-BT.601":
        image = image / 255.0
        image = np.dot(image, np.array(_M_RGB2YUV).T)

    return image

COMMON_OBJECTS = {'background', 'wall', 'floor', 'image', 'sky'}
def semantic_prompt_gen(anns, height, width, topk=10):
    """
    fliter too small objects and objects with low stability score
    anns: [{'class_name': 'person', 'bbox': [0.0, 0.0, 0.0, 0.0], 'size': [0, 0], 'stability_score': 0.0}, ...]
    semantic_prompt: "person: [0.0, 0.0, 0.0, 0.0]; ..."
    There are some rules:
        1. Remove common objects class like wall, floor, sky, bluury image,  
        2. ONLY UNIQUE OBJECTS
    """
    # Sort annotations by area in descending order
    # print(anns[0])
    sorted_annotations = sorted(anns, key=lambda x: x['area'], reverse=True)
    anns_len = len(sorted_annotations)
    # Select the top 10 largest regions
    top_10_largest_regions = sorted_annotations[:min(anns_len, topk)]
    semantic_prompt = ""
    seen_classes = set()  # Set to keep track of unique classes
    for region in top_10_largest_regions:
        if any(obj in region['class_name'] for obj in COMMON_OBJECTS):
            continue
        if any(obj in region['class_name'] for obj in COMMON_OBJECTS) or region['class_name'] in seen_classes:
            continue
        seen_classes.add(region['class_name'])  # Add the class name to the set
        for i in range(4):
            region['bbox'][i] = region['bbox'][i] / width if i % 2 == 0 else region['bbox'][i] / height
            region['bbox'][i] = format(region['bbox'][i], '.2f')  # Format the value with two decimal places
        semantic_prompt += region['class_name'] + ': ' + str(region['bbox']) + "; "
    # print(semantic_prompt)
    # print('\033[1;35m' + '*' * 100 + '\033[0m')
    return semantic_prompt


def gen_object_captions_for_images(images, device_id, pretrained_model_path):
    caption_list = []
    semantic_segment_model = SemanticSegment(device_id)
    segment_model = SegmentAnything(device_id, arch="vit_h", path=pretrained_model_path)
    with torch.no_grad():
        for image_idx, [image_name, image] in enumerate(images):
            try:
                # img = np.array(image)
                img = convert_PIL_to_numpy(image, format="BGR")
                # img = read_image(image_src, format="BGR")
                img = resize_long_edge_cv2(img, 384)
                # If the image is a single channel, convert it to 3 channels by duplicating the channel
                if img.ndim == 2 or img.shape[2] == 1:
                    img = np.repeat(img[:, :, np.newaxis], 3, axis=2)
                height, width = img.shape[:2]
                anns = segment_model.generate_mask_from_image(img)
                anns_w_class = semantic_segment_model.semantic_class_w_mask_from_image(img, anns)
            
                region_caption = semantic_prompt_gen(anns_w_class, height, width)
                caption_list.append([image_name, region_caption])
            except Exception as e:
                print(e)
                print(f"Error processing {image_name}")
                continue
            if image_idx % 10 == 0:
                print(f"Processed {image_idx} images.")
    return caption_list

def process_tar_file(input_tar_file_path, device_id, pretrained_model_path):
    images = []
    txt_content_dict = dict()
    json_content_dict = dict()
    print(f"Start Opening and Processing {input_tar_file_path}")
    
    generated_captions = []
    batch_size = 64
    with tarfile.open(input_tar_file_path, "r") as tar:
        for tarinfo in tar:
            if not tarinfo.isfile():
                continue
            file_name, ext = os.path.splitext(tarinfo.name)
            file_content = tar.extractfile(tarinfo.name).read()

            if ext in valid_image_extensions:
                image_file = io.BytesIO(file_content)
                images.append([file_name + ext, Image.open(image_file)])
            elif ext in ['.txt']:
                txt_content_dict[file_name] = file_content.decode()
            elif ext in ['.json']:
                json_content_dict[file_name] = file_content.decode()
            if len(images) == batch_size:
                # Generate captions for all images
                generated_batched_captions = gen_object_captions_for_images(images, device_id, pretrained_model_path)
                generated_captions.extend(generated_batched_captions)
                images = []
                # write it into a new file
                # break
    if len(images) > 0:
        generated_batched_captions = gen_object_captions_for_images(images, device_id, pretrained_model_path)
        generated_captions.extend(generated_batched_captions)
    return generated_captions
    # write it into a new file
    
def run_inference(rank, args, dir_list, output_dir, pretrained_model_path):
    torch.manual_seed(0)  # set a manual seed for reproducibility
    real_rank = rank // args.process_per_gpu
    device = torch.device("cuda", real_rank)  # Assign a specific GPU to each process based on its rank
    print(f"Process rank: {rank}, device: {device}")
    for dir_index, subdir in enumerate(dir_list):
        total_begin_time = time.time()

        output_tsv_file = os.path.join(output_dir, os.path.basename(subdir).split('/')[0] + ".tsv")
        if os.path.exists(output_tsv_file):
            print(f"File {output_tsv_file} already exists. Skip processing {subdir}")
            continue

        generated_captions = process_tar_file(subdir, device_id=real_rank, pretrained_model_path=pretrained_model_path)
        
        with open(output_tsv_file, 'w', newline='', encoding='utf-8') as file:
            writer = csv.writer(file, delimiter='\t')
            # Write the header row if needed
            writer.writerow(["image_file_name", "image_object_caption"])
            # Write the data rows
            writer.writerows(generated_captions)

        total_end_time = time.time()
        total_time = total_end_time - total_begin_time
        print(f"Total time taken to process {subdir} minutes: {divmod(total_time, 60)}")
        total_left_time = (len(dir_list) - dir_index - 1) * total_time
        hours, remainder = divmod(total_left_time, 3600)
        minutes, seconds = divmod(remainder, 60)
        print(f"Estimated time left: {int(hours)}:{int(minutes)}:{int(seconds)}")


def main():
    parser = argparse.ArgumentParser(description="Generate captions for images in directories.")
    parser.add_argument("--input_dirs", type=str, default="/dataset/coco2014_wds_test", help="Path to the directory containing tar files.")
    parser.add_argument("--output_dir", type=str, default="/dataset/cc3m_wds/annotation/", help="Path to the output directory to save generated captions.")
    parser.add_argument("--begin_index", type=int, default=0, help="The index of the first tar file to process.")
    parser.add_argument("--end_index", type=int, default=9999999, help="The index of the last tar file to process.")
    parser.add_argument("--world_size", type=int, default=4, help="The number of GPUs you have.")
    parser.add_argument("--process_per_gpu", type=int, default=1, help="The number of processes to run on each GPU.")
    parser.add_argument("--pretrained_model_path", type=str, default="pretrained_models", help="The path to the pretrained model.")
    parser.add_argument
    args = parser.parse_args()

    # Initialize the multiprocessing context
    mp.set_start_method('spawn')

    sub_dirs = [] 
    for sub_dir in os.listdir(args.input_dirs):
        if '.tar' in sub_dir:
            sub_dirs.append(os.path.join(args.input_dirs, sub_dir))

    #
    # Sort the 'doc_shards' list in order
    dir_list = sorted(sub_dirs, key=extract_numeric_part)

     # Create the image list
    selected_dir_list = []
    for dir_index, subdir in enumerate(dir_list):
        if dir_index < args.begin_index:
            continue
        if dir_index >= args.end_index:
            break
        # print("Processing directory", subdir)
        selected_dir_list.append(subdir)

    # Define the number of GPUs you have
    world_size = args.world_size * args.process_per_gpu

    print(selected_dir_list[0])
    # Split the dir_list among each GPU
    num_images_per_gpu = (len(selected_dir_list) + world_size - 1) // world_size
    dir_list_per_gpu = [selected_dir_list[i*num_images_per_gpu:(i+1)*num_images_per_gpu] for i in range(world_size)]

    # Start each process for each GPU
    processes = []
    for rank in range(world_size):
        p = mp.Process(target=run_inference, args=(rank, args, dir_list_per_gpu[rank], args.output_dir, args.pretrained_model_path))
        p.start()
        processes.append(p)


    # Wait for all processes to finish
    for p in processes:
        p.join()

if __name__ == "__main__":
    main()