import os
import sys
import torch
import time
import json
import pickle
from pathlib import Path

from tqdm import tqdm
import random
import math
import numpy as np
import copy

sys.path.append("../../")
from attacks.MMA import eda

# dataset = "flickr30k"
dataset = "coco"

for alpha in [0.3]:
# for alpha in [0.4, 0.5]:
    AUG_NAME = "EDA{}".format(alpha)

    NUM_AUG = 5
    N_CAPS = 1

    # orig_train_file="/data/dataset/dataset_json/data/flickr30k_train.json"
    # new_train_file=f"/data/dataset/dataset_json/{AUG_NAME}/flickr30k_train.json"
    orig_train_file="/data/dataset/dataset_json/data/coco_train.json"
    new_train_file=f"/data/dataset/dataset_json/{AUG_NAME}/coco_train.json"

    os.makedirs(os.path.dirname(new_train_file), exist_ok=True)

    def load_anns(file):
        with open(file, "r") as f:
            anns = json.load(f)
        return anns

    anns = load_anns(orig_train_file)

    new_anns_1 = []
    new_anns_2 = []
    new_anns_3 = []
    new_anns_4 = []
    image_ids = []
    for ann in tqdm(anns):

        if ann["image_id"] in image_ids:
            continue
        image_ids.append(ann["image_id"])

        caps = eda(
            ann["caption"], 
            alpha_sr=alpha, alpha_ri=alpha, alpha_rs=alpha, p_rd=alpha,
            num_aug=NUM_AUG
        )

        new_ann = copy.deepcopy(ann)
        new_ann["caption"] = caps[0]
        new_anns_1.append(new_ann)

        new_ann = copy.deepcopy(ann)
        new_ann["caption"] = caps[1]
        new_anns_2.append(new_ann)

        new_ann = copy.deepcopy(ann)
        new_ann["caption"] = caps[2]
        new_anns_3.append(new_ann)

        new_ann = copy.deepcopy(ann)
        new_ann["caption"] = caps[3]
        new_anns_4.append(new_ann)


    # save
    new_train_file = new_train_file.replace(".json", "_1.json")
    with open(new_train_file, "w") as f:
        json.dump(new_anns_1, f, indent=4)

    new_train_file = new_train_file.replace("_1.json", "_2.json")
    with open(new_train_file, "w") as f:
        json.dump(new_anns_2, f, indent=4)

    new_train_file = new_train_file.replace("_2.json", "_3.json")
    with open(new_train_file, "w") as f:
        json.dump(new_anns_3, f, indent=4)

    new_train_file = new_train_file.replace("_3.json", "_4.json")
    with open(new_train_file, "w") as f:
        json.dump(new_anns_4, f, indent=4)