"""
modified based on the BLIP code base
https://github.com/salesforce/BLIP.git
"""
import argparse
import os
import ruamel.yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
from tqdm import tqdm
from PIL import Image
from collections import defaultdict
import sys
from pathlib import Path
import time
import copy
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.utils.data import DataLoader

from models.blip import blip_decoder
import utils
from utils import cosine_lr_schedule
from data import create_dataset, create_sampler, create_loader, reinit_loader, create_iter_loader, reinit_sampler
from data.utils import save_result, coco_caption_eval, flickr_caption_eval
from data.flickr30k_dataset import flickr30k_train_al, flickr30k_caption_eval
from data.coco_karpathy_dataset import coco_karpathy_caption_eval, coco_karpathy_train_al
from torchvision import transforms
from transform.randaugment import RandomAugment
from torchvision.transforms.functional import InterpolationMode


d = Path(__file__).resolve().parents[1]
sys.path.append(os.path.join(d, "stablediffusion/")) # add sd scripts
from scripts.sd_gen import sd_gen


def clean_img_id(cur_img_id):
    # only keep original id as key for samples such as3874741721_0_sd_0_s1041.jpg 
    cur_img_id = os.path.basename(cur_img_id).split("_")[0]
    return cur_img_id


class DataBuilder(object):
    def __init__(self, config, min_scale=0.5) -> None:
        self.config = config
        
        
        normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

        self.transform_train = transforms.Compose([                        
                transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
                transforms.RandomHorizontalFlip(),
                RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
                                                'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),     
                transforms.ToTensor(),
                normalize,
            ])        
        self.transform_test = transforms.Compose([
            transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            normalize,
             ])  

    def create_train(self, ann, epoch=-1):
        print("create train dataset, number of samples included: %d" % len(ann))
        if "coco" in self.config["image_root"]:
            train_dataset = coco_karpathy_train_al(
            self.transform_train, 
            self.config['image_root'], 
            self.config['aug_data_root'],
            ann=ann,
            prompt=self.config['prompt'],
            epoch=epoch)
        else:
            train_dataset = flickr30k_train_al(
                self.transform_train, 
                self.config['image_root'], 
                self.config['aug_data_root'],
                ann=ann,
                prompt=self.config['prompt'],
                epoch=epoch)
        return train_dataset
        
    def create_val_test(self):
        if "coco" in self.config["image_root"]:
            val_dataset = coco_karpathy_caption_eval(self.transform_test, self.config['image_root'], self.config['ann_root'], 'val')
            test_dataset = coco_karpathy_caption_eval(self.transform_test, self.config['image_root'], self.config['ann_root'], 'test') 
        else:  
            val_dataset = flickr30k_caption_eval(self.transform_test, self.config['image_root'], self.config['ann_root'], 'val')
            test_dataset = flickr30k_caption_eval(self.transform_test, self.config['image_root'], self.config['ann_root'], 'test')   
        # return train_dataset, val_dataset, test_dataset
        return val_dataset, test_dataset

    def get_sample_dict(self, ann):
        """
        create sample dict indexed by image id with the original annotations
        {"1000092795.jpg": 
            [{"image": "1000092795.jpg", 
                     "caption": "Two young guys xxx", 
                     "source": "original", 
                     "caption_id": 0, 
                     "sample_id": "1000092795_0"},
                     {xxx}
        ]}}
        """

        def qualified_caption(caption, max_caption_len=20):
            return True if len(caption.split()) < max_caption_len else False

        sample_dict = defaultdict(dict)
        # build sample dict for each image in coco
        if "coco" in self.config["image_root"]:
            for cur_ann in ann:
                cur_img_id = cur_ann["image_id"]
                if cur_img_id not in sample_dict:
                    sample_dict[cur_img_id] = {"samples": [], "qualified_captions": []}
                
                sample_dict[cur_img_id]["samples"].append(cur_ann)
                if qualified_caption(cur_ann["caption"]):
                    sample_dict[cur_img_id]["qualified_captions"].append(cur_ann["caption"])

        else:
            for cur_ann in ann:
                cur_img_id = cur_ann["image"]
                # clean image_id
                cur_img_id = clean_img_id(cur_img_id)
                # build dict
                if cur_img_id not in sample_dict:
                    sample_dict[cur_img_id] = {"samples": [], "qualified_captions": []}
                sample_dict[cur_img_id]["samples"].append(cur_ann)
                if qualified_caption(cur_ann["caption"]):
                    sample_dict[cur_img_id]["qualified_captions"].append(cur_ann["caption"])
        return sample_dict

    def get_caption_list(self, ann):
        caption_dict = defaultdict(list)

        # build caption list for each image in coco
        if "coco" in self.config["image_root"]:
            for cur_ann in ann:
                cur_img_id = cur_ann["image_id"]
                caption_dict[cur_img_id].append(cur_ann["caption"])
        
        # build caption list for each image in coco
        else:
            for cur_ann in ann:
                cur_img_id = cur_ann["image"]
                key = os.path.basename(cur_img_id).split("_")[0] + ".jpg"
                # build dict
                caption_dict[key].append(cur_ann["caption"])
           
        return caption_dict



