# https://github.com/openai/CLIP/issues/83

import argparse
import os
import sys
import time
import glob
import gc

# import ruamel.yaml as yaml
import yaml
from ruamel.yaml import YAML
yaml = YAML(typ='rt')

import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
import pprint
import copy

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn.functional as F

from PIL import Image


from easydict import EasyDict as edict
from models.get_model import load_model

# Codes imported from https://github.com/salesforce/ALBEF/blob/main/Retrieval.py
from dataset import create_dataset_no_norm, create_sampler, create_loader
from scheduler import create_scheduler_each_step

# from optim import create_optimizer
from constants import images_normalize

from models.clip_model import clip

import utils.utils as utils
from utils.utils_attack import get_attacker, attack_batch_train
from utils.utils_eval import eval_pipeline
from utils.utils_optimizer import get_trainable_params, get_optimizer
from utils.utils_visualization import vis_img_txt_pairs
from utils.utils_loss import BCEwithProj

from attacks.MMA import eda

from utils.FARE import train as FARE_train
from utils.FARE import pgd_train as FARE_pgd_train

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion_kl = nn.KLDivLoss(reduction="sum").to(device)


def check_empty_text(text):
    # check empty text. debug
    new_text = []
    for t in text:
        if len(t) == 0:
            t = "a photo"
            print("Empty text: replaced with 'a photo'")
        new_text.append(t)
    return new_text


def eda_text(text, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=1):
    new_text = []
    for t in text:
        try:
            t = eda(t, alpha_sr=alpha_sr, alpha_ri=alpha_ri, alpha_rs=alpha_rs, p_rd=p_rd, num_aug=num_aug)[
                0
            ]
        except:
            print("Error in EDA")
            print("t:", t)
            t = "a photo"
        new_text.append(t)
    return new_text


