from ast import parse
import os
import sys
from copy import deepcopy
import argparse
import pickle
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from copy import deepcopy

if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

import path
folder_path= (path.Path(__file__).abspath()).parent.parent
sys.path.append(folder_path)

from sub_selection import RandomSubset, ClassRandomSubset
from data.pytorch_datasets import get_dataset

from models.defense.nn_mnist import NN_MNIST
from models.defender import Defender

torch.manual_seed(0)
np.random.seed(0)

parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', default='mnist', type=str)
parser.add_argument('-t', '--t', default=20, type=int, help='timestep of the model to be loaded')
parser.add_argument('--save_dir', type=str, default='saved_models')
parser.add_argument('--model_dir', type=str, default='pgd_nnmnist_gs_10_rho_pt1')
parser.add_argument('-lr', '--lr', type=float, default=0.001)
parser.add_argument('-gs', '--gradient_steps', default=20, type=int)
parser.add_argument('--save_freq', default=5, type=int)
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--ts_batch_size', type=int, default=64, help='batch size for the train step while updating classifier')
parser.add_argument('-log', type=str, help='name of the log file')
# parser.add_argument('-ci', '--clean_init', action="store_true") # False by default!



if __name__=="__main__":
    args = parser.parse_args()
    print(args)

    log_path = f'log/{args.log}.txt'
    sys.stdout = open(log_path, 'w', 1)

    ## load the defender
    load_path = os.path.join(args.save_dir, args.data, args.model_dir)
    with open(os.path.join(load_path, 'defender.pkl'), 'rb') as f:
        trained_defender = pickle.load(f)

    ## dataset
    print('Loading dataset...')
    # dataset = get_dataset(args)[0]
    dataset = trained_defender.dataset

    S_B = trained_defender.S_dict[args.t-1]  ## (t-2) for training on the S on which the model executed train_step
    # b_vals = [0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5]
    # b_vals is the percentage of data!
    b_vals = np.array([1, 2, 3, 4, 5])
    # attack_size_vals = np.arange(500, 3001, 500)
    for b in b_vals:
        attack_set_size = int((len(dataset)/100)*b)
        print(f"## Training for attacks on b={b}% of dataset and attack set size={attack_set_size}")
        defender = deepcopy(trained_defender)
        S_b = S_B[:attack_set_size]

        train_subset, rem_subset = defender.split_subsets(dataset, S_b)
        model_save_dir = os.path.join(load_path, "b_models")
        os.makedirs(model_save_dir, exist_ok=True)
        model_save_path = os.path.join(model_save_dir, f'model_t{args.t}_b{b}')
        
        
        classifier = NN_MNIST()
        defender.classifier = deepcopy(classifier)
        defender.train_step(train_subset, rem_subset, args.gradient_steps, args.lr, args.weight_decay, args.ts_batch_size, save_between=True, save_freq=args.save_freq, save_path=model_save_path)
        
        pkl_save_dir = os.path.join(load_path, "defenders/")
        os.makedirs(pkl_save_dir, exist_ok=True)
        pkl_save_path = os.path.join(pkl_save_dir, f'defender_t{args.t}_gs{args.gradient_steps}_b{attack_set_size}.pkl')
        with open(pkl_save_path, 'wb') as f:
            pickle.dump(defender, f)
