"""
attack utils
"""

import sys
import numpy as np
import torch

# attacks
from attacks.fast_set_attack import (
    FSGAttacker,
    ImageAttacker as FSGA_ImageAttacker,
    TextAttacker as FSGA_TextAttacker,
)
from attacks.fast_set_attack_pde import PDE_MMAttacker, PDE_ImageAttacker, PDE_TextAttacker
from attacks.SupPGD import (
    SupPGD,
    ImageAttacker as SupPGD_ImageAttacker,
)
from attacks.UnsupPGD import (
    UnsupPGD,
    ImageAttacker as UnsupPGD_ImageAttacker,
    TextAttacker as UnsupPGD_TextAttacker,
)
from attacks.MMA import (
    MMA,
    ImageAttacker as MMA_ImageAttacker,
)

# SGA
# sys.path.append("SGA")
from attacks.SGA.attacker import (
    SGAttacker as SGAttacker,
    ImageAttacker as SGA_ImageAttacker,
    TextAttacker as SGA_TextAttacker,
)

# ACMMM2022: https://github.com/adversarial-for-goodness/Co-Attack/tree/main
# attack image_embed and text_embed
from attacks.co_attack_modified import ImageAttacker as PGD_ImageAttacker
from attacks.co_attack_modified import BertAttack, BertAttackFusion
from attacks.co_attack_modified import MultiModalAttacker as CoAttacker

from attacks.MMA_clean import MultiAttacker, eda

from constants import images_normalize

from easydict import EasyDict as edict


def set_mode_for_attack(model):
    """
    model with RNN should be in train mode to conduct attack.
    e.g.) GRU, LSTM
    """
    model.train()
    for m in model.modules():
        if m.__class__.__name__.startswith("GRU"):
            m.train()
            print(f"Set {m.__class__.__name__} in train mode.")