def train(model, data_loader, optimizer, epoch, accum_iter, device):
    # train
    model.train()  
    
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
    header = 'Train Caption Epoch: [{}]'.format(epoch)
    print_freq = 50


    for batch_idx, (image, caption, _, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        image = image.to(device) 

        with torch.set_grad_enabled(True):      
        
            loss = model(image, caption)   

            loss = loss / accum_iter 
            loss.backward()

            if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_loader)):
                optimizer.step()
                optimizer.zero_grad()
            
        
            metric_logger.update(loss=loss.item())
            metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger.global_avg())     
    return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}  


@torch.no_grad()
def evaluate(model, data_loader, device, config):
    # evaluate
    model.eval() 
    
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Caption generation:'
    print_freq = 10

    result = []
    for image, image_id in metric_logger.log_every(data_loader, print_freq, header): 
        
        image = image.to(device)       
        
        captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'], 
                                  min_length=config['min_length'])
        
        for caption, img_id in zip(captions, image_id):
            if not isinstance(img_id, str):
                result.append({"image_id": img_id.item(), "caption": caption})
            else:
                result.append({"image_id": img_id, "caption": caption})
  
    return result

@torch.no_grad()
def record_train_img_loss(model, data_loader, device, config):
    # get loss on each of the input images
    model.eval()

    metric_logger = utils.MetricLogger(delimiter=" ")
    header = 'Evaluate loss on train dataset...'
    print_freq = 10

    result = []
    all_loss = []
    for batch_idx, (image, gt_captions, image_path, caption_ids, sample_ids) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        image = image.to(device) 
    
        loss = model(image, gt_captions, reduction="none")
        loss_values = loss.cpu().detach().tolist()
        
        all_loss.append(loss_values)

        # captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'], 
        #                           min_length=config['min_length'])
        
        if not isinstance(loss_values, list):
            loss_values = loss_values.tolist()

        for img_path, gt_caption, loss_val, caption_id, sample_id in zip(image_path, gt_captions, loss_values, caption_ids, sample_ids):
            # print(img_id, loss_val, "\n")
            if "sd" not in img_path and "coco" not in img_path:
                img_bn = os.path.basename(img_path) 
            else:
                img_bn = img_path
            img_source = "sd" if "sd" in img_bn else "original"
            result.append(
                {
                    "image": img_bn,
                    "caption": gt_caption, 
                    "source": img_source,
                    "loss_val": loss_val,
                    "caption_id": caption_id.item(),
                    "sample_id": sample_id
                })

    print("evaluated train samples: %d" % len(result))
    return result, all_loss


def sort_caption_by_train_loss(train_result_file):
    print("sorting caption dict by loss from last epoch..")
    with open(train_result_file, "r") as result_json:
        train_results = json.load(result_json)
    caption_loss_dict = defaultdict(list)
    for sample in sorted(train_results, key=lambda x: x["loss_val"]):
        img_id = os.path.basename(sample["image"]).split("_")[0].replace(".jpg", "")
        caption_loss_dict[img_id].append((sample["caption"], sample["loss_val"]))

    return caption_loss_dict



