import torch
import torch.nn as nn
import torch.nn.functional as F
from FLAlgorithms.users.userbase import User
import numpy as np

class UserFGSM(User):
    def __init__(self, device, numeric_id, train_data, test_data, model, batch_size, learning_rate, robust, gamma, local_epochs, K):
        super().__init__(device, numeric_id, train_data, test_data, model[0], batch_size, learning_rate, robust, gamma, local_epochs)

        self.loss = nn.CrossEntropyLoss() 
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
        #self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

    def train(self, epochs, adv_option):
        LOSS = 0
        self.model.train()
        for epoch in range(1, self.local_epochs + 1):
            for X,y in self.trainloader:
                X, y = X.to(self.device), y.long().to(self.device)
                # Adversarial training
                X = self.fgsm(X = X, y = y, epsilon = adv_option[0], alpha =adv_option[1])
                self.optimizer.zero_grad()
                loss = self.loss(self.model(X), y)
                loss.backward()
                self.optimizer.step()
                LOSS += loss
        return LOSS