def get_attacker(
    args,
    train_config,
    attack_name,
    model,
    ref_model,
    tokenizer,
    cls=False,
    attack_fused_emb=False,
    eps=2.0,
    steps=10,
    step_size=0.5,
    **kwargs,
):
    if attack_name == "SGA":
        img_attacker = SGA_ImageAttacker(
            images_normalize, eps=eps / 255.0, steps=steps, step_size=step_size / 255.0
        )
        txt_attacker = SGA_TextAttacker(
            ref_model,
            tokenizer,
            cls=False,
            max_length=30,
            number_perturbation=1,
            topk=10,
            threshold_pred_score=0.3,
        )
        attacker = SGAttacker(model, img_attacker, txt_attacker)
    elif attack_name in ["BERT", "PGD", "Sep-Attack", "Co-Attack", "Clean"]:
        image_attacker = PGD_ImageAttacker(
            eps / 255.0,
            preprocess=images_normalize,
            bounding=(0, 1),
            cls=cls,
        )
        if attack_fused_emb:
            text_attacker = BertAttackFusion(ref_model, tokenizer, cls=cls)
        else:
            text_attacker = BertAttack(ref_model, tokenizer, cls=cls)
        attacker = CoAttacker(model, image_attacker, text_attacker, tokenizer, cls=cls)
    elif attack_name == "Bert-Sup":
        # Text Attack + Supervised Image Attack
        image_attacker = MMA_ImageAttacker(
            images_normalize,
            eps=eps / 255.0,
            steps=steps,
            step_size=step_size / 255.0,
            alpha_unsup=0,
            alpha_sup=1,
        )
        if attack_fused_emb:
            text_attacker = BertAttackFusion(ref_model, tokenizer, cls=cls)
        else:
            text_attacker = BertAttack(ref_model, tokenizer, cls=cls)
        attacker = MultiAttacker(model, image_attacker, text_attacker, tokenizer)
    elif attack_name == "EDA-Sup":
        # Text Attack + Supervised Image Attack
        image_attacker = MMA_ImageAttacker(
            images_normalize,
            eps=eps / 255.0,
            steps=steps,
            step_size=step_size / 255.0,
            alpha_unsup=0,
            alpha_sup=1,
        )
        class EDAAttacker:
            def __init__(self, alpha=0.3):
                self.alpha = alpha

            def attack(self, texts, **kwargs):
                if "alpha" in kwargs:
                    self.alpha = kwargs["alpha"]
                # texts: batch
                texts = [
                    eda(
                        sentence, 
                        alpha_sr=self.alpha, 
                        alpha_ri=self.alpha, 
                        alpha_rs=self.alpha, 
                        p_rd=self.alpha, 
                        num_aug=1
                    )[0]
                    for sentence in texts
                ]
                return texts
        text_attacker = EDAAttacker(alpha=0.2)
        attacker = MultiAttacker(model, image_attacker, text_attacker, tokenizer)
    elif attack_name == "FSGA":
        img_attacker = FSGA_ImageAttacker(
            images_normalize, eps=eps / 255.0, steps=steps, step_size=step_size / 255.0
        )
        txt_attacker = FSGA_TextAttacker(tokenizer)
        attacker = FSGAttacker(model, img_attacker, txt_attacker)
    elif attack_name == "PDE-MMA":
        # print(next(model.parameters()).is_cuda)
        # print(next(ref_model.parameters()).is_cuda)
        img_attacker = PDE_ImageAttacker(
            train_config,
            args.img_attack_loss,
            images_normalize,
            eps=eps / 255.0,
            steps=steps,
            step_size=step_size / 255.0,
        )
        txt_attacker = PDE_TextAttacker(ref_model, tokenizer, cls=cls)
        attacker = PDE_MMAttacker(model, img_attacker, txt_attacker)
    elif attack_name in ["SupPGD", "TeCoA", "TeCoA_Orig"]:
        img_attacker = SupPGD_ImageAttacker(
            images_normalize, eps=eps / 255.0, steps=steps, step_size=step_size / 255.0
        )
        attacker = SupPGD(model, img_attacker, tokenizer)
    elif attack_name == "UnsupPGD":
        img_attacker = UnsupPGD_ImageAttacker(
            images_normalize, eps=eps / 255.0, steps=steps, step_size=step_size / 255.0, cls=cls
        )
        txt_attacker = UnsupPGD_TextAttacker(tokenizer)
        attacker = UnsupPGD(model, img_attacker, txt_attacker)
    elif attack_name == "MMA":
        attack_config = kwargs["attack_config"]["MMA"]
        img_attacker = MMA_ImageAttacker(
            images_normalize,
            eps=eps / 255.0,
            steps=steps,
            step_size=step_size / 255.0,
            alpha_unsup=attack_config["alpha_unsup"],
            alpha_sup=attack_config["alpha_sup"],
        )
        attacker = MMA(model, img_attacker, tokenizer, attack_config)

    # for analysis
    elif attack_name == "Sup1-1":
        scale_ver, txt_sup_k = 0, 1
        alpha = 0.1

        attack_config = edict()
        attack_config["is_use_gt_caps"] = False
        attack_config["alpha_sr"] = alpha
        attack_config["alpha_ri"] = alpha
        attack_config["alpha_rs"] = alpha
        attack_config["p_rd"] = alpha
        attack_config["alpha_unsup"] = 0
        attack_config["alpha_sup"] = 1
        attack_config["is_txt_aug"] = False
        attack_config["txt_aug"] = "rand"
        attack_config["scale_ver"] = scale_ver
        attack_config["txt_sup_k"] = txt_sup_k
        img_attacker = MMA_ImageAttacker(
            images_normalize,
            eps=eps / 255.0,
            steps=steps,
            step_size=step_size / 255.0,
            alpha_unsup=attack_config["alpha_unsup"],
            alpha_sup=attack_config["alpha_sup"],
        )
        attacker = MMA(model, img_attacker, tokenizer, attack_config)
    elif attack_name == "Sup5-5":
        scale_ver, txt_sup_k = 2, 5
        alpha = 0.1

        attack_config = edict()
        attack_config["is_use_gt_caps"] = False
        attack_config["alpha_sr"] = alpha
        attack_config["alpha_ri"] = alpha
        attack_config["alpha_rs"] = alpha
        attack_config["p_rd"] = alpha
        attack_config["alpha_unsup"] = 0
        attack_config["alpha_sup"] = 1
        attack_config["is_txt_aug"] = False
        attack_config["txt_aug"] = "rand"
        attack_config["scale_ver"] = scale_ver
        attack_config["txt_sup_k"] = txt_sup_k
        img_attacker = MMA_ImageAttacker(
            images_normalize,
            eps=eps / 255.0,
            steps=steps,
            step_size=step_size / 255.0,
            alpha_unsup=attack_config["alpha_unsup"],
            alpha_sup=attack_config["alpha_sup"],
        )
        attacker = MMA(model, img_attacker, tokenizer, attack_config)

    elif attack_name == "FARE":
        return None

    else:
        raise ValueError(f"Invalid attack mode: {attack_name}")

    return attacker