def train_loss_eval(epoch, train_dataset, model, device, args):
    print("Epoch %s: evaluating on %d training samples to identify augmenting target..." % (epoch, len(train_dataset.annotation)))
    # init eval loader
    if args.distributed:
        sampler = torch.utils.data.DistributedSampler(train_dataset, num_replicas=utils.get_world_size(), rank=utils.get_rank(), shuffle=False)
    else:
        sampler = None
    # sampler = None
    train_eval_loader = reinit_loader(train_dataset, sampler, batch_size=config['batch_size'], num_workers=4, is_train=False, collate_fn=None)
    train_result, train_loss = record_train_img_loss(model, train_eval_loader, device, config)
    train_result_file = save_result(train_result, args.result_dir, 'train_epoch%d'%epoch, args, remove_duplicate="sample_id")
    return train_result_file

def get_top_loss_samples(train_result_file, args):
    with open(train_result_file, "r") as result_json:
        train_results = json.load(result_json)
    # select top train loss samples
    sorted_train_results = sorted(train_results, key=lambda x: x["loss_val"], reverse=True)
    if args.dynamic_aug:
        ## get top loss samples if loss value is 2std away from mean
        df = pd.DataFrame(train_results)
        std, mean = df["loss_val"].std(), df["loss_val"].mean()
        top_loss_samples = df[df["loss_val"] > (mean + 2*std)].to_dict('records')
        rest_samples = df[df["loss_val"] <= (mean + 2*std)].to_dict('records')
    else:
        num_to_augment = round(len(sorted_train_results) * args.curation_ratio)
        top_loss_samples = sorted_train_results[:num_to_augment]
        rest_samples = sorted_train_results[num_to_augment:]
    print("total number of augmenting samples we will %s: %d" % (args.curation_option, len(top_loss_samples)))
    return top_loss_samples, rest_samples

def save_top_loss_samples(args, epoch, top_loss_samples):
    al_samples_root = os.path.join(args.output_dir, "al_samples")
    os.makedirs(al_samples_root, exist_ok=True)
    save_path = os.path.join(al_samples_root, "top_loss_samples_epoch%d.json"%epoch)
    print("saving top loss samples at epoch %d to %s" % (epoch, save_path))
    with open(save_path,"w") as output_json:
        json.dump(top_loss_samples, output_json)
        output_json.close()

def replace_top_loss_samples(top_loss_samples, sample_dict):
    replaced_samples = []
    for sample in top_loss_samples:
        cand_captions = sample_dict[sample["image"]]["qualified_captions"]
        if len(cand_captions) == 0:
            all_captions = [s["caption"] for s in sample_dict[sample["image"]]["samples"]]
            sorted_captions = sorted(all_captions, key=lambda x: len(x.split()))
            cand_captions = sorted_captions[:3]
        
        rep_caption = random.choice(cand_captions)
        replaced_samples.append({
            "image": sample["image"],
            "caption": rep_caption,
            "source": "aug",
            "caption_id": sample["caption_id"],
            "sample_id": str(sample["sample_id"]) + "_r"
        })
    return replaced_samples

def get_origin2sd(sd_train):
    # build corresponding origin_img - sd_img mapping
    origin2sd = {}
    for sample in sd_train:
        origin_image = sample["image"].split("_")[0] + ".jpg"
        origin2sd[origin_image] = sample["image"]
    return origin2sd

