from abc import ABCMeta, abstractmethod
from torch import nn
# FGSM based attacks
from torchattacks import FGSM, BIM, FFGSM, MIFGSM, RFGSM, DIFGSM, TIFGSM, NIFGSM, SINIFGSM, VMIFGSM, VNIFGSM, LGV
#Projected Gradient Descent based attacks
from torchattacks import PGD, EOTPGD, APGD, APGDT, UPGD, TPGD, Jitter, PGDRS
# Non-targeted attacks
from torchattacks import VANILA, GN
# L2 attacks
from torchattacks import PGDL2, PGDRSL2
# Carlini-Wagner based attacks
from torchattacks import CW
from torchattacks import DeepFool

# Linf, L2 attacks
from torchattacks import FAB
from torchattacks import AutoAttack
from torchattacks import Square

# L0 attacks
from torchattacks import SparseFool
from torchattacks import OnePixel
from torchattacks import Pixle

import argparse
import inspect

import yaml
from yaml.loader import SafeLoader


class AttackMeta(ABCMeta):
    registered_attacks = {}
    def __new__(cls, name, bases, attrs):
        # Create the class
        new_class = super().__new__(cls, name, bases, attrs)
        # Register the attack
        if name != "Attack":
            AttackMeta.registered_attacks[name.split("Attack")[0]] = new_class
        return new_class

class AttackFactory(metaclass=AttackMeta):
    registered_attacks = {}

    def __init__(self, args):
        self.args = args
        self.attack = None

    @abstractmethod
    def __call__(self, model,*args, **kwargs):
        pass

    @abstractmethod
    def get_class(self):
        pass

    @staticmethod
    def get_attack(name, args):
        return AttackMeta.registered_attacks[name](args)

class FGSMAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = FGSM(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return FGSM

class FFGSMAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = FFGSM(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return FFGSM

class MIFGSMAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = MIFGSM(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return MIFGSM   

class RFGSMAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = RFGSM(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return RFGSM

class DIFGSMAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = DIFGSM(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return DIFGSM

class TIFGSMAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = TIFGSM(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return TIFGSM

class NIFGSMAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = NIFGSM(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return NIFGSM

class SINIFGSMAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = SINIFGSM(model, *args, **kwargs)
        return self.attack

    def get_class(self):
        return SINIFGSM

class VMIFGSMAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = VMIFGSM(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return VMIFGSM

class VNIFGSMAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = VNIFGSM(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return VNIFGSM

class LGVAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = LGV(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return LGV

class PGDAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = PGD(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return PGD

class EOTPGDAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = EOTPGD(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return EOTPGD
    

class APGDAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = APGD(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return APGD


class APGDTAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = APGDT(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return APGDT

class UPGDAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = UPGD(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return UPGD

class TPGDAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = TPGD(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return TPGD

class JitterAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = Jitter(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return Jitter

class PGDRSAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = PGDRS(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return PGDRS

class VANILAAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = VANILA(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return VANILA

class GNAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = GN(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return GN

class PGDL2Attack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = PGDL2(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return PGDL2

class PGDRSL2Attack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = PGDRSL2(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return PGDRSL2

class CWAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = CW(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return CW

class DeepFoolAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = DeepFool(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return DeepFool

class SparseFoolAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = SparseFool(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return SparseFool

class BIMAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = BIM(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return BIM

class FABAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = FAB(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return FAB

class AutoAttackAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = AutoAttack(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return AutoAttack

class SquareAttackAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = Square(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return Square

class OnePixelAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = OnePixel(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return OnePixel

class PixleAttack(AttackFactory):
    def __call__(self, model, *args, **kwargs):
        self.attack = Pixle(model, *args, **kwargs)
        return self.attack
    
    def get_class(self):
        return Pixle


def get_attack_args_kwargs(attack_name:str = "FGSM", *args, **kwargs):
    attack_class = AttackMeta.registered_attacks[attack_name](args).get_class()
    args = inspect.getfullargspec(attack_class).args
    kwargs = inspect.getfullargspec(attack_class).kwonlyargs

    return args, kwargs


def get_attack_from_file(file_path:str = "attack_config.yaml"):
    attack_args, attack_kwargs = [], {}
    try:
        with open(file_path, 'r') as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
            attack_args = []
            attack_kwargs = config["config"]
    except FileNotFoundError as e:
        print(f"Error: {e}")
        return attack_args, attack_kwargs   

    print(attack_kwargs)    
    return attack_args, attack_kwargs
    

def get_attack_instance(attack_name:str = "FGSM", model:nn.Module = None , *args, **kwargs):

    args, kwargs = get_attack_from_file(f"configs/adversarial_attacks/{attack_name}.yaml")

    return AttackMeta.registered_attacks[attack_name](args)(model, *args, **kwargs)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='PyTorch Attack Wrapper')
    parser.add_argument('--attack', default='FGSM', type=str, help='attack method')
    parser.add_argument('--eps', default=0.031, type=float, help='epsilon')
    args = parser.parse_args()

    print(AttackMeta.registered_attacks)

    attack = AttackMeta.registered_attacks[args.attack](args)
    print(AttackFactory)
    