from random import gauss
import math
from models.defense.svm import SVM
from models.defense.nn_mnist import NN_MNIST
from models.defense.nn_cifar10 import NN_CIFAR10
from models.defense.resnet import ResNet18
import matplotlib.pyplot as plt
from copy import deepcopy
import torch
import numpy as np
import pickle
from torch.utils.data import Subset
from torch.utils.data import DataLoader
import time
import os
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import cleverhans
import math
import subprocess
from cleverhans.torch.attacks.projected_gradient_descent import (
    projected_gradient_descent,
)

from distorted_greedy import distorted_greedy, distorted_greedy_pointwise
import models.defense.baselines.TRADES.train_trades_cifar10 as trades_cifar10
import models.defense.baselines.TRADES.train_trades_mnist as trades_mnist
import models.defense.baselines.PGDAT.pgd_attack_fmnist as fmnist_pgdat
import models.defense.baselines.MART.train_MART as mart_cifar10
class Defender:
    def __init__(self, classifier_type, dataset, args):
        # Takes as input a classifier type, which could be NN, SVM, etc
        if torch.cuda.is_available():
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
            self.device = torch.device('cuda')
            torch.cuda.set_device(args.gpu)
            print('Using Device: ', torch.cuda.get_device_name())
        else:
            self.device = torch.device('cpu')
        self.base_classifier = None
        self.classifier = None
        self.classifier_type = classifier_type
        self.adversary = None              ## the defender need to have an object of adversary
        self.full_dataset = dataset
        self.parse_classifier_type()
        self.rho = args.rho
        self.lam = args.lam
        self.save_dir = args.save_dir
        # if torch.cuda.is_available():
        #     self.device = torch.device('cuda')
        # else:
        #     self.device = torch.device('cpu')
        self.pre_computed_diff = None
        self.pre_computed_regularizer = None
        self.timer=False
        self.perturbed_x = None
        self.S_dict = {}
        self.S_losses = []
        self.diff_losses = []
        self.S_len = []
        self.total_losses = []
        self.avg_S_losses = []
        self.avg_diff_losses = []
        self.eta = args.eta                 # Possibly reset by update_defender()
        self.val_att_accs_log = {}
        self.val_unatt_accs_log = {}
        self.val_comb_accs_log = {}
        self.S_losses_log = {}
        self.diff_losses_log = {}
        self.overall_epoch_num = 0
        self.adjust_lr = args.adjust_lr
        self.timestep = 0
        self.debug = args.debug
        self.args = args
        # self.split_dataset_into_trainval()
        self.lazy_attack_update = (args.lazy_attack_update > 0) # False by default!
        self.lr = args.lr
        self.new_val_attack_logs = {}
        self.new_val_unattacked_logs = {}
        self.new_val_combined_logs = {}
        self.remote_save_dir = None
        if self.lazy_attack_update:
            print("Lazy attack enabled!")
            self.lazy_attack_timestep = args.lazy_attack_update
        
        self.val_attacked_accuracies = {}
        self.val_unattacked_accuracies = {}
        self.val_combined_accuracies = {}
        
        self.test_attacked_accuracies = {}
        self.test_unattacked_accuracies = {}
        self.test_combined_accuracies = {}

        self.early_exit = False
        self.max_comb_val_acc = -1
        self.save_curr_model = False
        self.gs_times = []
        self.dg_times = []
        print("Attack model type is", self.args.attack_model_type)

    def split_dataset_into_trainval(self, frac=0.05):
        generator = torch.Generator(self.device)
        train, val = torch.utils.data.random_split(self.full_dataset, [40000, 10000], generator=generator)
        self.dataset = train
        self.val_ds = val
    
    def set_test_ds(self, test_attacked, test_unattacked):
        self.test_attack_ds = test_attacked
        self.test_unattacked_ds = test_unattacked
    
    def parse_classifier_type(self):
        if self.classifier_type=="svm":
            self.classifier = SVM()
            self.base_classifier = SVM()
        elif self.classifier_type=="nn_mnist":
            self.classifier = NN_MNIST()
            self.base_classifier = NN_MNIST()
        elif self.classifier_type=="nn_cifar10":
            self.classifier = NN_CIFAR10()
            self.base_classifier = NN_CIFAR10()
        
        
    def train_base_classifier(self, dataset, train_args=None, val_dataset=None):
        # Train base classifier using train_args
        # Also update dataset
        print("*************************************************************************")
        print("Base classifier training begins!")
        self.classifier.train_model(dataset, train_args, val_dataset)
        self.base_classifier = deepcopy(self.classifier)
        print("*************************************************************************")
        print("Base classifier training ended!")
    
    def create_checkpoint(self, T):
        checkpoint = {
                'epoch' : T,
                'optimizer' : self.optimizer.state_dict(),
                'state dict' : self.classifier.model.state_dict()
            }
    
    def update_classifier(self, dataset, train_args=None):
        # Use S_curr, A_curr to compute update classifier
        # self.classifier to store the current classifier
        '''
        Updates the classifier to make it robust against adversarial attack on subsets

        inputs:
            dataset: tuple (train_dataset, test_dataset) where each is a pytorch dataset object
            S_curr: list of indices (only required for SVM)
            A_curr: matrix (only required for SVM)
            train_args: arguments object. should contain K, gradient_steps, lr
        '''
        ## Modified algorithm from convex submodular paper
        self.classifier = self.classifier.to(self.device)
        self.base_classifier = self.base_classifier.to(self.device)

        T = train_args.T
        grad_steps = train_args.gradient_steps
        lr = train_args.lr
        weight_decay = train_args.weight_decay
        ts_batch_size = train_args.ts_batch_size
        S = []
        print("Parameters:")
        print("lambda =", train_args.lam)
        print("rho =", train_args.rho)
        # print(train_args)
        print("gamma =", train_args.gamma)
        # self.classifier = deepcopy(self.base_classifier)
        self.eta = train_args.eta
        val_acc_logs = self.get_val_set_acc(vanilla=False)
        self.val_att_accs_log[0] = [val_acc_logs[0]]
        self.val_unatt_accs_log[0] = [val_acc_logs[1]]
        self.val_comb_accs_log[0] = [val_acc_logs[2]]

        end_tr_epoch = 0
        
        if train_args.trades_training:    
            if self.args.data == 'cifar10':
                trades = trades_cifar10
            elif self.args.data == 'fmnist':
                trades = trades_mnist
            else:
                raise NotImplementedError

            end_tr_epoch = train_args.init_trades_epochs
            print('Trades training started.........')
            train_loader = DataLoader(self.dataset, batch_size=256, shuffle=False)
            test_loader = DataLoader(self.val_ds, batch_size=256, shuffle=False)

            class trades_args():
                def __init__(self, data):
                    if data == 'cifar10':
                        self.step_size = 0.007
                        self.epsilon = 0.031
                        self.num_steps = 20
                        self.beta = 6.0
                        self.log_interval = 1
                    elif data == 'fmnist':
                        self.step_size = 0.01
                        self.epsilon = 0.3
                        self.num_steps = 40
                        self.beta = 1.0
                        self.log_interval = 1
                    else:
                        raise NotImplementedError
            trades_args = trades_args(self.args.data)
                
            saved_state = self.get_model_state()
            for epoch in range(1, train_args.init_trades_epochs + 1):
                # adversarial training
                trades.train(trades_args, self.classifier.model, self.device, train_loader, self.optimizer, epoch)

                # evaluation on natural examples
                print('================================================================')
                trades.eval_train(self.classifier.model, self.device, train_loader)
                trades.eval_test(self.classifier.model, self.device, test_loader)
                print('================================================================')

                (acc, un_acc, comb_acc) = self.get_val_set_acc(vanilla=False) 
                self.val_att_accs_log[epoch] = [acc]
                self.val_unatt_accs_log[epoch] = [un_acc]
                self.val_comb_accs_log[epoch] = [comb_acc]
                self.save_info_dicts()
                
                min_epochs = 10
                if (epoch > min_epochs) and (acc < self.val_att_accs_log[epoch-1][0]) and train_args.trades_early_stop:
                    end_tr_epoch = epoch-1
                    self.set_model_state(saved_state)
                    break
                
                else:
                    saved_state = self.get_model_state()
                    # save checkpoint
                    torch.save(self.classifier.model.state_dict(),
                            os.path.join(self.save_dir, 'trades_tr_epoch{}.pt'.format(epoch)))
                    torch.save(self.optimizer.state_dict(),
                            os.path.join(self.save_dir, 'opt_trades_tr_epoch{}.tar'.format(epoch)))

            print(f'Completed {end_tr_epoch} epochs')
            print('Trades training finished.........')
        elif train_args.pgdat_training:
            # Initialized with PGD-AT classifier. Training PGD-AT for a few epochs
            # print("Zero PGD-AT epochs.... training... Let's see....")
            class pgdat_args():
                def __init__(self, data):
                    self.epsilon = data.pgd_eps
                    self.num_steps = data.pgd_num_steps
                    self.step_size = data.pgd_step_size
                    self.log_interval = 100
                    print(f"Arguments are for PGDAT are epsilon:{self.epsilon}, ns:{self.num_steps} and step_size:{self.step_size}")
            train_loader = DataLoader(self.dataset, batch_size=128, shuffle=False)
            for epoch in range(1, train_args.init_pgdat_epochs + 1):
                # train pgdat
                start = time.time()
                pgdat_train_args = pgdat_args(self.args)
                print(f"PGDAT training epoch {epoch}...")
                if self.args.data=="fmnist":
                    fmnist_pgdat.train(pgdat_train_args, self.classifier.model, self.device, train_loader, self.optimizer, epoch)
                end = time.time()
                print(f"Ended PGDAT training epoch {epoch}... Time taken is {end-start}")
                # Evaluate and store
                (acc, un_acc, comb_acc) = self.get_val_set_acc(vanilla=False) 
                self.val_att_accs_log[epoch] = [acc]
                self.val_unatt_accs_log[epoch] = [un_acc]
                self.val_comb_accs_log[epoch] = [comb_acc]
                self.save_info_dicts()
                
                min_epochs = 10
                if (epoch > min_epochs) and (acc < self.val_att_accs_log[epoch-1][0]) and train_args.trades_early_stop:
                    end_tr_epoch = epoch-1
                    self.set_model_state(saved_state)
                    break
                
                else:
                    saved_state = self.get_model_state()
                    # save checkpoint
                    torch.save(self.classifier.model.state_dict(),
                            os.path.join(self.save_dir, 'pgdat_tr_epoch{}.pt'.format(epoch)))
                    torch.save(self.optimizer.state_dict(),
                            os.path.join(self.save_dir, 'opt_pgdat_tr_epoch{}.tar'.format(epoch)))
        elif train_args.mart_training:
            if self.args.data == 'cifar10':
                mart = mart_cifar10
            else:
                raise NotImplementedError
            
            class mart_args():
                def __init__(self, args):
                    self.epsilon = args.pgd_eps
                    self.num_steps = args.pgd_num_steps
                    self.step_size = args.pgd_step_size
                    self.log_interval = 100
                    if args.data == 'cifar10':
                        self.beta = 5.0
                    elif args.data == 'fmnist':
                        self.beta = 1.0
                    else:
                        raise NotImplementedError
                    print(f"Arguments are for MART are epsilon:{self.epsilon}, ns:{self.num_steps}, step_size:{self.step_size}, beta: {self.beta}")
            mart_args = mart_args(self.args)

            train_loader = DataLoader(self.dataset, batch_size=128, shuffle=False)
            print('Trades training started.........')
            
            for epoch in range(1, train_args.init_mart_epochs + 1):
                # adversarial training
                # mart.train(args, model, device, train_loader, optimizer, epoch)
                
                # Evaluate and store
                (acc, un_acc, comb_acc) = self.get_val_set_acc(vanilla=False) 
                self.val_att_accs_log[epoch] = [acc]
                self.val_unatt_accs_log[epoch] = [un_acc]
                self.val_comb_accs_log[epoch] = [comb_acc]
                self.save_info_dicts()

                # save checkpoint
                torch.save(self.classifier.model.state_dict(),
                        os.path.join(self.save_dir, 'mart_tr_epoch{}.pt'.format(epoch)))
                torch.save(self.optimizer.state_dict(),
                        os.path.join(self.save_dir, 'opt_mart_tr_epoch{}.tar'.format(epoch)))
            print('MART training finished.........')  
            

        for t in range(T):
            print("***********************************************************************************")
            print(f"In timestep {t+1}/{T}")
            self.timestep = (t+1)
            if len(S)>0:
                S_adv = self.compute_adv_subset(take_prev_S=0)
                train_subset, rem_subset = self.split_subsets(dataset, S_adv)
                # print("Train subset size, rem_subset size:", len(train_subset), "\t", len(rem_subset))
                vanilla = False
                # if t<2:
                #     vanilla=True
                #     print("=================================Vanilla Training!=================================")
                self.early_exit = False
                self.save_curr_model = False
                # start_ts = time.time()
                S_losses, diff_losses, val_acc_logs = self.train_step(train_subset, rem_subset, grad_steps, weight_decay, ts_batch_size, vanilla=vanilla)
                # end_ts = time.time()
                # self.gs_times.append(end-start)
                # print(f"REQUIRED NUMBERS------------ Time for gradient step is {end-start}")
                self.val_att_accs_log[t+end_tr_epoch] = val_acc_logs[0]
                self.val_unatt_accs_log[t+end_tr_epoch] = val_acc_logs[1]
                self.val_comb_accs_log[t+end_tr_epoch] = val_acc_logs[2]
                self.S_losses_log[t] = S_losses
                self.diff_losses_log[t] = diff_losses
                self.S_losses.append(S_losses[-1])
                self.diff_losses.append(diff_losses[-1])
                timestep_stamp = "model_"+str(t+1)+"_timestep"
                def_save_path = os.path.join(self.save_dir, timestep_stamp)
                self.classifier.save_model(def_save_path)
                if t>5:
                    if self.early_exit==True:
                        current_val_acc = val_acc_logs[2][-2]
                        print("Early exit... val acc is", current_val_acc)
                        print("Val accs this timestep:", val_acc_logs[2])
                    else:
                        current_val_acc = val_acc_logs[2][-1]
                        print("Normal exit... val acc is", current_val_acc)
                        print("Val accs this timestep:", val_acc_logs[2])
                    if current_val_acc > self.max_comb_val_acc:
                        print("Max val acc for this model... Saving to model-final")
                        self.max_comb_val_acc = current_val_acc
                        self.save_curr_model = True

                    if self.save_curr_model:
                        print(f"Saving max acc model..... model-{t+1}-timestep")
                        rho_dir_model_name = os.path.join(self.args.rho_dir, "model-final")
                        self.classifier.save_model(rho_dir_model_name)
                    print(f"Max combined val accuracy till now is {self.max_comb_val_acc}")
                self.save_info_dicts()
                # Store checkpoint as well....
                checkpoint = self.create_checkpoint(t+1)
                checkpoint_stamp = "checkpoint_"+str(t+1)+"_timestep"
                checkpoint_save_path = os.path.join(self.save_dir, checkpoint_stamp)
                print(f"\tSaving model checkpoint to {checkpoint_save_path}")
                torch.save(checkpoint, checkpoint_save_path)
                
                print(f"\tS loss is {S_losses[-1]} and Diff loss is {diff_losses[-1]}")
            if self.lazy_attack_update and t%self.lazy_attack_timestep==0:
                print("************************ Lazy attack update! Updating classifier *****************")
                self.lazy_attack_classifier = deepcopy(self.classifier)
            start = time.time()
            S = distorted_greedy_pointwise(self, train_args)
            end = time.time()
            self.dg_times.append(end-start)
            print("REQUIRED NUMBERS-------------- GS TIMES!!!!!", self.gs_times)
            print("REQUIRED NUMBERS-------------- DG TIMES!!!!!", self.dg_times)
            self.S_dict[t] = deepcopy(S)
            self.S_len.append(len(S))
            if t>2:
                self.plot_losses()
                self.plot_S_sizes()
                fname = "defender_model_K_"+str(train_args.K)+".pkl"
                with open(os.path.join(self.save_dir, 'defender.pkl'), 'wb') as f:
                    pickle.dump(self, f)
            if self.args.attack_model_type=="advgan" and t%self.args.adv_gan_retrain_timestep==0:
                print("Retraining AdvGAN!")
                sub_ds = Subset(self.dataset, S)
                self.adversary.attack_model.train_advgan(sub_ds, self.args, self.args.adv_gan_retraining_epochs)
            print("Current set size is ", len(S))
        self.plot_losses()
        self.plot_S_sizes()
    
    def save_to_remote(self):
        return None
        # subprocess.run(["scp", "-r", self.args.rho_dir, self.remote_save_dir])
        
    def save_info_dicts(self):
        print("Saving INFO DICTS!!!!!!!!!!")
        old_dict = {}
        old_dict["S losses"] = deepcopy(self.S_losses_log)
        old_dict["Diff losses"] = deepcopy(self.diff_losses_log)
        old_dict["Val attacked accs"] = deepcopy(self.val_att_accs_log)
        old_dict["Val unattacked accs"] = deepcopy(self.val_unatt_accs_log)
        old_dict["Val combined accs"] = deepcopy(self.val_comb_accs_log)
        with open(os.path.join(self.args.rho_dir, 'val_acc_train_losses_OLD.pkl'), 'wb') as f:
            pickle.dump(old_dict, f)
        
        # Store the real val and test accuracy dicts!
        val_dict = {}
        val_dict["Attacked Accuracies"] =  self.val_attacked_accuracies
        val_dict["Unattacked Accuracies"] = self.val_unattacked_accuracies
        val_dict["Combined Accuracies"] = self.val_combined_accuracies
        with open(os.path.join(self.args.rho_dir, 'val_accs.pkl'), 'wb') as fval:
            pickle.dump(val_dict, fval)
        
        test_dict = {}
        test_dict["Attacked Accuracies"] =  self.test_attacked_accuracies
        test_dict["Unattacked Accuracies"] = self.test_unattacked_accuracies
        test_dict["Combined Accuracies"] = self.test_combined_accuracies
        with open(os.path.join(self.args.rho_dir, 'test_accs.pkl'), 'wb') as ftest:
            pickle.dump(test_dict, ftest)
        self.save_to_remote()
    
    def log_val_test_accs(self, val_acc, val_un_acc, val_comb_acc, test_acc, test_un_acc, test_comb_acc, gs):
        ts = self.timestep
        self.val_attacked_accuracies[(ts, gs)] = val_acc
        self.val_unattacked_accuracies[(ts, gs)] = val_un_acc
        self.val_combined_accuracies[(ts, gs)] = val_comb_acc
        
        self.test_attacked_accuracies[(ts, gs)] = test_acc
        self.test_unattacked_accuracies[(ts, gs)] = test_un_acc
        self.test_combined_accuracies[(ts, gs)] = test_comb_acc
    
    def compute_adv_subset(self, take_prev_S=0):
        S_dict = deepcopy(self.S_dict)
        if (take_prev_S + 1) > len(S_dict):
            # If you don't have sufficient prev sets, set take_prev_S to take union of whatever you have
            take_prev_S = len(S_dict) - 1
        t_idxs = list(S_dict.keys())[-(take_prev_S+1):]
        print(f"\tUnion of last {take_prev_S} sets: {t_idxs}")
        S_adv = []
        for t in t_idxs:
            # print(S_adv)
            # print(S_dict[t])
            S_adv = S_adv + S_dict[t]
        print(f"\tLength of S_adv before union is {len(S_adv)}", end=" ")
        S_dict = list(set(S_adv))
        print(f" which becomes {len(S_adv)}")
        return S_adv
    
    def init_optimizer(self, optimizer):
        if optimizer=="Adam":
            print("Initializing optimizer with Adam.......")
            self.optimizer = optim.Adam(self.classifier.model.parameters(), lr=self.lr)
        elif optimizer=="SGD":
            print("Initializing optimizer with SGD.......")
            self.optimizer = optim.SGD(self.classifier.model.parameters(), lr=self.lr, momentum=0.1, weight_decay=2e-4)
    
    def set_model_state(self, state):
        self.classifier.model.load_state_dict(state["state dict"])
        self.optimizer.load_state_dict(state["optimizer"])
    
    def get_model_state(self):
        curr_state = {
            "state dict" : deepcopy(self.classifier.model.state_dict()),
            "optimizer" : deepcopy(self.optimizer.state_dict())
        }
        return curr_state
    
    def compare_models(self, sd1, sd2):
        models_differ = 0
        for key_item_1, key_item_2 in zip(sd1.items(), sd2.items()):
            if torch.equal(key_item_1[1], key_item_2[1]):
                pass
            else:
                models_differ += 1
                if (key_item_1[0] == key_item_2[0]):
                    print('Mismtach found at', key_item_1[0])
                else:
                    raise Exception
        if models_differ == 0:
            # print('Models match perfectly! :)')
            pass
    
    def compute_att_unatt_batch_sizes_for_train(self, final_batch_size):
        eta_frac = (self.eta)/100.0
        num_attack = max(int(math.ceil(final_batch_size*eta_frac)), 1)
        num_unattack = max(final_batch_size - num_attack, 1)
        return (num_attack, num_unattack)
    
    def adjust_learning_rate(self):
        effective_epoch_for_lr_update = 0
        if self.adjust_lr=="None":
            print("No LR adjustment")
            return 
        elif self.adjust_lr=="Epoch":
            print("Epoch LR adjustment!")
            effective_epoch_for_lr_update = self.overall_epoch_num
        elif self.adjust_lr=="Timestep":
            print("Timestep LR adjustment!")
            effective_epoch_for_lr_update = self.timestep
        self.classifier.adjust_learning_rate(self.optimizer, effective_epoch_for_lr_update, self.lr)
        

    def train_step(self, train_subset, rem_subset, grad_steps, weight_decay, ts_batch_size, save_between=False, save_freq=10, save_path=None, vanilla=False):
        print("Train step called!")
        self.classifier.model.train()
        device = self.device
        
        self.adjust_learning_rate()

        begin = time.time()
        print(f"\tTrain subset size is {len(train_subset)} and rem_subset size is {len(rem_subset)}")
        
        attack_batch_size, rem_batch_size = self.compute_att_unatt_batch_sizes_for_train(ts_batch_size)

        attack_dl = DataLoader(train_subset, batch_size=attack_batch_size, shuffle=False)
        generator = torch.Generator(self.device)
        rem_dl = DataLoader(rem_subset, batch_size=rem_batch_size, shuffle=True, generator=generator)

        start = time.time()

        print(f"\tSize of attack set is {len(train_subset)} and unattacked is {len(rem_subset)}")
        print(f"\tAttack batch size is {attack_batch_size} and unattacked batch size is {rem_batch_size}")

        val_att_accs = []
        val_unatt_accs = []
        val_comb_accs = []
        S_losses = []
        rem_ce_losses = []

        saved_state = self.get_model_state()
        for step in range(grad_steps):
            S_loss = 0.0
            rem_ce_loss = 0.0
            begin_step = time.time()
            print(f"In gradient step {step+1} of {grad_steps}...")

            attack_iter = iter(attack_dl)
            rem_iter = iter(rem_dl)
            
            assert(len(rem_dl)>=len(attack_dl))
            num_batches = max(len(rem_dl), len(attack_dl))
            if self.debug:
                print("Debugging mode ON!")
                num_batches = 2
            for batch in range(num_batches):
                self.classifier.model.train()
                self.optimizer.zero_grad()

                # Combine examples into 1 batch
                combined_x = None
                combined_y = None
                combined_rho = None
                bitmask_attack = None

                if (batch+1)%5==0:
                    print(f"\tIn batch {batch+1}/{num_batches}...")
                
                # Get attack batch, if available, and train
                attack_batch = next(attack_iter, None)
                loss_adv = 0
                attack_batch_size = 0
                if attack_batch is not None:
                    attack_x, attack_y = attack_batch
                    attack_batch_size = attack_x.shape[0]
                    attack_x = attack_x.to(device)
                    attack_y = attack_y.to(device)
                    
                    if not vanilla:
                        # print("=============================================== Adversarial Training! ============================")
                        # old_state_dict = self.classifier.model.state_dict()
                        self.classifier.model.eval()
                        attack_x = self.adversary.attack_model.get_perturbed(attack_x).to(device)
                        # attack_x = attack_x+0.1*torch.randn(attack_x.shape)
                        self.classifier.model.train()
                        # new_state_dict = self.classifier.model.state_dict()
                        # self.compare_models(old_state_dict, new_state_dict)

                    # attack_x = attack_x+0.1*torch.randn(attack_x.shape)
                    if combined_x is None:
                        combined_x = attack_x
                        combined_y = attack_y
                        combined_rho = torch.ones(attack_y.shape[0])
                        bitmask_attack = torch.ones(attack_y.shape[0])
                    else:
                        combined_x = torch.cat((combined_x, attack_x))
                        combined_y = torch.cat((combined_y, attack_y))
                        combined_rho = torch.cat((combined_rho, torch.ones(attack_y.shape[0])))
                        bitmask_attack = torch.cat((bitmask_attack, torch.ones(attack_y.shape[0])))
                
                # Get unattacked batch, if available, and train
                rem_batch = next(rem_iter, None)
                loss_rem = 0
                rem_batch_size = 0
                if rem_batch is not None:
                    rem_x, rem_y = rem_batch
                    rem_batch_size = rem_x.shape[0]
                    rem_x = rem_x.to(device)
                    rem_y = rem_y.to(device)

                    if combined_x is None:
                        combined_x = rem_x
                        combined_y = rem_y
                        combined_rho = torch.ones(rem_y.shape[0])
                        bitmask_attack = torch.zeros(rem_y.shape[0])
                    else:
                        combined_x = torch.cat((combined_x, rem_x))
                        combined_y = torch.cat((combined_y, rem_y))
                        combined_rho = torch.cat((combined_rho, self.rho*torch.ones(rem_y.shape[0])))
                        bitmask_attack = torch.cat((bitmask_attack, torch.zeros(rem_y.shape[0])))
                    
                logits = self.classifier(combined_x)
                
                loss_total = F.cross_entropy(logits, combined_y, reduction="none")
                loss_with_rho = loss_total*combined_rho
                loss_full = torch.sum(loss_with_rho)/len(combined_x)

                loss_on_cpu = loss_with_rho.detach().cpu().numpy()
                bitmask_attack_cpu = bitmask_attack.detach().cpu().numpy()
                # print(f"Bitmask first few: {bitmask_attack_cpu[:5]}")
                if (batch+1)%5==0:
                    print(f"\t\tAttacked: {np.sum(bitmask_attack_cpu)}, Unattacked: {np.sum(1-bitmask_attack_cpu)}")
                curr_S_loss = np.sum(bitmask_attack_cpu*loss_on_cpu) #/np.sum(bitmask_attack_cpu)
                curr_ce_loss = np.sum((1-bitmask_attack_cpu)*loss_on_cpu) #/np.sum(1-bitmask_attack_cpu)

                S_loss += curr_S_loss
                rem_ce_loss += curr_ce_loss
                # rem_ce_loss += float(loss_rem)
                
                # loss_full = (loss_adv + loss_rem)/(attack_batch_size + rem_batch_size)

                loss_full.backward()
                self.optimizer.step()
                curr_S_avg = 0 if np.sum(bitmask_attack_cpu)==0 else (curr_S_loss/np.sum(bitmask_attack_cpu))
                curr_rem_avg = curr_ce_loss/np.sum(1-bitmask_attack_cpu)
                if (batch+1)%5==0:
                    print(f"\t\tAvg: {round(float(loss_full), 5)} Adv loss is: {round(float(curr_S_avg), 5)} and rem loss is {round(float(curr_rem_avg), 5)}")
            
            finish_step = time.time()
            time_req = finish_step-begin_step
            self.gs_times.append(time_req)

            # Append losses for logging:
            S_losses.append(S_loss)
            rem_ce_losses.append(rem_ce_loss)

            # Get validation accuracies and see if model needs to be discarded and previous needs to be stored.
            (acc, un_acc, comb_acc) = self.get_val_set_acc(vanilla=vanilla)
            (test_acc, test_un_acc, test_comb_acc) = self.get_test_set_accs(vanilla=vanilla)
            if (not vanilla) and (len(val_att_accs)>0) and (acc<val_att_accs[-1]):
                # Val attack acc dropped, revert to previous model
                print(f"------ Curr accuracy is {acc} which is less than prev acc of {val_att_accs[-1]}. Stopping training for this timestep.")
                self.early_exit = True
                val_att_accs.append(acc)
                val_unatt_accs.append(un_acc)
                val_comb_accs.append(comb_acc)
                self.set_model_state(saved_state)
                self.log_val_test_accs(acc, un_acc, comb_acc, test_acc, test_un_acc, test_comb_acc, -1)
                return (S_losses, rem_ce_losses, (val_att_accs, val_unatt_accs, val_comb_accs))
            else:
                # Val acc increased. Save curr state!
                saved_state = self.get_model_state()
                val_att_accs.append(acc)
                val_unatt_accs.append(un_acc)
                val_comb_accs.append(comb_acc)
                self.log_val_test_accs(acc, un_acc, comb_acc, test_acc, test_un_acc, test_comb_acc, step+1)
                self.overall_epoch_num += 1
            print(f"\tValidation accuracies are: {val_att_accs}")

            # Compute average loss for printing
            print(f"\tS_loss is {S_loss}")
            S_loss_avg = S_loss/len(attack_dl.dataset)
            rem_ce_loss_avg = rem_ce_loss/len(rem_dl.dataset)
            print(f"\tTrain: Avg S_loss is {S_loss_avg} and avg CE loss on D\S is {rem_ce_loss_avg}. Time taken: {round(time_req, 3)}")
            self.save_to_remote()

        print("Train step Ended!")
        return (S_losses, rem_ce_losses, (val_att_accs, val_unatt_accs, val_comb_accs))
    
    def get_test_set_accs(self, vanilla=False):
        """
            Evaluates on the test set and returns attacked, unattacked and combined accuracies
            Same attacked/unattacked set used across the entire run of the model
        """
        self.classifier.model.eval()
        eta_frac = self.eta/100
        attack_size = len(self.test_attack_ds)
        rem_size = len(self.test_unattacked_ds)
        eps = self.args.pgd_eps
        epsiter = self.args.pgd_step_size
        num_steps = self.args.pgd_num_steps
        
        # Evaluate on attack set
        attack_dl = DataLoader(self.test_attack_ds, batch_size=512, shuffle=False)
        attacked_correct = 0
        attacked_loss = 0
        for (x, y) in attack_dl:
            x = x.to(self.device)
            y = y.to(self.device)
            if vanilla==False:
                x = projected_gradient_descent(self.classifier.model, x, eps, epsiter, num_steps, np.inf)
            acc, loss = list(self.classifier.evaluate(x, y).values())
            attacked_correct += acc*x.shape[0]
            attacked_loss += loss*x.shape[0]

        # Evaluate on unattacked set
        unattacked_correct = 0
        unattacked_loss = 0
        rem_dl = DataLoader(self.test_unattacked_ds, batch_size=512, shuffle=False)
        for (x, y) in rem_dl:
            x = x.to(self.device)
            y = y.to(self.device)
            acc, loss = list(self.classifier.evaluate(x, y).values())
            unattacked_correct += acc*x.shape[0]
            unattacked_loss += loss*x.shape[0]
        
        # Compute average losses on attacked, unattacked and the combined
        attacked_loss = attacked_loss/len(self.test_attack_ds)
        unattacked_loss = unattacked_loss/len(self.test_unattacked_ds)
        combined_loss = eta_frac*attacked_loss + (1-eta_frac)*unattacked_loss

        attacked_acc = attacked_correct/len(self.test_attack_ds)
        unattacked_acc = unattacked_correct/len(self.test_unattacked_ds)
        combined_acc = eta_frac*attacked_acc + (1-eta_frac)*unattacked_acc
        print("\tTest Set Performance:")
        print(f"\t\tTest: Attacked set accuracy is {attacked_acc} and avg loss is {attacked_loss}")
        print(f"\t\tTest: Unattacked set accuracy is {unattacked_acc} and avg loss is {unattacked_loss}")
        print(f"\t\tTest: Combined accuracy is {combined_acc} and avg loss is {combined_loss}")
        return (attacked_acc, unattacked_acc, combined_acc)
    
    def get_val_set_acc(self, vanilla=False):
        """
            Returns adv accuracy on validation set
        """
        self.classifier.model.eval()
        eps = self.args.pgd_eps
        epsiter = self.args.pgd_step_size
        num_steps = self.args.pgd_num_steps
        dataset = self.val_ds
        eta_frac = self.eta/100
        attack_size = int(eta_frac*len(dataset))
        rem_set_size = len(dataset) - attack_size
        if self.args.debug:
            attack_size = 10
            rem_set_size = 10
            ignore_set_size = len(dataset) - 20
        print(f"Attack set size is {attack_size}, rem set size is {rem_set_size}")
        generator = torch.Generator(self.device)
        if self.args.debug:
            attack_set, rem_set, _ = torch.utils.data.random_split(self.val_ds, [attack_size, rem_set_size, ignore_set_size], generator=generator)
        else:
            attack_set, rem_set = torch.utils.data.random_split(self.val_ds, [attack_size, rem_set_size], generator=generator)
        
        print(f"\tIn validation with attack and unattacked sizes: {len(attack_set)} and {(len(rem_set))} resp")

        # Evaluate on attack set
        attack_dl = DataLoader(attack_set, batch_size=512, shuffle=False)
        attacked_correct = 0
        attacked_loss = 0
        for (x, y) in attack_dl:
            x = x.to(self.device)
            y = y.to(self.device)
            if vanilla==False:
                x = projected_gradient_descent(self.classifier.model, x, eps, epsiter, num_steps, np.inf)
            acc, loss = list(self.classifier.evaluate(x, y).values())
            attacked_correct += acc*x.shape[0]
            attacked_loss += loss*x.shape[0]

        # Evaluate on unattacked set
        unattacked_correct = 0
        unattacked_loss = 0
        rem_dl = DataLoader(rem_set, batch_size=512, shuffle=False)
        for (x, y) in rem_dl:
            x = x.to(self.device)
            y = y.to(self.device)
            acc, loss = list(self.classifier.evaluate(x, y).values())
            unattacked_correct += acc*x.shape[0]
            unattacked_loss += loss*x.shape[0]
        
        # Compute average losses on attacked, unattacked and the combined
        attacked_loss = attacked_loss/len(attack_set)
        unattacked_loss = unattacked_loss/len(rem_set)
        combined_loss = eta_frac*attacked_loss + (1-eta_frac)*unattacked_loss

        attacked_acc = attacked_correct/len(attack_set)
        unattacked_acc = unattacked_correct/len(rem_set)
        combined_acc = eta_frac*attacked_acc + (1-eta_frac)*unattacked_acc
        print("\tVal Set Performance:")
        print(f"\t\tAttacked set accuracy is {attacked_acc} and avg loss is {attacked_loss}")
        print(f"\t\tUnattacked set accuracy is {unattacked_acc} and avg loss is {unattacked_loss}")
        print(f"\t\tCombined accuracy is {combined_acc} and avg loss is {combined_loss}")
        return (attacked_acc, unattacked_acc, combined_acc)
        
    def compute_perturbed_dl(self, dl):
        for idx, (x, y) in enumerate(dl):
            pert = self.adversary.attack_model.get_perturbed(x, y)
            if self.perturbed_x is None:
                self.perturbed_x = pert
            else:
                self.perturbed_x = torch.cat(pert, self.perturbed_x)
        pass

    def get_classifier_loss(self, points, labels, requires_mean=True):
        return self.classifier.get_loss(points, labels, requires_mean)
    
    def get_base_classifier_loss(self, points, labels, requires_mean=True):
        return self.base_classifier.get_loss(points, labels, requires_mean)
    
    def load_base_classifier(self, load_path):
        if self.classifier_type=="nn_mnist":
            self.base_classifier.load_model(load_path)
            self.classifier = deepcopy(self.base_classifier)
        elif self.classifier_type=="nn_cifar10":
            self.base_classifier.load_model(load_path)
            self.classifier = deepcopy(self.base_classifier)
    
    def split_subsets(self, dataset, S_idxs):
        S = Subset(dataset, S_idxs)
        remaining_idxs = [idx for idx in range(len(dataset)) if idx not in S_idxs]
        # print("Splitting into sizes, ", len(S_idxs), " and ", len(remaining_idxs))
        remaining = Subset(dataset, remaining_idxs)
        # print("Size of remaining is", len(remaining))
        for idx in S_idxs:
            if idx in remaining_idxs:
                print("Erroneous split!")
        return (S, remaining)
    
    def zero_grad(self):
        """
        Zeroes the gradients of the classifier
        """
        self.classifier.zero_grad()
    
    def set_adversary(self, adversary):
        self.adversary = adversary
    
    def compute_diff(self):
        if self.pre_computed_diff is None:
            start = time.time()
            print("\t Pre-computing!")
            lossfn = nn.CrossEntropyLoss(reduction="sum")
            with torch.no_grad():
                # Torch.no_grad is extremely important here
                # Without it, we get out of memory error on cuda!
                # https://stackoverflow.com/questions/55322434/how-to-clear-cuda-memory-in-pytorch
                # Compute the diff and save it
                total_diff = 0
                dl = DataLoader(self.dataset, batch_size=32)
                for idx, (x, y) in enumerate(dl):
                    # print(f"\tIdx is {idx}")
                    x = x.to(self.device)
                    y = y.to(self.device)
                    logits = self.classifier(x)
                    nat_loss = lossfn(logits, y)
                    total_diff += nat_loss
            self.pre_computed_diff = total_diff
            end = time.time()
            print(f"\t Pre-computation of diff done in {round(end-start, 3)} seconds")
            print(f"\t Diff val is {float(self.pre_computed_diff)}")
        return self.pre_computed_diff
    
    def compute_regularizer(self):
        if self.pre_computed_regularizer is None:
            # Compute the regularizer and save it
            self.pre_computed_regularizer = self.classifier.get_regularizer()
        return self.pre_computed_regularizer

    def compute_g(self, S):
        if len(S)==0:
            return 0
        if self.timer:
            start = time.time()
        L = self.compute_L(S)
        diff = self.compute_diff() #self.classifier(X_d) - self.base_classifier(X_d)
        g = L + self.rho*diff + self.lam*self.compute_regularizer()*len(S)
        if self.timer:
            end = time.time()
            print(f"\t\tCompute_g on {S} took {round(end-start, 4)}")
        return g
    
    def compute_g_on_batch(self, X, y):
        L = self.compute_L_on_batch(X, y)
        # diff_val = self.rho*self.compute_diff()
        # print("\t\tL values are:", L)
        return L + self.lam

    def compute_c_on_batch(self, X, y):
        lossfn = nn.CrossEntropyLoss(reduction="none")
        with torch.no_grad():
            X = X.to(self.device)
            y = y.to(self.device)
            logits = self.classifier(X)
            nat_loss = self.rho*lossfn(logits, y)
            assert (nat_loss.shape==y.shape), ("Incompatible shapes in compute_c. Check lossfn.")
            # out = nat_loss + self.lam
        return nat_loss + self.lam
    
    def compute_L_on_batch(self, X, y):
        X = X.to(self.device)
        y = y.to(self.device)
        if self.adversary.attack_model.requires_training==True:
            theta_star = self.adversary.compute_attack_for_set(X, y) 
        else:
            theta_star = self.adversary.attack_model
        if self.timer:
            start = time.time()
        X_perturbed = theta_star.get_perturbed(X, y)

        if self.timer:
            end = time.time()
            # print(f"\t\tComputing perturbed examples took {round(end-start, 3)}")

        L = self.get_classifier_loss(X_perturbed, y, requires_mean=False)
        # print(L.shape)
        return L

    def compute_c(self, S):
        if len(S)==0:
            return 0
        X_s, Y_s = self.unzip_dataset(self.dataset, S)
        num_S = len(S)
        diff = self.classifier(X_s) - self.base_classifier(X_s)
        c = self.rho*(self.calc_sq_norm_sum(diff)) + self.lam*self.classifier.get_regularizer()*num_S
        return c

    def compute_L(self, S):
        num_S = len(S)
        if num_S == 0:
            return 0
        X_s, Y_s = self.unzip_dataset(self.dataset, S)
        if self.adversary.attack_model.requires_training==True:
            theta_star = self.adversary.compute_attack_for_set(X_s, Y_s)  
        else:
            theta_star = self.adversary.attack_model
        # X_s, Y_s = self.dataset[S]
        if self.timer:
            start = time.time()
        X_s_perturbed = theta_star.get_perturbed(X_s, Y_s)
        if self.timer:
            end = time.time()
            print(f"\t\tcompute_L perturbed for {S} took {round(end-start, 4)}")
            start = time.time()
        L = self.get_classifier_loss(X_s_perturbed, Y_s, requires_mean=True)*num_S
        if self.timer:
            end = time.time()
            print(f"\t\tcompute_L get_loss for {S} took {round(end-start, 4)}")
            print(f"\t\tValue of L is", L)
        return L

    def calc_sq_norm_sum(self, m):
        return torch.sum(torch.linalg.norm(m, dim=1)**2)
    
    def reset_pre_computed(self):
        self.pre_computed_diff = None
        self.pre_computed_regularizer = None
    
    def unzip_dataset(self, ds, idxs):
        if self.timer:
            start = time.time()
        subset = Subset(ds, idxs)
        dl = DataLoader(subset, batch_size=len(subset))
        dl_iter = iter(dl)
        x, y = dl_iter.next()
        x = x.to(self.device)
        y = y.to(self.device)
        if self.timer:
            end = time.time()
            print(f"\t\tUnzip dataset took {round(end-start, 4)}")
        return (x, y)
    
    def plot_losses(self):
        # Plot the S_loss and diff_loss and store it in dir!
        plot_save_dir = self.save_dir
        S_save_dir = plot_save_dir+"/S_loss_plots"
        os.makedirs(S_save_dir, exist_ok=True)
        # S_losses plot:
        T = len(self.S_losses)+2
        plt.plot(np.arange(T)[2:], np.array(self.S_losses))
        plt.title("Loss on subset S (due to perturbation) vs T")
        fig_name = os.path.join(S_save_dir+"/S_losses_timestep_" + str(T-1) +".png")
        plt.xlabel("T")
        plt.ylabel("Loss value (sum over S)")
        plt.savefig(fig_name)
        plt.close()

        # Diff_losses plot
        diff_save_dir = plot_save_dir+"/diff_plots"
        os.makedirs(diff_save_dir, exist_ok=True)
        T = len(self.diff_losses)+2
        plt.plot(np.arange(T)[2:], np.array(self.diff_losses))
        plt.title("Loss due to diff with base classifier vs T")
        fig_name = os.path.join(diff_save_dir+"/diff_losses_timestep_" + str(T-1) +".png")
        plt.xlabel("T")
        plt.ylabel("Difference loss on D\S (summed)")
        plt.savefig(fig_name)
        plt.close()

        # Plot total losses as well
        total_save_dir = plot_save_dir+"/total_loss_plots"
        os.makedirs(total_save_dir, exist_ok=True)
        T = len(self.S_losses)+2
        plt.plot(np.arange(T)[2:], (np.array(self.S_losses) + np.array(self.diff_losses)))
        plt.title("Total loss classifier training vs T")
        fig_name = os.path.join(total_save_dir+"/total_losses_timestep_" + str(T-1) +".png")
        plt.xlabel("T")
        plt.ylabel("Total Loss")
        plt.savefig(fig_name)
        plt.close()

    
    def plot_S_sizes(self):
        # Plot the size of S over multiple timesteps!
        plot_save_dir = self.save_dir
        S_len_save_dir = plot_save_dir+"/S_length_plots"
        os.makedirs(S_len_save_dir, exist_ok=True)
        T = len(self.S_len)+1
        plt.plot(np.arange(T)[1:], np.array(self.S_len))
        plt.title("Length of selected set S vs T")
        fig_name = os.path.join(S_len_save_dir+"/S_lengths_timestep_" + str(T-1) +".png")
        plt.xlabel("T")
        plt.ylabel("|S|")
        plt.savefig(fig_name)
        plt.close()