def attack_batch_eval(
    args,
    attack_name,
    attacker,
    images,
    texts,
    txt2img,
    device,
    num_iters=10,  # default
    alpha=3.0,  # alpha for Co-Attack. Unsup + alpha * Sup.
    return_pair=False,
    attack_fused_emb=False,
):
    """
    Attack batch for evaluation.
    Fixed number of iterations for all attacks, fixed configuration.

    images: torch.Tensor
        (B, 3, H, W)
    texts: list
        length (B * 5)
    txt2img: list
    """
    img_shape = images.shape
    txt_shape = len(texts)
    if attack_name == "SGA":
        scales = [0.5, 0.75, 1.25, 1.5]
        adv_images, adv_texts = attacker.attack(
            images, texts, txt2img, device=device, max_lemgth=30, scales=scales
        )
    elif attack_name in ["BERT", "PGD", "Sep-Attack", "Co-Attack"]:
        if attack_name == "BERT":
            adv_mode = 1
        elif attack_name == "PGD":
            adv_mode = 2
        elif attack_name == "Sep-Attack":
            adv_mode = 3
        elif attack_name == "Co-Attack":
            adv_mode = 4
        elif attack_name == "Clean":
            adv_mode = 0
        else:
            raise ValueError(f"Invalid attack mode: {attack_name}")
        if attack_fused_emb:
            # Note: CoAttack for fused embeddings requires paired input (image, text).
            B = len(images)
            # img2txt_n: how many texts for each image
            img2txt_n = [txt2img.count(i) for i in range(B)]
            repeated_images = torch.stack(
                [im for i, (im, n) in enumerate(zip(images, img2txt_n)) for _ in range(n)], dim=0
            )
            adv_images, adv_texts = attacker.run(
                repeated_images,
                texts,
                adv=adv_mode,
                num_iters=num_iters,
                alpha=alpha,
            )
            if not return_pair:
                indice = np.cumsum(img2txt_n) - 1
                adv_images = adv_images[indice]
            # print(len(adv_images), len(adv_texts))
        else:
            # Note: CoAttack for fused embeddings requires paired input (image, text).
            max_length = 77 if "CLIP" in args.model else 1e3
            B = len(images)
            img2txt_n = [txt2img.count(i) for i in range(B)]
            repeated_images = torch.stack(
                [im for i, (im, n) in enumerate(zip(images, img2txt_n)) for _ in range(n)], dim=0
            )
            adv_images, adv_texts = attacker.run_before_fusion(
                repeated_images,
                texts,
                adv=adv_mode,
                num_iters=num_iters,
                alpha=alpha,
                max_length=max_length,
            )
            if not return_pair:
                indice = np.cumsum(img2txt_n) - 1
                adv_images = adv_images[indice]
    elif attack_name == "SupPGD":
        # use all GT texts to perturb images
        adv_images, _ = attacker.attack(
            images,
            texts,
            txt2img,
            device=device,
            max_length=30,
            scales=[],
            txt_att_k=0,
            txt_attack=None,
            img_attack_loss="sim",
            gt_caps_list=None,
            is_train=False,
        )
        adv_texts = texts
    elif attack_name == "UnsupPGD":
        adv_images, _ = attacker.attack(
            images, texts, device=device, max_length=30, txt_att_k=0, is_train=False
        )
        adv_texts = texts
    elif attack_name == "MMA":
        adv_images, adv_texts = attacker.attack(
            images,
            texts,
            txt2img,
            device=device,
            max_length=30,
            scales=None,  # Use predefined scales: set in attack_config!
            gt_caps_list=None,  # only used if gt_caps is not None
            is_train=False,
        )
    elif attack_name in ["Sup1-1", "Sup5-5"]:
        assert len(images) == len(texts), f"{len(images)} != {len(texts)}"
        adv_images, _ = attacker.attack(
            images,
            texts,
            txt2img,
            device=device,
            max_length=30,
            scales=None,  # Use predefined scales: set in attack_config!
            gt_caps_list=None,  # only used if gt_caps is not None
            is_train=False,
        )
        adv_texts = texts
    else:
        raise ValueError(f"Invalid attack mode: {attack_name}")
    if return_pair:
        assert len(adv_images) == len(adv_texts), f"{len(adv_images)} != {len(adv_texts)}"

    assert adv_images.shape == img_shape, f"{adv_images.shape} != {img_shape}"
    assert len(adv_texts) == txt_shape, f"{len(adv_texts)} != {txt_shape}"
    return adv_images, adv_texts