def replace_top_loss_images_with_sd(top_loss_samples, origin2sd, caption_dict, caption_dict_by_loss, sample_dict, change_caption_by="loss"):
    # replace top loss samples with previous generated sd images

    def get_origin_key(sample):
        origin_image = os.path.basename(sample["image"]).split("_")[0]
        if not origin_image.endswith(".jpg"):
            origin_image += ".jpg"
        return origin_image
    
    replace_train = []
    for sample in top_loss_samples:
        new_sample = copy.deepcopy(sample)
        if "coco" in sample.get("sample_id", None):
            img_id = "_".join(sample["sample_id"].split("_")[:2])
        else:
            img_id = get_origin_key(sample)
        sd_images = origin2sd.get(img_id, [])
        if len(sd_images) != 0:
            sd_id = sd_images[0]
            # update sd candidates, pop out used ones
            sd_images.pop(0)
            origin2sd[img_id] = sd_images # if could not find an corresponding sd image (or is already sd), stay same
        else:
            sd_id = sample["image"]
        # replace image with a new sd image,  does not change caption, only change image
        if sd_id != sample["image"]:
            new_sample["image"] = sd_id 
            # print("replace %s with %s" % (sample["image"], sd_id))
            new_sample["source"] = "sd"
        # change caption at the same time
        if change_caption_by == "loss":
            cur_caption = sample["caption"]
            cand_captions = caption_dict.get(img_id, [])
            if cur_caption in cand_captions and len(cand_captions) > 1:
                cand_captions.remove(cur_caption)
            if len(cand_captions) > 0:
                rep_caption = random.choice(cand_captions)
            else:
                rep_caption = cur_caption
            new_sample["caption"] = rep_caption
        if change_caption_by == "sorted_loss" or change_caption_by == "lowest_loss":
            cur_caption = sample["caption"]
            rep_caption = cur_caption
            all_captions = caption_dict_by_loss.get(img_id.replace(".jpg", ""), [])
            if len(all_captions) > 0:
                cand_captions = [x[0] for x in all_captions if x[1] < sample["loss_val"]][:3]
                if change_caption_by == "sorted_loss" and len(cand_captions) > 0:
                    rep_caption = random.choice(cand_captions) 
                elif change_caption_by == "lowest_loss" and len(cand_captions) > 0:
                    rep_caption = cand_captions[0]
            new_sample["caption"] = rep_caption
        if change_caption_by == "length":
            # print(img_id, sample_dict[img_id.replace(".jpg", "")], "\n", sample_dict)
            cand_captions = sample_dict[img_id].get("qualified_captions", [])
            if len(cand_captions) == 0:
                all_captions = [s["caption"] for s in sample_dict[img_id]["samples"]]
                sorted_captions = sorted(all_captions, key=lambda x: len(x.split()))
                cand_captions = sorted_captions[:3]
        
            rep_caption = random.choice(cand_captions)
            new_sample["caption"] = rep_caption


        replace_train.append(new_sample) # add sd image

    return replace_train, origin2sd, caption_dict

def generate_sd_samples(top_loss_samples, sample_dict, epoch):
    generator = sd_gen()
    sd_samples = []

    for sample in top_loss_samples:
        cand_captions = sample_dict[sample["image"]].get("qualified_captions", [])
        if len(cand_captions) == 0 and sample_dict[sample["image"]].get("samples") is not None:
            all_captions = [s["caption"] for s in sample_dict[sample["image"]]["samples"]]
            sorted_captions = sorted(all_captions, key=lambda x: len(x.split()))
            cand_captions = sorted_captions[:3]
        
        imgs = generator.generate_with_captions(cand_captions)
        rep_caption = random.choice(cand_captions)
        gimg = imgs[0]

        # sd image will be named as <sample_id>_sd_<epoch>.jpg, the sample_id corresponds to the sample_id in the original training data
        sd_img_name = os.path.basename(sample["sample_id"] + "_sd_%d.jpg"%epoch)
        img_save_path = os.path.join(config["aug_data_root"], "images", sd_img_name)

        os.makedirs(os.path.join(config["aug_data_root"], "images"), exist_ok=True)
        gimg.save(img_save_path)
        sd_samples.append({
            "image": sd_img_name,
            "caption": rep_caption,
            "source": "sd",
            "caption_id": sample["caption_id"],
            "sample_id": sd_img_name.replace(".jpg", "")
        })
    print("generated %d sd samples for augmentation..." % len(sd_samples))

    sd_ann_root = os.path.join(config["aug_data_root"], "annotation")
    os.makedirs(sd_ann_root, exist_ok=True)
    sd_ann_file = os.path.join(sd_ann_root, "sd_ann_epoch_%d.json"%epoch)
    with open(sd_ann_file, "w") as output_json:
        json.dump(sd_samples, output_json)
    return sd_ann_file