def train(
    args,
    model,
    model_without_ddp,
    data_loader,
    optimizer,
    epoch,
    device,
    scheduler,
    attacker=None,
    model_ori=None,
):
    # train
    model.float()  # float32
    if model_ori is not None:
        model_ori.float()
    model.train()

    loss_list = []
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
    metric_logger.add_meter("loss", utils.SmoothedValue(window_size=1, fmt="{value:.4f}"))
    header = "Train Epoch: [{}]".format(epoch)
    print_freq = 50

    for batch_idx, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        if len(data) == 4:
            image, text, idx, gt_caps_list = data
        else:
            image, text, idx = data
        image = image.to(device, non_blocking=True)
        idx = idx.to(device, non_blocking=True)

        B = len(image)

        # NOTE: image is not normalized yet.

        # clean image
        # image_clean = image.clone().detach()

        # check empty text. debug
        text = check_empty_text(text)

        ###############################
        ##### online Text Aug. ########
        ###############################
        # online img aug is already done inside the data loader.
        if args.is_eda:
            text = eda_text(
                text,
                alpha_sr=args.alpha_sr,
                alpha_ri=args.alpha_ri,
                alpha_rs=args.alpha_rs,
                p_rd=args.p_rd,
                num_aug=1,
            )

        ###############################
        ###### Attack image/text ######
        ###############################
        # Note: Here, images should not be normalized. Normalization is applied inside the attack.
        if args.attack == "FARE":
            OUTPUT_NORM = False

            clean_image = image.clone().detach()

            # modified from FARE: https://github.com/chs20/RobustVLM/blob/main/train/adversarial_training_clip.py
            with torch.no_grad():
                normalaized_image = images_normalize(image)
                embedding_orig = model_ori.encode_image(normalaized_image)
                if OUTPUT_NORM:
                    embedding_orig = embedding_orig / embedding_orig.norm(dim=-1, keepdim=True)

            # loss for the attack
            model.eval()  # FARE attacks model in eval mode
            # 
            loss_inner_wrapper = FARE_train.ComputeLossWrapper(
                embedding_orig,
                None,  # not using embedding_text_labels_norm
                reduction="mean",  # mean for pgd
                loss="l2",
                logit_scale=None,
            )
            EPS = args.epsilon / 255.0
            image = FARE_pgd_train.pgd(
                forward=model.encode_image,
                images_normalize=images_normalize,
                loss_fn=loss_inner_wrapper,
                data_clean=image,  # image
                targets=None,  # not used
                norm="linf",
                eps=EPS,
                iterations=args.num_iters,  # default 10
                stepsize=args.step_size / 255.0,  # default 1.0 / 255.0
                output_normalize=OUTPUT_NORM,
                perturbation=torch.zeros_like(image).uniform_(-EPS, EPS).requires_grad_(True),
                mode="max",
                verbose=False,
            )
            del loss_inner_wrapper
            model.train()  # back to train mode

        elif args.attack == "TeCoA":
            # SupPGD with clip loss
            txt2img = np.arange(len(image)).tolist()
            image, text = attack_batch_train(
                args,
                args.attack,
                attacker,
                image,
                text,
                txt2img,
                device,
                return_pair=True,
            )
            image.detach_()

        elif args.attack is not None and attacker is not None:
            # normal
            txt2img = np.arange(len(image)).tolist()
            image, text = attack_batch_train(
                args,
                args.attack,
                attacker,
                image,
                text,
                txt2img,
                device,
                return_pair=True,
                gt_caps_list=gt_caps_list if args.is_use_gt_caps else None,
            )
            image.detach_()
        elif args.attack is None:
            # standard training
            pass
        else:
            raise ValueError("Invalid attack")

        ########################################
        ### visualize image and text pairs #####
        ########################################
        if epoch < 2 and batch_idx < 3:
            file_name = f"ep{epoch}_batch{batch_idx}"
            vis_img_txt_pairs(
                image.cpu().detach().numpy().transpose(0, 2, 3, 1), text, VIS_DIR, file_name, show_n=10
            )
            print("Visualized image-text pairs:", file_name)

        ######################################
        ###### normalize before forward ######
        ######################################
        image = images_normalize(image)
        # image_clean = images_normalize(image_clean)

        ################################
        ######### CLIP training ########
        ################################
        text_input_ids = clip.tokenize(text, truncate=True).to(device)

        outs = model(image, text_input_ids)
        # if not isinstance(outs, tuple):
        #     loss = outs
        # elif len(outs) == 3:
        if len(outs) == 3:
            logits_per_image, logits_per_text, _ = outs
            gt = torch.arange(len(image), dtype=torch.long, device=device)  # ground truth label (1: 1)

            if args.train_loss == "clip":
                _loss_img = loss_img(logits_per_image, gt)
                _loss_txt = loss_txt(logits_per_text, gt)
                loss = (_loss_img + _loss_txt) / 2

            elif args.train_loss == "clip_i":
                loss = loss_img(logits_per_image, gt)

            elif args.train_loss == "FARE":

                def l2(out, targets, reduction='none'):
                    # squared l2 - it does not divide by the latent dimension
                    # should have shape (batch_size, embedding_size)
                    assert out.shape == targets.shape, f'{out.shape} != {targets.shape}'
                    assert out.shape[0] > 1
                    # Compute the element-wise squared error
                    squared_error_batch = F.mse_loss(out, targets, reduction='none')
                    if reduction == 'mean':
                        squared_error_batch = torch.mean(squared_error_batch.sum(dim=1))
                    else:
                        squared_error_batch = squared_error_batch.sum(dim=1)
                        assert squared_error_batch.shape == (out.shape[0],), f'{squared_error_batch.shape} != {(out.shape[0],)}'
                    return squared_error_batch

                def compute_loss(loss_str, embedding, targets, embedding_orig, logit_scale,
                                embedding_text_labels_norm=None, reduction='mean'):
                    if loss_str == 'l2':
                        loss = l2(out=embedding, targets=embedding_orig, reduction=reduction)
                    elif loss_str == 'ce':
                        loss = ce(
                            out=embedding @ (logit_scale * embedding_text_labels_norm),
                            targets=targets,
                            reduction=reduction
                        )
                    else:
                        raise ValueError(f'loss {loss_str} not supported')
                    return loss

                output_normalize = True # default in https://github.com/chs20/RobustVLM/blob/main/train/adversarial_training_clip.py

                # embedding_clean = model(data, output_normalize=output_normalize)
                embeqdding_clean = model.encode_image(clean_image)
                if OUTPUT_NORM:
                    embedding_clean = embedding_clean / embedding_clean.norm(dim=-1, keepdim=True)
                
                CLEAN_WEIGHT = 0
                # if CLEAN_WEIGHT > 0.:
                #     loss_clean = compute_loss(
                #         loss_str=args.loss_clean, embedding=embedding_clean, targets=targets,
                #         embedding_orig=embedding_orig, logit_scale=100., embedding_text_labels_norm=None
                #         )
                # else:
                #     loss_clean = 0.

                # embedding_adv = model(data_adv, output_normalize=args.output_normalize)
                embedding_adv = model.encode_image(image)
                if OUTPUT_NORM:
                    embedding_adv = embedding_adv / embedding_adv.norm(dim=-1, keepdim=True)
                # del data, data_adv

                # if args.trades:
                #     embedding_clean_no_grad = embedding_clean.detach().clone()
                #     embedding_orig.cpu()

                LOSS = "l2"
                loss = compute_loss(
                    loss_str=LOSS, embedding=embedding_adv, targets=None,
                    embedding_orig=embedding_orig, # not trades
                    logit_scale=100., embedding_text_labels_norm=None, 
                    )
                # loss_total = args.clean_weight * loss_clean + (1 - args.clean_weight) * loss

            acc_per_image = (logits_per_image.argmax(1) == gt).float().mean()
            acc_per_text = (logits_per_text.argmax(1) == gt).float().mean()
            if random.random() < 0.01 or batch_idx == 0:
                print("Batch idx:", batch_idx)
                print("clip_loss:", loss)
                print("acc_per_image:", acc_per_image)
                print("acc_per_text:", acc_per_text)
        else:
            raise ValueError("Invalid output")

        # if loss is nan
        if torch.isnan(loss):
            print("loss is nan")
            exit()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        metric_logger.update(loss=loss.item())
        loss_list.append(loss.item())
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

        # step the scheduler
        n = epoch * len(data_loader) + batch_idx
        scheduler.step(n)

        if args.total_steps is not None and n >= args.total_steps:
            print("Total steps reached: ", n)
            break

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger.global_avg())
    return {k: "{:.6f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}, loss_list


def main(args, config):
    # utils.init_distributed_mode(args)

    device = torch.device(args.device)

    ########################
    ###### set seed ########
    ########################
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True

    ########################
    ###### load data #######
    ########################
    train_dataset, val_dataset, test_dataset, train_dataset_for_eval = create_dataset_no_norm(
        "re",
        config,
        get_train_eval=True,
        control_aug_ratio=args.control_aug_ratio,
        aug_n=args.aug_n,
        aug_m=args.aug_m,
        aug_scale=(args.aug_scale, 1.0),
        n_holes=args.n_holes,
        length_ratio=args.length_ratio,
        degrees=args.degrees,
        translate=args.translate,
        scale=args.scale,
        color_aug_strength=args.color_aug_strength,
    )
    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
    else:
        samplers = [None, None, None]

    if args.control_aug_ratio is not None:
        val_loader, test_loader, train_subset_loader = create_loader(
            [val_dataset, test_dataset, train_dataset_for_eval],
            samplers,
            batch_size=[config["batch_size_test"]] * 3,
            num_workers=[4] * 3,
            is_trains=[False, False, False],
            collate_fns=[None] * 3,
        )

        print("Control Augmentation Ratio:", args.control_aug_ratio)
        from dataset.control_ratio_sampler import ControlRatioSampler
        train_batch_sampler = ControlRatioSampler(
            train_dataset.ori_inds,
            train_dataset.aug_inds,
            config["batch_size_train"],
            aug_fraction=args.control_aug_ratio,
            num_batches=int(np.ceil(
                len(train_dataset.ori_inds) * (1 + args.control_aug_ratio) / config["batch_size_train"])
            ), # see all original data in one epoch
        )
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_batch_sampler, num_workers=4, pin_memory=True)

    else:
        samplers.append(None)
        train_loader, val_loader, test_loader, train_subset_loader = create_loader(
            [train_dataset, val_dataset, test_dataset, train_dataset_for_eval],
            samplers,
            batch_size=[config["batch_size_train"]] + [config["batch_size_test"]] * 3,
            num_workers=[4] * 4,
            is_trains=[True, False, False, False],
            collate_fns=[None] * 4,
        )

    if args.total_steps is not None:
        print("Total steps:", args.total_steps)
        train_config["schedular"]["epochs"] = args.total_steps // len(train_loader) + 1

    ########################
    ###### load model ######
    ########################
    print("Loading model")
    model, ref_model, tokenizer = load_model(
        config, args.model, args.ckpt, args.text_encoder, device=device, train_config=train_config
    )
    model_ori = None
    if args.is_pretrained_model_guided or args.attack == "FARE":
        print("Loading model_ori")
        model_ori, _, _ = load_model(
            config, args.model, args.ckpt, args.text_encoder, device=device, train_config=train_config
        )
        model_ori.eval()

    if args.evaluate:
        ckpt_path = args.eval_ckpt_path
        if ckpt_path is None:
            print("\n=====\nEVALUATING PRETRAINED MODEL!!!!!!\n=====\n")
        elif os.path.isfile(ckpt_path):
            print("Loading checkpoint:", ckpt_path)
            checkpoint = torch.load(ckpt_path, map_location="cpu")
            model.load_state_dict(checkpoint["model"])
        else:
            raise ValueError("Invalid ckpt path")

    if args.resume_ckpt is not None:
        print("Loading checkpoint:", args.resume_ckpt)
        checkpoint = torch.load(args.resume_ckpt, map_location="cpu")
        model.load_state_dict(checkpoint["model"])
        # optimizer.load_state_dict(checkpoint["optimizer"])
        # lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])

    model = model.to(device)
    ref_model = ref_model.to(device)

    # fix gaussian encoder
    if args.fix_gau:
        print("Fix gau_encoder")
        for name, param in model.named_parameters():
            if "gau_encoder" in name:
                param.requires_grad_(False)

    ########################
    ###### optimizer #######
    ########################
    if args.train_only_vision_encoder:
        print("Train only vision encoder")
        parameters = get_trainable_params(model.visual, train_config)
        # requires_grad = False
        for name, param in model.named_parameters():
            if "visual" not in name:
                param.requires_grad_(False)
    else:
        parameters = get_trainable_params(model, train_config)
    opt_config = utils.AttrDict(train_config["optimizer"])
    optimizer = get_optimizer(parameters, opt_config)
    arg_sche = utils.AttrDict(train_config["schedular"])
    lr_scheduler, _ = create_scheduler_each_step(arg_sche, optimizer, train_loader)

    if args.resume_ckpt is not None and args.is_resume_opt is not None:
        print("Resume Opt:", args.resume_ckpt)
        print("Epoch:", checkpoint["epoch"] + 1)
        checkpoint = torch.load(args.resume_ckpt, map_location="cpu")
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        resume_ep = checkpoint["epoch"] + 1

    # distributed training
    model_without_ddp = model
    if args.multi_gpu:
        # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=args.gpu,find_unused_parameters=True)
        print(args.gpu)
        model = torch.nn.DataParallel(model, device_ids=args.gpu)
        model.to(device)
        model_without_ddp = model.module

        # copy functions of model when using DDP
        # model.inference_image = model.module.inference_image
        # model.inference_text = model.module.inference_text
        # model.inference = model.module.inference
        # model.encode_image = model.module.encode_image
        # model.encode_text = model.module.encode_text

        # print(model.inference_image)

        print(f"Model is on: {next(model.parameters()).device}")
        print(f"Model parallel: {model.device_ids}")
        # s
        # print(f"Data is on: {data.device}")

    ########################
    ###### attacker ########
    ########################
    attacker = None
    if args.attack is not None:
        attacker = get_attacker(
            args,
            train_config,
            args.attack,
            model,  # for DDP
            # model_without_ddp,
            ref_model,
            tokenizer,
            attack_config=attack_config,
            eps=args.epsilon,
            steps=args.num_iters,
            step_size=args.step_size,
        )

    # note: eps, steps, step_size are default values
    eval_attacker_dict = {
        att: get_attacker(args, train_config, att, model_without_ddp, ref_model, tokenizer)
        for att in ATTACK_EVAL_LIST
    }

    ########################
    ###### training ########
    ########################
    max_epoch = train_config["schedular"]["epochs"]
    print("max_epoch:", max_epoch)

    print("Start training")
    start_time = time.time()
    loss_list_all = []
    start_ep = 0
    sum_steps = 0
    # start_ep = -1
    if args.is_resume_opt is not None:
        start_ep = resume_ep
    if args.evaluate:
        start_ep = max_epoch - 1
    for epoch in range(start_ep, max_epoch):  # -1: for init evaluation
        print("Epoch:", epoch)
        if train_config["schedular"]["freeze_backbone_epochs"] > 0:
            if epoch < train_config["schedular"]["freeze_backbone_epochs"]:
                print("Freeze backbone")
                model_without_ddp.freeze_backbone()
            else:
                print("Unfreeze backbone")
                model_without_ddp.unfreeze_backbone()
        train_stats = {}
        if not args.evaluate and epoch >= 0:
            if args.distributed:
                train_loader.sampler.set_epoch(epoch)
            if args.no_attack_warmup_epoch > epoch:
                print("No attack epoch: ", epoch)
                train_stats, loss_list = train(
                    args,
                    model,
                    model_without_ddp,
                    train_loader,
                    optimizer,
                    epoch,
                    device,
                    lr_scheduler,
                    attacker=None,
                    model_ori=model_ori,
                )
            else:
                train_stats, loss_list = train(
                    args,
                    model,
                    model_without_ddp,
                    train_loader,
                    optimizer,
                    epoch,
                    device,
                    lr_scheduler,
                    attacker=attacker,
                    model_ori=model_ori,
                )
            loss_list_all.extend(loss_list)

        # save
        save_obj = {
            "model": model_without_ddp.state_dict(),
            "optimizer": optimizer.state_dict(),
            "lr_scheduler": lr_scheduler.state_dict(),
            "config": config,
            "epoch": epoch,
        }
        if epoch % args.save_model_interval == 0:
            torch.save(save_obj, os.path.join(args.output_dir, f"checkpoint_ep{epoch}.pth"))
        if epoch == max_epoch - 1:
            torch.save(save_obj, os.path.join(args.output_dir, "checkpoint_last.pth"))

        ##################
        ###### eval ######
        ##################
        scores_dict = {}
        eval_results_dict = {}
        if utils.is_main_process() and not args.skip_eval:
            # # val
            # eval_result, score_matrix_i2t, score_matrix_t2i, feat_dict, _, _ = eval_pipeline(
            #     args, model_without_ddp, val_loader, tokenizer, device, config
            # )
            # eval_results_dict["val"] = eval_result
            # scores_dict["val"] = {"i2t": score_matrix_i2t, "t2i": score_matrix_t2i}

            # test
            eval_result, score_matrix_i2t, score_matrix_t2i, feat_dict, _, _ = eval_pipeline(
                args, model_without_ddp, test_loader, tokenizer, device, config
            )
            eval_results_dict["Clean"] = eval_result
            scores_dict["Clean"] = {"i2t": score_matrix_i2t, "t2i": score_matrix_t2i}

            if epoch == max_epoch - 1 or args.eval_every_epoch:
                # train subset
                if args.eval_train_subset:
                    print("~~ Start evaluation with train subset ~~")
                    # clean
                    eval_result, score_matrix_i2t, score_matrix_t2i, feat_dict, _, _ = eval_pipeline(
                        args, model_without_ddp, train_subset_loader, tokenizer, device, config
                    )
                    eval_results_dict["train_CLEAN"] = eval_result
                    scores_dict["train_CLEAN"] = {"i2t": score_matrix_i2t, "t2i": score_matrix_t2i}
                    
                    for att, eval_attacker in eval_attacker_dict.items():
                        print(f"Start evaluation with {att}")
                        eval_result, score_matrix_i2t, score_matrix_t2i, feat_dict, _, _ = eval_pipeline(
                            args,
                            model_without_ddp,
                            train_subset_loader,
                            tokenizer,
                            device,
                            config,
                            attacker=eval_attacker,
                            attack_name=att,
                        )
                        eval_results_dict[f"train_{att}"] = eval_result
                        scores_dict[f"train_{att}"] = {"i2t": score_matrix_i2t, "t2i": score_matrix_t2i}

                        print("eval result: ", att)
                        print(eval_result)

                if args.eval_val:
                    print("~~ Start evaluation with validation subset ~~")
                    # clean
                    eval_result, score_matrix_i2t, score_matrix_t2i, feat_dict, _, _ = eval_pipeline(
                        args, model_without_ddp, val_loader, tokenizer, device, config
                    )
                    eval_results_dict["val_CLEAN"] = eval_result
                    scores_dict["val_CLEAN"] = {"i2t": score_matrix_i2t, "t2i": score_matrix_t2i}
                    
                    for att, eval_attacker in eval_attacker_dict.items():
                        print(f"Start evaluation with {att}")
                        eval_result, score_matrix_i2t, score_matrix_t2i, feat_dict, _, _ = eval_pipeline(
                            args,
                            model_without_ddp,
                            train_subset_loader,
                            tokenizer,
                            device,
                            config,
                            attacker=eval_attacker,
                            attack_name=att,
                        )
                        eval_results_dict[f"val_{att}"] = eval_result
                        scores_dict[f"val_{att}"] = {"i2t": score_matrix_i2t, "t2i": score_matrix_t2i}

                        print("eval result: ", att)
                        print(eval_result)

                # test set
                print("~~ Start evaluation with test set ~~")
                for att, eval_attacker in eval_attacker_dict.items():
                    print(f"Start evaluation with {att}")
                    eval_result, score_matrix_i2t, score_matrix_t2i, feat_dict, _, _ = eval_pipeline(
                        args,
                        model_without_ddp,
                        test_loader,
                        tokenizer,
                        device,
                        config,
                        attacker=eval_attacker,
                        attack_name=att,
                    )
                    eval_results_dict[att] = eval_result
                    scores_dict[att] = {"i2t": score_matrix_i2t, "t2i": score_matrix_t2i}

                    print("eval result: ", att)
                    print(eval_result)
                
                

        ##################
        ###### log #######
        ##################
        log_stats = {
            **train_stats,
            **eval_results_dict,
            "epoch": epoch,
        }
        log_stats.update({"output_dir": args.output_dir})
        with open(os.path.join(args.output_dir, f"log.json"), "a") as f:
            json.dump(log_stats, f, indent=4)
            f.write("\n")
        print(eval_results_dict)

        # save eval results
        with open(os.path.join(args.output_dir, f"eval_results_ep{epoch}.json"), "w") as f:
            json.dump(eval_results_dict, f, indent=4)

        # dist.barrier()
        torch.cuda.empty_cache()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print("Training time {}".format(total_time_str))

    print(log_stats)

    # log txt
    with open(os.path.join(args.output_dir, "log.txt"), "w") as f:
        f.write(str(log_stats))

    # save the last log_stats
    with open(os.path.join(args.output_dir, "final_results.json"), "w") as f:
        json.dump(log_stats, f, indent=4)

    # save loss_list
    with open(os.path.join(args.output_dir, "loss_list.json"), "w") as f:
        json.dump(loss_list_all, f, indent=4)

    # save scores
    with open(os.path.join(args.output_dir, "scores_dict.npy"), "wb") as f:
        np.save(f, scores_dict)


