import argparse
from blip_vqa_eval import evaluate_direcotry_using_blip_vqa
from datetime import datetime
import os


def parse_args():
    parser = argparse.ArgumentParser(description="VQA Evaluation of Generated T2I CompBench dataset")
    parser.add_argument(
        "--num_chunks",
        type=int,
        default=20,
    )
    parser.add_argument(
        "--chunk_idx",
        type=int,
        default=None,
        required=True,
    )
    parser.add_argument(
        "--generator_directory_names",
        type=str,
        nargs='+',
        default=["syngen", "deepfloyd", "syngen_sd_v2_1"]
    )
    parser.add_argument(
        "--compbench_category_name",
        type=str,
        default="color",
        choices=["color", "texture", "shape"],
    )

    args = parser.parse_args()

    if args.chunk_idx < 0 or args.chunk_idx >= args.num_chunks:
        raise ValueError("--chunk_idx should be in range of (0, --num_chunks)")

    return args


def get_list_chunk(arr, num_chunks, chunk_idx):
    arr_len = len(arr)

    chunk_size = (arr_len + num_chunks - 1) // num_chunks

    start_index = chunk_size * chunk_idx
    end_index = min((chunk_idx + 1) * chunk_size, arr_len)

    print(f"Choosing chunk ({start_index}:{end_index})")
    print(f"First item of the chunk: \"{arr[start_index]}\"")
    print(f"Last item of the chunk: \"{arr[end_index-1]}\"")

    return arr[start_index:end_index]


if __name__ == '__main__':
    args = parse_args()
    
    with open(f'T2I-CompBench-dataset/{args.compbench_category_name}.txt', 'r') as f:
        prompts = f.read().splitlines()
        prompts = [p.strip('.') for p in prompts]
        prompts = sorted(set(prompts))

    base_dir = f'./T2I-CompBench-dataset/{args.compbench_category_name}'
    
    assert len(set(os.listdir(base_dir)).intersection(prompts)) == len(prompts)

    generated_prompts_directories = get_list_chunk(prompts, args.num_chunks, args.chunk_idx)

    print(f"------ Evaluation for generators: {', '.join(args.generator_directory_names)} ------")

    for prompt_directory in generated_prompts_directories:
        prompt_directory_path = os.path.join(base_dir, prompt_directory)
        for prompt_generator_directory in args.generator_directory_names:
            prompt_generator_directory_path = os.path.join(prompt_directory_path, prompt_generator_directory)
            if not os.path.exists(prompt_generator_directory_path):
                print("!"*100)
                print("!!! Generator path does not exists. Skipping this...")
                print(f"!!! Path: {prompt_generator_directory_path}")
                print("!"*100, flush=True)
                continue

            print("="*80)
            print(f"[Date and Time] {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
            print(f"[Prompt] {prompt_directory}")
            print(f"[Generator] {prompt_generator_directory}")
            print("="*80, flush=True)
            
            evaluate_direcotry_using_blip_vqa(image_folder_path=prompt_generator_directory_path)

    print("Done!")