def batch_generate_sd_samples(top_loss_samples, sample_dict, epoch, args):
    """generate sd samples in batch"""
    generator = sd_gen(batch_size=args.sd_batch_size)
    sd_samples = []
    
    # set output path
    sd_ann_root = os.path.join(config["aug_data_root"], "annotation")
    os.makedirs(sd_ann_root, exist_ok=True)

    sd_img_root = os.path.join(config["aug_data_root"], "images")
    os.makedirs(sd_img_root, exist_ok=True)

    print("generate sd annotations first...")
    sample_idx = 0
    for sample in tqdm(top_loss_samples, total=len(top_loss_samples)):
        
        img_id = clean_img_id(sample["image"])

        cand_captions = sample_dict[img_id].get("qualified_captions", [])
        if len(cand_captions) == 0 and sample_dict[img_id].get("samples") is not None:
            all_captions = [s["caption"] for s in sample_dict[img_id]["samples"]]
            sorted_captions = sorted(all_captions, key=lambda x: len(x.split()))
            cand_captions = sorted_captions[:3]

        # use a reasonable caption for generation
        if len(cand_captions) != 0:
            rep_caption = random.choice(cand_captions)
        else:
            print("can not obtain a better catpion....") # this should not happen tho
            rep_caption = sample["caption"]

        if args.concat_prompt:
            prompt = ". ".join(list(set(cand_captions))) if len(cand_captions)!=0 else sample["caption"]
        else:
            prompt = rep_caption
        
        if args.add_styler:
            prompt = prompt + " national geographic, high quality photography, Canon EOS R3, Flickr"

        ### save new sd samples   
        # sd image will be named as <sample_id>_sd_<epoch>_s<id>.jpg, the sample_id corresponds to the sample_id in the original training data
        sd_img_name = os.path.basename(sample["sample_id"] + "_sd_%d_s%d.jpg"%(epoch, sample_idx))
        img_save_path = os.path.join(sd_img_root, sd_img_name)
        
        sd_samples.append({
            "image": sd_img_name,
            "caption": rep_caption,
            "source": "sd",
            "caption_id": sample["caption_id"],
            "sample_id": sd_img_name.replace(".jpg", ""),
            "prompt": prompt
        })
        sample_idx += 1
    

    # generate images in batch for expected sd samples  
    print("generate sd images")  
    generator.generate(sd_samples, outdir=sd_img_root)   
    print("generated %d sd samples for augmentation..." % len(sd_samples))

    sd_ann_file = os.path.join(sd_ann_root, "sd_ann_epoch_%d.json"%epoch)
    with open(sd_ann_file, "w") as output_json:
        json.dump(sd_samples, output_json)
    return sd_ann_file


def get_offline_sd_samples(epoch, config):
    annotation_root = os.path.join(config["aug_data_root"], "annotations")
    annotation_file = os.path.join(annotation_root, "top_loss_samples_epoch%s_generated.json") % str(epoch)
    if os.path.exists(annotation_file):
        with open(annotation_file, "r") as input_json:
            offline_sd_samples = json.load(input_json)
    else:
        print("no offline sd samples are available for epoch %s" % str(epoch))
        offline_sd_samples = []
    return offline_sd_samples

def get_raw2sd(sd_ann_file, rewrite_img_path=True):
    with open(sd_ann_file, "r") as input_json:
        sd_train = json.load(input_json)
        image_root = os.path.join(sd_ann_file.split("/")[-3], "images")
    # build corresponding origin_img - sd_img mapping
    origin2sd = {}
    for sample in sd_train:
        origin_image = sample["image"].split("_")[0] + ".jpg"
        if rewrite_img_path:
            # add image_root to image name
            origin2sd[origin_image] = os.path.join(image_root, sample["image"])
        else:
            origin2sd[origin_image] = sample["image"]
    return origin2sd

