import os 
import json
import copy

from PIL import Image
import torch
from torchvision import transforms

from transformers import OFATokenizer, OFAModel

# import sys
# sys.path.append("../../")
# from dataset.randaugment import RandomAugment

# from generate import sequence_generator
from tqdm import tqdm

mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
resolution = 480
# add random crop and flip
# add rotation
random_crop_transform = transforms.Compose([
    lambda image: image.convert("RGB"),
    transforms.RandomResizedCrop(resolution, scale=(0.5, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

# def rotate_img(image, angle):
#     return image.rotate(angle)

# def resize_pad_img(image, scale):
#     new_size = (int(image.size[0] * scale), int(image.size[1] * scale))
#     new_img = image.resize(new_size, Image.BICUBIC)
#     pad = Image.new("RGB", (resolution, resolution))
#     pad.paste(new_img, ((resolution - new_size[0]) // 2, (resolution - new_size[1]) // 2))
#     return pad

# patch_resize_transform = transforms.Compose([
#         lambda image: image.convert("RGB"),
#         transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
#         transforms.ToTensor(), 
#         transforms.Normalize(mean=mean, std=std)
#     ])

ckpt_dir = "OFA-Sys/ofa-large-caption"
tokenizer = OFATokenizer.from_pretrained(ckpt_dir)


# image_path_list = [
#     "flickr30k-images/1007129816.jpg",
#     "flickr30k-images/1009434119.jpg",
# ]

# # using the generator of huggingface version
# model = OFAModel.from_pretrained(ckpt_dir, use_cache=False)

# for path_to_image in image_path_list:

#     path_to_image = os.path.join("/data/dataset/Flickr30k", path_to_image)

#     txt = " what does the image describe?"
#     inputs = tokenizer([txt], return_tensors="pt").input_ids
#     img = Image.open(path_to_image)
#     patch_img = patch_resize_transform(img).unsqueeze(0)

#     gen = model.generate(inputs, patch_images=patch_img, num_beams=5, no_repeat_ngram_size=3) 

#     print(tokenizer.batch_decode(gen, skip_special_tokens=True))



############ Flickr30k ############
# dataset = "flickr30k"
# MODEL_ID = "OFA-large-caption"
# N_CAPTIONS = 4

# image_dir = "/data/dataset/Flickr30k"
# orig_train_file=f"/data/dataset/dataset_json/data/{dataset}_train.json"
# new_train_file=f"/data/dataset/dataset_json/{MODEL_ID}/{dataset}_train.json"

############ COCO ############
dataset = "coco"
MODEL_ID = "OFA-large-caption"
N_CAPTIONS = 4

image_dir = "/data/dataset/MSCOCO"
orig_train_file=f"/data/dataset/dataset_json/data/{dataset}_train.json"
new_train_file=f"/data/dataset/dataset_json/{MODEL_ID}/{dataset}_train.json"

os.makedirs(os.path.dirname(new_train_file), exist_ok=True)

with open(orig_train_file, 'r') as f:
    anns = json.load(f)

# using the generator of huggingface version
model = OFAModel.from_pretrained(ckpt_dir, use_cache=False)
model.cuda()
model.eval()

image_id_list = []

new_anns = [[] for _ in range(N_CAPTIONS)]
for i, ann in tqdm(enumerate(anns)):
    image_id = ann["image_id"]
    image_file = ann["image"]
    caption = ann["caption"]

    if image_id in image_id_list:
        continue
    else:
        image_id_list.append(image_id)

    new_ann = copy.deepcopy(ann)

    txt = "what does the image describe?"
    inputs = tokenizer([txt] * N_CAPTIONS, return_tensors="pt").input_ids
    inputs = inputs.cuda()

    img = Image.open(os.path.join(image_dir, image_file))

    patch_imgs = []
    for j in range(N_CAPTIONS):
        # patch_img = patch_resize_transform(img).unsqueeze(0)
        patch_img = random_crop_transform(img).unsqueeze(0)
        patch_imgs.append(patch_img)
    patch_imgs = torch.cat(patch_imgs, dim=0)
    patch_imgs = patch_imgs.cuda()

    gen = model.generate(inputs, patch_images=patch_imgs, num_beams=5, no_repeat_ngram_size=3) 

    for j in range(N_CAPTIONS):
        cap = tokenizer.batch_decode(gen, skip_special_tokens=True)[j]
        new_ann["caption"] = cap
        print(image_id, j, cap)
        new_anns[j].append(copy.deepcopy(new_ann))

    # gen = model.generate(inputs, patch_images=patch_img, num_beams=5, no_repeat_ngram_size=3) 

    # cap = tokenizer.batch_decode(gen, skip_special_tokens=True)[0]
    # new_ann["caption"] = cap

    # print(image_id, cap)

    # new_anns.append(new_ann)

    if i % 100 == 0:
        print(f"Processed {i} images")
        for j in range(N_CAPTIONS):
            this_train_file = new_train_file.replace(".json", f"_{j}.json")
            this_anns = new_anns[j]
            with open(this_train_file, 'w') as f:
                json.dump(this_anns, f, indent=4)