def t2bool(t):
    if t.lower() == "true":
        return True
    elif t.lower() == "false":
        return False
    else:
        raise ValueError("Invalid value")


def t2fl(t):
    """text to float list"""
    return [float(x) for x in t.split(",")]


def t2il(t):
    """text to int list"""
    return [int(x) for x in t.split(",")]


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./configs/Retrieval_flickr_train.yaml")
    parser.add_argument("--train_file", default=None, type=str, nargs="+")
    parser.add_argument("--seed", default=42, type=int)

    parser.add_argument("--model", default="CLIP_ViT-B-16_PT", type=str)  # model architecture
    parser.add_argument("--model_name", default="CLIP_ViT-B-16_PT", type=str)  # id for the model
    parser.add_argument("--text_encoder", default="bert-base-uncased", type=str)
    parser.add_argument("--ckpt", default=None, type=str)

    # training config
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--world_size", default=1, type=int, help="number of distributed processes")
    parser.add_argument("--gpu", default=0, nargs="+", type=int, help="GPU id to use.")
    parser.add_argument("--dist_url", default="env://", help="url used to set up distributed training")
    parser.add_argument("--distributed", default=False, type=bool)
    parser.add_argument("--multi_gpu", default=False, type=bool)

    # adversarial training config
    parser.add_argument(
        "--attack",
        default=None,
        type=str,
        choices=[
            "SGA",
            "Co-Attack",
            "TeCoA",
            "TeCoA_Orig",
            "FARE",
            "Sep-Attack",
            "PGD",
            "BERT",
            "Clean",
            "FSGA",
            "PDE-MMA",
            "SupPGD",
            "UnsupPGD",
            "MMA",
            "Bert-Sup",
            "EDA-Sup"
        ],
    )
    parser.add_argument("--is_rand_mask", default=False, action="store_true")
    parser.add_argument("--attack_fused_emb", default=False, type=bool)
    parser.add_argument("--cls", default=False, type=bool)
    parser.add_argument("--output_dir", default="../train_results", type=str)
    parser.add_argument("--epsilon", default=2.0, type=float)
    parser.add_argument("--alpha", default=3.0, type=float)  # for Co-Attack
    parser.add_argument("--num_iters", default=10, type=int)
    parser.add_argument("--step_size", default=0.5, type=float)

    # dataset
    parser.add_argument(
        "--caps_k", default=None, type=int
    )  # how many captions per image is used for training.
    
    ### augmentation
    parser.add_argument("--img_aug_type", default="randaug", 
        type=str, 
        choices=["randaug", "affine", "color", "cutout"]
    )
    # randaug
    parser.add_argument("--aug_n", default=2, type=int)
    parser.add_argument("--aug_m", default=7, type=int)
    parser.add_argument("--aug_scale", default=0.5, type=float)
    # cutout
    parser.add_argument("--n_holes", default=1, type=int)
    parser.add_argument("--length_ratio", default=0.5, type=float)
    # affine
    parser.add_argument("--degrees", default=20, type=int)
    parser.add_argument("--translate", default=0.2, type=float)
    parser.add_argument("--scale", default=0.1, type=float)
    # color
    parser.add_argument("--color_aug_strength", default=0.5, type=float)
    # to control the ratio of augmented data (prepared json file)
    parser.add_argument("--control_aug_ratio", default=None, type=float)

    parser.add_argument("--is_aug_image_feat", default=False, type=t2bool)
    parser.add_argument("--aug_image_feat_alpha", default=0.1, type=float)


    parser.add_argument("--is_eda", default=False, type=t2bool, help="easy data augmentation")

    # FSGA config
    parser.add_argument("--scale_ver", default=0, type=int)
    parser.add_argument("--txt_att_k", default=0, type=int)
    parser.add_argument("--txt_attack", default=None, type=str, choices=["rand", "bert"])
    parser.add_argument("--img_attack_loss", default=False, type=str)

    # MMA config
    parser.add_argument("--is_use_gt_caps", default=False, type=t2bool)
    parser.add_argument(
        "--txt_sup_k", default=1, type=int
    )  # if > 1, use augmented texts for text-supervised image attack
    parser.add_argument("--alpha_sr", default=0.1, type=float)
    parser.add_argument("--alpha_ri", default=0.1, type=float)
    parser.add_argument("--alpha_rs", default=0.1, type=float)
    parser.add_argument("--p_rd", default=0.1, type=float)
    parser.add_argument("--alpha_unsup", default=0.0, type=float)
    parser.add_argument("--alpha_sup", default=1.0, type=float)
    parser.add_argument("--is_txt_aug", default=False, type=t2bool)
    parser.add_argument("--txt_aug", default="sr", type=str)
    parser.add_argument("--curric_eps", default=None, type=t2fl)
    parser.add_argument("--curric_iter", default=None, type=t2il)

    # train config
    parser.add_argument("--train_loss", default="clip", type=str)
    parser.add_argument("--train_only_vision_encoder", default=False, type=t2bool)
    parser.add_argument("--train_config", default=None, type=str)
    parser.add_argument("--fix_gau", default=False, action="store_true")
    parser.add_argument("--label_smoothing", default=0.0, type=float)
    parser.add_argument("--is_aug_txt", default=False, type=t2bool)
    parser.add_argument("--aug_alpha", default=0.3, type=float)
    parser.add_argument("--no_attack_warmup_epoch", default=0, type=int)

    # pre-trained model guided adversarial training: https://github.com/serendipity1122/Pre-trained-Model-Guided-Fine-Tuning-for-Zero-Shot-Adversarial-Robustness/
    parser.add_argument("--is_pretrained_model_guided", default=False, action="store_true")

    # other
    parser.add_argument("--skip_eval", default=False, action="store_true")
    parser.add_argument("--eval_every_epoch", default=False, action="store_true")
    parser.add_argument("--resume_ckpt", default=None, type=str)
    parser.add_argument("--is_resume_opt", default=None, type=str)

    # evaluation
    parser.add_argument("--evaluate", action="store_true")
    parser.add_argument("--eval_ckpt_path", default=None, type=str)
    parser.add_argument("--eval_only_clean", default=False, action="store_true")
    parser.add_argument("--eval_train_subset", default=True, action="store_true")
    parser.add_argument("--eval_val", default=True, action="store_true")

    # overwrite config for grid search
    parser.add_argument("--lr", default=None, type=float)
    parser.add_argument("--pde_mul_lr", default=None, type=float)
    parser.add_argument("--epochs", default=None, type=int)
    parser.add_argument("--total_steps", default=None, type=int)
    parser.add_argument("--batch_size", default=None, type=int)
    parser.add_argument("--lr_scheduler", default=None, type=str)

    # experiment
    parser.add_argument("--mark", default=None, type=str)
    parser.add_argument("--overwrite", default=False, type=str)
    parser.add_argument("--save_model_interval", default=1, type=int)

    args = parser.parse_args()

    assert args.model not in ["ALBEF", "ALBEF_PT", "TCL", "TCL_PT"]

    ATTACK_EVAL_LIST = ["SupPGD", "UnsupPGD", "BERT", "Co-Attack", "SGA"]
    if args.eval_only_clean:
        ATTACK_EVAL_LIST = []
    # ATTACK_EVAL_LIST = ["PGD"]
    # ATTACK_EVAL_LIST = ["SGA"]

    config = yaml.load(open(args.config, "r"))
    config = edict(config)

    train_config = yaml.load(open(args.train_config, "r"))
    train_config = edict(train_config)

    attack_config = train_config["attack"]

    ################################
    ## Use multi-gpu if available ##
    ################################
    if torch.cuda.device_count() > 1:
        args.multi_gpu = True
        args.world_size = torch.cuda.device_count()
        # args.gpu = args.world_size - 1
        args.gpu = [i for i in range(args.world_size)]
        print("\n******** Multi-GPU training ********")
        print("- World size:", args.world_size)
        print("- GPUs:", args.gpu)
        print("\n")

    ########################################
    ### overwrite config for grid search ###
    ########################################
    if args.lr is not None:
        train_config["optimizer"]["lr"] = args.lr
        print("Overwrite lr:", args.lr)
    if args.pde_mul_lr is not None:
        train_config["pde_mul_lr"] = args.pde_mul_lr
        print("Overwrite pde_mul_lr:", args.pde_mul_lr)
    if args.epochs is not None:
        train_config["schedular"]["epochs"] = args.epochs
        print("Overwrite epochs:", args.epochs)
    if args.batch_size is not None:
        config["batch_size_train"] = args.batch_size
        print("Overwrite batch_size:", args.batch_size)
    if args.lr_scheduler is not None:
        train_config["schedular"]["sched"] = args.lr_scheduler
        print("Overwrite lr_scheduler:", args.lr_scheduler)

    if args.is_use_gt_caps:
        config["is_return_set_data"] = True
        attack_config["MMA"]["is_use_gt_caps"] = True
    if args.caps_k is not None:
        config["caps_k"] = args.caps_k

    if args.attack == "MMA":
        attack_config["MMA"]["is_use_gt_caps"] = args.is_use_gt_caps
        attack_config["MMA"]["txt_sup_k"] = args.txt_sup_k
        attack_config["MMA"]["alpha_sr"] = args.alpha_sr
        attack_config["MMA"]["alpha_ri"] = args.alpha_ri
        attack_config["MMA"]["alpha_rs"] = args.alpha_rs
        attack_config["MMA"]["p_rd"] = args.p_rd
        attack_config["MMA"]["alpha_unsup"] = args.alpha_unsup
        attack_config["MMA"]["alpha_sup"] = args.alpha_sup
        attack_config["MMA"]["is_txt_aug"] = args.is_txt_aug
        attack_config["MMA"]["txt_aug"] = args.txt_aug
        attack_config["MMA"]["scale_ver"] = args.scale_ver

        attack_config["MMA"]["is_aug_txt"] = args.is_aug_txt
        print("is_aug_txt:", args.is_aug_txt)

    ############################
    ## create output directory
    ############################
    dataset_name = config["dataset_name"]

    if args.train_file is not None:
        config["train_file"] = args.train_file

    exp_id = ""
    if len(config["train_file"]) >= 1:
        add_file = ""
        for i, train_file in enumerate(config["train_file"]):
            add_file += f"{i}=" + train_file.split("/")[-1].split(".")[0][-10:] + "="
        exp_id += "DATA-" + add_file + "-"
    if args.caps_k is not None:
        exp_id += f"UseCaps={args.caps_k}-"
    if args.resume_ckpt is not None:
        exp_id += "RESUME-"
    if "PCMEPP" in args.model:
        if train_config["pcmepp"]["model"]["is_probabilistic_model"]:
            exp_id += "PCMEPP-"
    if "PT" in args.model:
        exp_id += f"v{train_config['vision_depth']}_l{train_config['language_depth']}_vn{train_config['vision_ctx']}_ln{train_config['language_ctx']}"
    exp_id += f"_bs{config['batch_size_train']}"
    exp_id += f"{train_config['optimizer']['opt']}_lr{train_config['optimizer']['lr']}_wd{train_config['optimizer']['weight_decay']}"
    if args.total_steps is not None:
        exp_id += f"_totalStep{args.total_steps}"
    else:
        exp_id += f"_ep{train_config['schedular']['epochs']}"

    if train_config["schedular"]["warmup_steps"] != 0:
        exp_id += f"_warmup{train_config['schedular']['warmup_steps']}"
    if args.lr_scheduler is not None:
        exp_id += f"_{args.lr_scheduler}"
    if args.label_smoothing != 0.0:
        exp_id += f"_ls{args.label_smoothing}"
    if args.aug_n != 2:
        exp_id += f"_augN{args.aug_n}"
    if args.aug_m != 7:
        exp_id += f"_augM{args.aug_m}"
    if args.aug_scale != 0.5:
        exp_id += f"_augS{args.aug_scale}"

    if args.is_eda:
        exp_id += "_eda"

    date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    # attack name
    if args.attack is not None:
        DIR_NAME = f"{args.attack}-iters{args.num_iters}-step{args.step_size}"
        if "FSGA" in args.attack or "SupPGD" in args.attack or "UnsupPGD" in args.attack:
            DIR_NAME += f"-scale{args.scale_ver}"
            if args.is_use_gt_caps:
                DIR_NAME += "-txt-gtCaps"
            else:
                DIR_NAME += f"-txt-att-k{args.txt_att_k}"
            if args.img_attack_loss is not None:
                DIR_NAME += f"-img-{args.img_attack_loss}"
        if "PDE-MMA" in args.attack:
            DIR_NAME += f"-scale{args.scale_ver}"
            if args.img_attack_loss is not None:
                DIR_NAME += f"-img-{args.img_attack_loss}"
            if args.txt_attack is not None:
                DIR_NAME += f"-txt-{args.txt_attack}"
        if args.attack == "MMA":
            DIR_NAME += f"-scale{args.scale_ver}"
            att_cfg = attack_config["MMA"]
            if args.is_use_gt_caps:
                DIR_NAME += "-txt-gtCaps"
            else:
                DIR_NAME += f"-txtSup{att_cfg.txt_sup_k}-sr{att_cfg.alpha_sr}-ri{att_cfg.alpha_ri}-rs{att_cfg.alpha_rs}-rd{att_cfg.p_rd}"
            DIR_NAME += f"-img-unsup{att_cfg.alpha_unsup}-sup{att_cfg.alpha_sup}"
            if att_cfg.is_txt_aug:
                DIR_NAME += f"-txtAug-{att_cfg.txt_aug}"

        if args.no_attack_warmup_epoch > 0:
            DIR_NAME += f"-noAtt{args.no_attack_warmup_epoch}"

        if args.curric_eps is not None:
            assert len(args.curric_eps) == len(args.curric_iter)
            assert len(args.curric_eps) == train_config["schedular"]["epochs"]
            DIR_NAME += f"-curricEps{args.curric_eps}-Iter{args.curric_iter}"

    else:
        DIR_NAME = "std"
    if args.is_rand_mask:
        DIR_NAME += "-rand-mask"
    if args.fix_gau:
        DIR_NAME += "-fixGau"

    if args.control_aug_ratio is not None:
        DIR_NAME = f"controlAugRatio{args.control_aug_ratio}" + DIR_NAME


    DIR_NAME += f"-trainLoss={args.train_loss}"

    # add mark
    if args.mark is not None:
        DIR_PATH = os.path.join(args.output_dir, dataset_name, args.model_name, args.mark, DIR_NAME)
    else:
        DIR_PATH = os.path.join(args.output_dir, dataset_name, args.model_name, exp_id, DIR_NAME)

    print("Exp ID:", exp_id)
    
    if os.path.exists(DIR_PATH):
        if not args.overwrite:
            results_paths = glob.glob(os.path.join(DIR_PATH, "*/final_results.json"))
            if len(results_paths) > 0:
                print("Already exists:", DIR_PATH)
                print("Skip\n\n")
                exit()

    args.output_dir = os.path.join(DIR_PATH, date)
    os.makedirs(args.output_dir, exist_ok=True)
    print("\n\n")
    print("Output directory:", args.output_dir)
    print("\n\n")

    if args.evaluate:
        print("Evaluate mode")
        if args.eval_ckpt_path is not None:
            if os.path.isfile(args.eval_ckpt_path):
                # args.eval_ckpt_path = "/path/to/checkpoint.pth"
                output_dir = os.path.dirname(args.eval_ckpt_path)
                print("Output directory:", output_dir)
                args.output_dir = output_dir

    VIS_DIR = os.path.join(args.output_dir, "vis")
    os.makedirs(VIS_DIR, exist_ok=True)

    # log print
    sys.stdout = utils.Tee(sys.stdout, open(os.path.join(args.output_dir, "out.txt"), "w"))

    ############################
    ## save args
    ############################
    with open(os.path.join(args.output_dir, "args.json"), "w") as f:
        json.dump(vars(args), f, indent=4)
    # save config
    with open(os.path.join(args.output_dir, "config.json"), "w") as f:
        json.dump(config, f, indent=4)
    if args.train_config is not None:
        with open(os.path.join(args.output_dir, "train_config.json"), "w") as f:
            json.dump(train_config, f, indent=4)

    loss_img = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing).to(device)
    loss_txt = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing).to(device)

    main(args, config)