def main(args, config):
    utils.init_distributed_mode(args)    
    
    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True

    #### Dataset #### 
    print("Creating captioning dataset")
    # train_dataset, val_dataset, test_dataset = create_dataset('caption_%s'%config['dataset'], config) 
    dataset_builder = DataBuilder(config)
    print("Read train ann from %s" % config["train_file_name"])
    train_ann = json.load(open(os.path.join(config["ann_root"], config["train_file_name"]),'r'))
    sample_dict = dataset_builder.get_sample_dict(train_ann)  

    if args.sd_train:
        sd_ann = json.load(open(config["sd_ann_file"],'r'))
        train_dataset = dataset_builder.create_train(train_ann + sd_ann , epoch=0)
    else:
        train_dataset = dataset_builder.create_train(train_ann, epoch=0)
    
    val_dataset, test_dataset = dataset_builder.create_val_test()

    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()            
        samplers = create_sampler([train_dataset,val_dataset,test_dataset], [True,False,False], num_tasks, global_rank)         
    else:
        samplers = [None, None, None]
    
    train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
                                                          batch_size=[config['batch_size']]*3,num_workers=[4,4,4],
                                                          is_trains=[True, False, False], collate_fns=[None,None,None])   


    #### Model #### 
    print("Creating model")
    model = blip_decoder(
        pretrained=config['pretrained'], 
        image_size=config['image_size'], 
        vit=config['vit'], 
        vit_grad_ckpt=config['vit_grad_ckpt'], 
        vit_ckpt_layer=config['vit_ckpt_layer'],
        prompt=config['prompt']
    )

    model = model.to(device)
    
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module    
    
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
            
    best = 0
    best_epoch = 0

    if args.curation_option == "replace_img":
        print("we need to load an original to sd mapping first...")
        origin2sd_mapping = json.load(open(args.origin2sd_mapping, "r"))
        caption_dict = dataset_builder.get_caption_list(train_ann)
        print("loaded %s"%args.origin2sd_mapping)
    else:
        origin2sd_mapping = None
        caption_dict = None
        

    print("Start training")
    start_time = time.time() 

    for epoch in range(0, config['max_epoch']):
        # train 
        print("Starting Epoch: %d" % epoch)
        if not args.evaluate:        
            if args.distributed:
                if epoch!=0 and args.active_learning:
                    print("reinit train sampler and train loader")
                    train_sampler = reinit_sampler(train_dataset, True, num_tasks=num_tasks, global_rank=global_rank)
                    train_loader = reinit_loader(train_dataset, train_sampler, batch_size=config['batch_size'], num_workers=4, is_train=True, collate_fn=None)
                train_loader.sampler.set_epoch(epoch)
                
            cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
                
            train_stats = train(model, train_loader, optimizer, epoch, args.accum_iter, device) 

            if args.eval_train_loss:
                train_result_file = train_loss_eval(epoch, train_dataset, model, device, args)
                top_loss_samples, rest_samples = get_top_loss_samples(train_result_file, args)
                if args.distributed:
                    dist.barrier()
                save_top_loss_samples(args, epoch, top_loss_samples)

            # active learning process
            if args.active_learning and int(epoch) % int(args.every_train_eval) == 0:
                
                if args.curation_option == "offline_sd" or args.curation_option == "offline_remove":
                    # instead of using current top loss samples, we use offline samples obtained beforehand
                    train_result_file = os.path.join(args.offline_result_root, "train_epoch%s.json"%str(epoch))
                    assert os.path.exists(train_result_file) == True
                else:
                    train_result_file = train_loss_eval(epoch, train_dataset, model, device, args)
                
                if args.distributed:
                    dist.barrier()

                top_loss_samples, rest_samples = get_top_loss_samples(train_result_file, args)
                caption_dict_by_loss = sort_caption_by_train_loss(train_result_file)


                # curation with REMOVE approach
                if (args.curation_option == "remove" or args.curation_option == "offline_remove") and int(config['max_epoch']) - int(epoch) >= args.num_epoch_noal:
                    print("reinit train dataset for next epoch, number of samples included in new dataset: %d" % len(rest_samples))
                    train_dataset = dataset_builder.create_train(rest_samples)

                # curation with REPLACECAP approach
                if args.curation_option == "replace_cap": 
                    replace_samples = replace_top_loss_samples(top_loss_samples, sample_dict)
                    new_samples = rest_samples +  replace_samples
                    print("reinit train dataset for next epoch, number of samples included in new dataset: %d" % len(new_samples))
                    train_dataset = dataset_builder.create_train(new_samples)
                    print("peek from updated dataset: " , train_dataset.__getitem__(0))

                # curation with REPLACEIMG approach
                if args.curation_option == "replace_img" and int(config['max_epoch']) - int(epoch) >= args.num_epoch_noal: # reserve num epoch with no augmentation
                    # as we already have a full sd train set, here we can do online AL but avoid online sd generation
                    replace_samples, origin2sd_mapping, caption_dict = replace_top_loss_images_with_sd(top_loss_samples, origin2sd_mapping, caption_dict, caption_dict_by_loss, sample_dict, change_caption_by=args.change_caption)
                    new_samples = rest_samples +  replace_samples
                    print("reinit train dataset for next epoch, number of samples included in new dataset: %d" % len(new_samples))
                    json.dump(new_samples, open(os.path.join(args.output_dir, "updated_train_epoch%d.json"%epoch),"w"))
                    train_dataset = dataset_builder.create_train(new_samples)
                    print("peek from updated dataset: " , train_dataset.__getitem__(0))

                # online augmentation with stable diffusion generation
                if args.curation_option == "sd": 
                    # only generate at main gpu
                    # if utils.is_main_process():
                    print("generate sd images in main process...")
                    start_time = time.time()
                    sd_ann_file = batch_generate_sd_samples(top_loss_samples, sample_dict, epoch, args)
                    execution_time = (time.time() - start_time)
                    print("generation complete! total execution time: %s" % str(execution_time))
                    
                    # dist.barrier()
                    
                    sd_ann_root = os.path.join(config["aug_data_root"], "annotation")
                    if os.path.exists(os.path.join(sd_ann_root, "sd_ann_epoch_%d.json"%epoch)):
                        sd_samples = json.load(open(sd_ann_file, "r"))
                        new_samples = rest_samples + sd_samples
                        print("reinit train dataset for next epoch, containing %d rest samples and %d new sd samples" % \
                            (len(rest_samples), len(sd_samples)))
                        train_dataset = dataset_builder.create_train(new_samples)
                        # sample_dict = dataset_builder.get_sample_dict(new_samples)
                    
                if args.curation_option == "offline_sd":
                    # sd_samples = get_offline_sd_samples(epoch, config)
                    # new_samples = rest_samples + sd_samples

                    # prepare for next epoch's train
                    new_samples = json.load(open(os.path.join("/annotation/offline_train", "flickr30k_train_preprocessed_epoch%s.json"%str(epoch+1)),'r'))
                    print("reinit train dataset for next epoch by adding offline sd samples, number of samples included in new dataset: %d" % len(new_samples))
                    train_dataset = dataset_builder.create_train(new_samples, epoch=epoch)
                
                save_top_loss_samples(args, epoch, top_loss_samples)

        if args.train_eval:
            train_result_file = train_loss_eval(epoch, train_dataset, model, device, args)
            dist.barrier()
            print("loss eval for train samples saved to :", train_result_file)

        # eval
        print("Evaluate on dev...")
        val_result = evaluate(model_without_ddp, val_loader, device, config)  
        val_result_file = save_result(val_result, args.result_dir, 'val_epoch%d'%epoch, args, remove_duplicate='image_id')        
  
        print("Evaluate on test...")
        test_result = evaluate(model_without_ddp, test_loader, device, config)  
        test_result_file = save_result(test_result, args.result_dir, 'test_epoch%d'%epoch, args, remove_duplicate='image_id') 

        print("calculate val and test metrics...")
        if utils.is_main_process() or args.distributed==False:   
            if 'flickr_gt_root' in config:
                coco_val = flickr_caption_eval(config['flickr_gt_root'],val_result_file,'val')
                coco_test = flickr_caption_eval(config['flickr_gt_root'],test_result_file,'test')
            elif 'coco_gt_root' in config:
                coco_val = coco_caption_eval(config['coco_gt_root'],val_result_file,'val')
                coco_test = coco_caption_eval(config['coco_gt_root'],test_result_file,'test')
            
            if args.evaluate:            
                log_stats = {**{f'val_{k}': v for k, v in coco_val.eval.items()},
                             **{f'test_{k}': v for k, v in coco_test.eval.items()},                       
                            }
                with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
                    f.write(json.dumps(log_stats) + "\n")                   
            else:             
                save_obj = {
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'config': config,
                    'epoch': epoch,
                }

                if coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] > best:
                    best = coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4']
                    best_epoch = epoch                
                    torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth')) 
                    
                log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                             **{f'val_{k}': v for k, v in coco_val.eval.items()},
                             **{f'test_{k}': v for k, v in coco_test.eval.items()},                       
                             'epoch': epoch,
                             'best_epoch': best_epoch,
                            }
                with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
                    f.write(json.dumps(log_stats) + "\n")     
                    
        if args.evaluate: 
            break
        if args.distributed:
            dist.barrier()  

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str)) 


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='./configs/caption_coco.yaml')
    parser.add_argument('--output_dir', default='output/Caption_coco')        
    parser.add_argument('--evaluate', action='store_true', default=False) 
    parser.add_argument('--active_learning', type=bool, default=True)
    parser.add_argument('--every_train_eval', default=1, help="number of epoch that we update training samples based on loss") 
    parser.add_argument('--max_caption_len', default=20, help="maximum length of caption that we use to generate augment images")  
    parser.add_argument('--curation_ratio', default=0.05, type=float, help="ratio of top loss training samples")
    parser.add_argument('--curation_option', type=str, choices=["sd", "offline_sd", "remove", "offline_remove", "replace_cap", "replace_img"], help="generate new images or remove high loss images")
    parser.add_argument('--device', default='cuda')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')  
    parser.add_argument('--accum_iter', default=2, type=int, help='accumulate gradients by 2')  
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    parser.add_argument('--distributed', default=True, type=bool)
    parser.add_argument('--offline_result_root', type=str, required=False, help="folder that saved previous train loss samples for offline sd")
    parser.add_argument('--sd_batch_size', type=int, default=4, help="batch size for sd generation")
    parser.add_argument('--concat_prompt', type=bool, default=False, help="concat prompt for sd generation")
    parser.add_argument('--add_styler', type=bool, default=False, help="add styler for sd generation")
    parser.add_argument('--sd_train', type=bool, default=False)
    parser.add_argument('--eval_train_loss', type=bool, default=False)
    parser.add_argument('--dynamic_aug', action="store_true", default=False)
    parser.add_argument('--num_epoch_noal', type=int, default=0)
    parser.add_argument('--train_eval', action="store_true", default=False)
    parser.add_argument('--origin2sd_mapping', default="annotation/replace_img_mapping.json", type=str)
    parser.add_argument('--change_caption', type=str, default="loss", help="if true, change caption when change with sd images, either by loss or by caption length")

    args = parser.parse_args()

    config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)

    args.result_dir = os.path.join(args.output_dir, 'result')

    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    Path(args.result_dir).mkdir(parents=True, exist_ok=True)
    Path(config["aug_data_root"]).mkdir(parents=True, exist_ok=True)
        
    yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))  

    # prevent duplicate evaluation
    if args.active_learning:
        args.eval_train_loss = False  

    main(args, config)