def attack_batch_train(
    args,
    attack_name,
    attacker,
    images,
    texts,
    txt2img,
    device,
    return_pair=False,
    gt_caps_list=None,
):
    """
    Attack batch for training.
    change args for different configurations.

    images: torch.Tensor
        (B, 3, H, W)
    texts: list
        length (B * 5)
    txt2img: list
    """
    if attack_name in ["FSGA", "PDE-MMA", "SupPGD"]:
        # scales = [0.5,0.75,1.25,1.5]
        scales_list = [
            [],
            [0.75, 1.25],
            [0.5, 0.75, 1.25, 1.5],
        ]
        scales = scales_list[args.scale_ver]
        txt_att_k = args.txt_att_k
        adv_images, adv_texts = attacker.attack(
            images,
            texts,
            txt2img,
            device=device,
            max_length=30,
            scales=scales,
            txt_att_k=txt_att_k,
            txt_attack=args.txt_attack,
            img_attack_loss=args.img_attack_loss,
            gt_caps_list=gt_caps_list,  # only used if gt_caps is not None
        )
    elif attack_name in ["BERT", "PGD", "Sep-Attack", "Co-Attack"]:
        if attack_name == "BERT":
            adv_mode = 1
        elif attack_name == "PGD":
            adv_mode = 2
        elif attack_name == "Sep-Attack":
            adv_mode = 3
        elif attack_name == "Co-Attack":
            adv_mode = 4
        elif attack_name == "Clean":
            adv_mode = 0
        else:
            raise ValueError(f"Invalid attack mode: {attack_name}")
        if args.attack_fused_emb:
            # Note: CoAttack for fused embeddings requires paired input (image, text).
            B = len(images)
            # img2txt_n: how many texts for each image
            img2txt_n = [txt2img.count(i) for i in range(B)]
            repeated_images = torch.stack(
                [im for i, (im, n) in enumerate(zip(images, img2txt_n)) for _ in range(n)], dim=0
            )
            adv_images, adv_texts = attacker.run(
                repeated_images,
                texts,
                adv=adv_mode,
                num_iters=args.num_iters,
                alpha=args.alpha,
            )
            if not return_pair:
                indice = np.cumsum(img2txt_n) - 1
                adv_images = adv_images[indice]
            # print(len(adv_images), len(adv_texts))
        else:
            # Note: CoAttack for fused embeddings requires paired input (image, text).
            max_length = 77 if "CLIP" in args.model else 1e3
            B = len(images)
            img2txt_n = [txt2img.count(i) for i in range(B)]
            repeated_images = torch.stack(
                [im for i, (im, n) in enumerate(zip(images, img2txt_n)) for _ in range(n)], dim=0
            )
            adv_images, adv_texts = attacker.run_before_fusion(
                repeated_images,
                texts,
                adv=adv_mode,
                num_iters=args.num_iters,
                alpha=args.alpha,
                max_length=max_length,
            )
            if not return_pair:
                indice = np.cumsum(img2txt_n) - 1
                adv_images = adv_images[indice]
    elif attack_name == "UnsupPGD":
        adv_images, adv_texts = attacker.attack(
            images,
            texts,
            device=device,
            max_length=30,
            txt_att_k=args.txt_att_k,
        )
    elif attack_name == "MMA":
        # scales_list = [
        #     [],
        #     [0.75,1.25],
        #     [0.5,0.75,1.25,1.5],
        # ]
        # scales = scales_list[args.scale_ver]
        scales = None
        adv_images, adv_texts = attacker.attack(
            images,
            texts,
            txt2img,
            device=device,
            max_length=30,
            scales=scales,  # Use predefined scales: set in attack_config!
            gt_caps_list=gt_caps_list,  # only used if gt_caps is not None
            is_train=return_pair,
        )
    elif attack_name == "SGA":
        scales = [0.5, 0.75, 1.25, 1.5]
        adv_images, adv_texts = attacker.attack(
            images, texts, txt2img, device=device, max_lemgth=30, scales=scales
        )
    elif attack_name == "TeCoA":
        adv_images, adv_texts = attacker.attack(
            images,
            texts,
            txt2img,
            device=device,
            max_length=30,
            img_attack_loss="clip",
        )
    elif attack_name == "TeCoA_Orig":
        adv_images, adv_texts = attacker.attack(
            images,
            texts,
            txt2img,
            device=device,
            max_length=30,
            img_attack_loss="clip_i",
        )
    elif attack_name in ["Bert-Sup", "EDA-Sup"]:
        adv_images, adv_texts = attacker.attack(
            images,
            texts,
            txt2img,
            device=device,
            max_length=30,
            img_attack_loss="sim",
        )
    else:
        raise ValueError(f"Invalid attack mode: {attack_name}")
    if return_pair and gt_caps_list is None:
        assert len(adv_images) == len(adv_texts)
    return adv_images, adv_texts
