import os
import sys
import torch
import time
import json
import pickle
from pathlib import Path

from tqdm import tqdm
import random
import math
import numpy as np
import copy
from PIL import Image

from torchvision import transforms

sys.path.append("../../")
from dataset.randaugment import RandomAugment

image_res = 480
aug_n = 2
aug_m = 5
aug_scale = (0.7, 1.0)
train_transform = transforms.Compose([       
            transforms.RandomResizedCrop(image_res,scale=aug_scale, interpolation=Image.BICUBIC),                 
            transforms.RandomHorizontalFlip(),
            RandomAugment(aug_n,aug_m,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
                                              'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),     
            transforms.ToTensor(),
        ])

IMAGE_DIR_NAME = "randaug2-5-0.7"

dataset_info_list = [
    (
        "flirck30k", 
        "/data/dataset/Flickr30k/",
        "/data/dataset/dataset_json/data/flickr30k_train.json", 
        "/data/dataset/dataset_json/RandAug2-5-0.7/flickr30k_train_{}.json"
    ),
    (
        "coco", 
        "/data/dataset/MSCOCO/",
        "/data/dataset/dataset_json/data/coco_train.json", 
        "/data/dataset/dataset_json/RandAug2-5-0.7/coco_train_{}.json"
    ),
]

NUM_AUG = 4
for dataset, image_dir, orig_train_file, new_train_file in dataset_info_list:
    for i in range(NUM_AUG):
        _image_dir = os.path.join(image_dir, IMAGE_DIR_NAME, f"idx{i}")
        os.makedirs(_image_dir, exist_ok=True)
    os.makedirs(os.path.dirname(new_train_file), exist_ok=True)

    def load_anns(file):
        with open(file, "r") as f:
            anns = json.load(f)
        return anns

    anns = load_anns(orig_train_file)

    new_anns_list = [[] for _ in range(NUM_AUG)]
    image_ids = []
    for ann in tqdm(anns):

        if ann["image_id"] in image_ids:
            continue
        image_ids.append(ann["image_id"])
        image_file = ann["image"]

        # load image
        img = Image.open(os.path.join(image_dir, image_file))

        # RandomAugment
        for j in range(NUM_AUG):
            patch_img = train_transform(img)

            _image_dir = os.path.join(image_dir, IMAGE_DIR_NAME, f"idx{j}")
            
            # save new image
            new_image_file = image_file.replace(".jpg", f"_{j}.jpg")
            new_image_file = os.path.join(_image_dir, new_image_file.split("/")[-1])
            patch_img = transforms.ToPILImage()(patch_img)
            patch_img.save(new_image_file)

            # save new annotation
            new_ann = copy.deepcopy(ann)
            new_ann["image"] = IMAGE_DIR_NAME + "/" + f"idx{j}/" + new_image_file.split("/")[-1]
            new_anns_list[j].append(new_ann)

        if len(image_ids) % 100 == 0:
            # save
            for j in range(NUM_AUG):
                _new_train_file = new_train_file.format(j)
                with open(_new_train_file, 'w') as f:
                    json.dump(new_anns_list[j], f, indent=4)

                print("Saved to ", _new_train_file)

    # save
    for j in range(NUM_AUG):
        _new_train_file = new_train_file.format(j)
        with open(_new_train_file, 'w') as f:
            json.dump(new_anns_list[j], f, indent=4)

        print("Saved to ", _new_train_file)


