import torch
from models.attacks.matrix import Matrix
from models.attacks.pgd import PGD
from models.attacks.AdvGAN import AdvGAN


class Adversary:
    def __init__(self, dataset, defender, args, attack_model_type, K=20):
        self.dataset = dataset
        self.defender = defender
        self.attack_model_type = attack_model_type
        self.K = K
        self.args = args
        self.parse_attack_model()
    
    def parse_attack_model(self):
        if self.attack_model_type=="matrix":
            self.attack_model = Matrix(self.defender, self.rho)
        elif self.attack_model_type=="pgd":
            print("PGD adversary with params: epsilon=0.031, epsilon_iter=0.007, num_steps=20")
            self.attack_model = PGD(self.defender, epsilon=self.args.pgd_eps, epsilon_iter=self.args.pgd_step_size, num_steps=self.args.pgd_num_steps)
        elif self.attack_model_type=="advgan":
            print("AdvGAN attack model!")
            self.attack_model = AdvGAN(self.defender, self.args)
    
    def compute_attack_for_set(self, points, labels):
        if self.attack_model.requires_training==True:
            self.attack_model.train_on_set(points, labels)
        self.attack_model.get_perturbed(points, labels)

# class Adversary:
#     def __init__(self, dataset, defender, args, attack_model_type, rho=1, K=20, epochs=15):
#         """
#         Initializes the adversary with the following arguments
#         Args:
#             attack_model (str): The attack model type used by the adversary. 
#                                         This is the specific way in which adversary perturbs a given points or set of points.
#             dataset (PyToch Dataset): The dataset being used.
#             K (int, optional): The size of the adversarial set of points. Defaults to 20.
#             epochs (int, optional): The number of epochs for which the attack model is trained for each new S. Defaults to 15.
#         """
#         self.S_ind = []
    #     self.dataset = dataset
    #     self.defender = defender
    #     self.attack_model_type = args.attack_model_type
    #     self.rho = args.rho
    #     self.K = args.K
    #     self.epochs = args.epochs
    #     self.args = args
    #     self.defender_losses = {}
    #     self.reconstruction_losses = {}
    #     self.parse_attack_model()
    
    # def parse_attack_model(self):
    #     if self.attack_model_type=="matrix":
    #         self.attack_model = Matrix(self.defender, self.rho)
    #     elif self.attack_model_type=="pgd":
    #         self.attack_model = PGD(self.defender, epsilon=0.3, epsilon_iter=0.003, num_steps=40)
    
#     def compute_attack(self):
#         """
#             Compute the attack and return the set S
#         """
#         if self.attack_model.requires_training:
#             self.compute_attack_trained()
#         else:
#             self.compute_attack_untrained()
    
#     def compute_attack_untrained(self):
#         pass
    
#     def compute_attack_trained(self):
#         """
#         !! Not relevant according to the new formulation

#         Compute new S and attack model for C_{r-1}
#         If needed, serialize results
#         Args:
#             defender (Defender): The defender against which we wish to compute the attack
#         Returns:
#             attacks (Dictionary) = Dictionary containing the attack model at each time step
#         """
#         final_attack_models = {}
#         self.S_ind = []
#         for i in range(self.K):
#             print(f"|S| = {len(self.S_ind)}")
#             # attack_model_curr = self.attack_model.clone(requires_grad=True)
#             pt_idxs = self.sample_points()
            
#             min_loss = None
#             opt_attack_for_i = None
#             opt_idx_for_i = None
            
#             for count, idx in enumerate(pt_idxs):
#                 if idx in self.S_ind:
#                     # If we already have this index in the set S, then skip
#                     continue
                
#                 attack_model_idx = self.compute_optimal_attack_for_set(self.S_ind+[idx], self.attack_model)
                
                
#                 pts, labels = self.indices_to_points_and_labels(self.S_ind+[idx])
#                 loss = attack_model_idx.get_loss(pts, labels)
#                 if min_loss is None or loss < min_loss:
#                     min_loss = loss
#                     opt_idx_i = idx
#                     opt_attack_for_i = attack_model_idx
            
#             self.attack_model = opt_attack_for_i
#             self.S_ind.append(opt_idx_i)
#             final_attack_models[i+1] = self.attack_model.clone(False)
            
#             # Get the losses and save them with the adversary.
#             with torch.no_grad():
#                 points, labels = self.indices_to_points_and_labels(self.S_ind)
#                 self.attack_model.get_loss(points, labels, requires_mean=False, save_losses=True)
#                 self.defender_losses[i+1] = self.attack_model.defender_loss
#                 self.reconstruction_losses[i+1] = self.attack_model.reconstruction_loss
#             print(f"S_ind = {str(self.S_ind)}")
#             print(f"Samples in S: {self.dataset[self.S_ind]}")
            
#         return final_attack_models
    
#     def compute_optimal_attack_for_set(self, S_inds, attack_model_init):
#         """
#         Computes optimal attack model for the set S_inds
#         Args:
#             S_inds (list): The set of indices of the points in S
#             attack_model_curr (AttackModel): The attack model being used to initialize our attack model for training on S.

#         Returns:
#             attack_model (AttackModel): The final trained attack model for the set S
#         """
#         # Computes optimal attack model for the set S_inds
#         # attack model initialized to attack_model_curr
#         if self.attack_model.requires_training:
#             attack_model = attack_model_init.clone(requires_grad=True)
#             if self.attack_model_type=="matrix":
#                 points, labels = self.indices_to_points_and_labels(S_inds)
#                 attack_model.train(points, labels, self.epochs)
#             else:
#                 attack_model.train(self.dataset, self.args)
#             return attack_model
#         else:
#             return self.attack_model
    
#     def sample_points(self):
#         """
#         Utility function: Can be used to set the set of selection points using stochastic greedy or entire dataset

#         Returns:
#             (List): List of point indices.
#         """
#         return list(range(len(self.dataset)))
    
#     def objective(self):
#         # The adversary's objective, computed on a set of points indexed by S_inds
#         pass
    
#     def indices_to_points_and_labels(self, indices):
#         """
#         Utility function: Takes in indices and returns corresponding dataset slice.
        
#         Args:
#             indices (List): List of indices. Could also be a numpy array.

#         Returns:
#             A slice of the dataset, indexed by the input indices. 
#         """
#         return self.dataset[indices]

# class TrainableAdversary(Adversary):
#     pass

# class NonTraininableAdversary(Adversary):
#     pass