import numpy as np
import matplotlib.pyplot as plt
import torch
import path
import sys
import json

# if torch.cuda.is_available():
#     torch.set_default_tensor_type('torch.cuda.FloatTensor')
#     device = torch.device('cuda')
# else:
#     device = torch.device('cpu')
    
# print('Using device:',device)

import pickle
import argparse
import os
import importlib.util
import datetime
import time
import subprocess

folder_path= (path.Path(__file__).abspath()).parent.parent
sys.path.append(folder_path)
defense_models_path = os.path.join(folder_path, "models/defense")
sys.path.append(defense_models_path)

from data.pytorch_datasets import get_dataset
# from utils import plotter
from models.defender import Defender
from models.adversary import Adversary
# from function_mapping import train   ## define functions for all the algos in function_mapping
from models.defense.nn_mnist import NN_MNIST
from models.defense.nn_cifar10 import NN_CIFAR10
from models.defense import resnet
from resnet import ResNet18
# from models.defense.small_cnn import NN_SmallCNN, SmallCNN
from copy import deepcopy
from sub_selection import RandomSubset, ClassRandomSubset
import socket

parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', default='cifar100', type=str)
parser.add_argument('--classifier_type', '-ct', type=str, default='nn_cifar100')
parser.add_argument('--attack_model_type', '-amt', type=str, default='pgd')
parser.add_argument('--epoch', type=int, default=500)
parser.add_argument('--base_train_epochs', type=int, default=120)
#parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--rho', type=float, default=0.5)   # was 0.01 
parser.add_argument('--lam', type=float, default=1)
parser.add_argument('-s', '--seed', nargs='+', default=0, type=int)
# parser.add_argument('-K', '--K', default=20, type=int, help='cardinality of the set S')
parser.add_argument('-eta', '--eta', default=10, type=float, help='percentage of the dataset on which to attack')
parser.add_argument('-T', '--T', default=100, type=int, help='Number of iterations for the defense algorithm')
# parser.add_argument('-num', '--num_intervals', default=10, type=int)
parser.add_argument('-p', '--plot', action='store_true')
parser.add_argument('-debug', '--debug', action='store_true')
parser.add_argument('--save_dir', type=str, default='saved_models')
parser.add_argument('--save_dir_base_model', type=str, default='models/defense/CIFAR100_models')
parser.add_argument('--data_dir', type=str, default='data')
parser.add_argument('--load', type=bool, default=True)
parser.add_argument('--save-freq', '-sf', default=1, type=int, metavar='N', help='save frequency')
parser.add_argument('--log-interval', type=int, default=1, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--adjust-lr', '-adjust-lr', type=str, default='Epoch') # Epoch, Timestep

## arguments for distorted greedy
parser.add_argument('-g', '--gamma', default=0.01, type=int, help='gamma parameter for distorted greedy')
parser.add_argument('--dg_batch_size', type=int, default=1024, help='batch size for distorted greedy')
parser.add_argument('--eps', type=int, default=0.1, help='error threshold for distorted greedy')

## arguments for base classifier
parser.add_argument('--base_epochs', type=int, default=120)
parser.add_argument('--base_lr', type=float, default=0.1)
parser.add_argument('--base_batch_size', type=int, default=128)
parser.add_argument('--base_weight_decay', type=float, default=2e-4)
parser.add_argument('--base_momentum', type=float, default=0.9, metavar='M', help='SGD momentum')
parser.add_argument('--base_optimizer', '-optimizer', type=str, default='SGD')

## arguments for PGD attack
parser.add_argument('--pgd_eps', type=float, default=0.031)
parser.add_argument('--pgd_step_size', type=float, default=0.007)
parser.add_argument('--pgd_num_steps', type=int, default=20)


##### Parameters for ADV GAN on CIFAR10
parser.add_argument('--num_labels', type=int, default=100)
parser.add_argument('--num_channels', type=int, default=3)
parser.add_argument('--min_clamp', type=float, default=0)
parser.add_argument('--max_clamp', type=float, default=1)
parser.add_argument('--adv_lr', type=float, default=0.001)
parser.add_argument('--adv_epsilon', type=float, default=0.031)
parser.add_argument('--adv_gan_batch_size', type=int, default=128)
parser.add_argument('--adv_gan_train_epochs', type=int, default=60)
parser.add_argument('--adv_gan_retrain_timestep', type=int, default=1)
parser.add_argument('--adv_gan_retraining_epochs', type=int, default=15)
parser.add_argument('--train_GAN', action='store_true')

parser.add_argument('-tbc', type=bool, default=True, help='Trades Base Classifier')
parser.add_argument('-ite', '--init_trades_epochs', type=int, default=20)
parser.add_argument('-tt', '--trades_training', type=bool, default=True)
parser.add_argument('-tes', '--trades_early_stop', action='store_true')

parser.add_argument('-pgdat_base_classifier', type=bool, default=False, help='PGDAT Base Classifier')
parser.add_argument('-pgdat_training', '--pgdat_training', type=bool, default=False)
parser.add_argument('-init_pgdat_epochs', '--init_pgdat_epochs', type=int, default=20)

parser.add_argument('-mbc', '--mart_base_classifier', type=bool, default=True, help='MART Base Classifier')
parser.add_argument('-mt', '--mart_training', type=bool, default=True)
parser.add_argument('-init_mart_epochs', '--init_mart_epochs', type=int, default=20)

parser.add_argument('-lr', '--lr', type=float, default=0.001)
parser.add_argument('-gs', '--gradient_steps', type=int, default=10)
parser.add_argument('--ts_batch_size', type=int, default=128, help='batch size for the train step while updating classifier')
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('-gpu', default=0, type=int, help='gpu:id to be used')
parser.add_argument('-lazy_attack_update', '-lazy_attack_update', default=-1, type=int, help='update attack after few timesteps')

args = parser.parse_args()

if args.data.lower() == 'cifar10' or args.data.lower() == 'cifar100':
    args.pgd_eps = 0.031
    args.pgd_num_steps = 20
    args.pgd_step_size = 0.007
    args.epochs = 120
elif args.data.lower() == 'fmnist':
    print('Setting fmnist args...')
    args.pgd_eps = 0.3
    args.pgd_num_steps = 40
    args.pgd_step_size = 0.01
    args.epochs = 100
    args.classifier_type = 'nn_mnist'
else:
    raise NotImplementedError

if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    device = torch.device('cuda')
    torch.cuda.set_device(args.gpu)
    print('Using Device: ', torch.cuda.get_device_name())
else:
    device = torch.device('cpu')
    
print('Using device:',device)

args.device = device
torch.manual_seed(0)
np.random.seed(0)
dataset = get_dataset(args)[0]
test_ds = get_dataset(args)[1]

id_ = datetime.datetime.now()    ## directory naming can be changed 
loss_plot_dir = os.path.join('..', 'results', f'{id_}', 'loss_plots')
os.makedirs(loss_plot_dir)

save_dir = os.path.join(args.save_dir, args.data, str(id_)[:16])
os.makedirs(save_dir, exist_ok=True)
args.save_dir = save_dir
args.rho_dir = os.path.join(save_dir, f"rho_{args.rho}")
os.makedirs(args.rho_dir, exist_ok=True)

final_save_dir = f'../../results/{args.data}/Ours{"_AdvGAN" if args.attack_model_type=="advgan" else ""}_{args.rho}'
os.makedirs(final_save_dir, exist_ok=True)
print("Model save dir is:", save_dir)
print("Model final save dir is:", final_save_dir)

log_path = args.rho_dir+'/log.txt'
sys.stdout = open(log_path, 'w', 1)

results = {}

begin = time.time()

recon_loss_dict = {}
defender_loss_dict = {}
total_loss_dict = {}

# Set up a defender and train base classifier
defender = Defender(classifier_type=args.classifier_type, dataset=dataset, args=args)

# data_dict = {"train_ds":defender.dataset, "val_ds":defender.val_ds}
# with open('dataset_split.pkl', 'wb') as f:
#     pickle.dump(data_dict, f)
with open(f'dataset{"_"+args.data if args.data!="cifar10" else ""}_split.pkl', 'rb') as f:
    dat = pickle.load(f)
    print("Loading datasets")
    defender.dataset = deepcopy(dat["train_ds"])
    print("Training ds size is: ", len(defender.dataset))
    defender.val_ds = deepcopy(dat["val_ds"])
    print("Training ds size is: ", len(defender.val_ds))

if args.eta:
    args.K = int(len(defender.dataset)*args.eta/100)
print(f"Size of S is {args.eta}% of the dataset")
print("Value of K is ", args.K)
test_set = RandomSubset(test_ds)
test_attacked, test_unattacked = test_set.split_ds(args.K)
defender.set_test_ds(test_attacked, test_unattacked)

if args.load:
    if args.tbc:
        # load_path = f"saved_models/baselines/cifar10/TRADES/TRADES_SGD_40k/model-sgd-best.pt"
        load_path = "TRADES_model_resnet18_cifar100/model-wideres-epoch76.pt"
        # load_path = 'model-fmnist/model-final.pt'
        sd = torch.load(load_path)
        # new_sd = dict()
        # for key in sd.keys():
        #     new_sd[key.replace('shortcut', 'convShortcut')] = sd[key]
        print("TRADES!!!!!")
        print(f"Adjust lr rate at: {args.adjust_lr}")
        print(f"Init trades training with epochs: {args.init_trades_epochs}")
        print(defender.classifier.model)
        defender.classifier.model.load_state_dict(sd)
        # defender.classifier.model.load_state_dict(new_sd)
        # defender.classifier.load_state_dict(sd)
        defender.init_optimizer(args.base_optimizer)
        # opt_load_path = f"saved_models/baselines/cifar10/TRADES/TRADES_SGD_40k/opt-model-sgd-best.tar"
        opt_load_path = "TRADES_model_resnet18_cifar100/opt-wideres-checkpoint_epoch76.tar"
        # opt_load_path = 'model-fmnist/opt-nn-checkpoint_epoch100.tar'
        opt_dict = torch.load(opt_load_path)
        defender.optimizer.load_state_dict(opt_dict)
        for param_group in defender.optimizer.param_groups:
            defender.lr = (param_group['lr'])
            print(f"Learning rate set to {defender.lr}")
            break
        print("Loaded state dict and optimizer dict for TRADES init")
        # print(sd.keys())
        # defender.classifier.model.load_state_dict(sd)
    elif args.mart_base_classifier:
        print("MART base classifier!!!!!!!!")
        load_path = "saved_models/baselines/cifar10/MART/model-final"
        sd = torch.load(load_path)
        print(f"Loaded base classifier from {load_path}")
        defender.classifier.load_state_dict(sd)
        defender.init_optimizer(args.base_optimizer)
        # opt_load_path = 'saved_models/baselines/fmnist/PGD_AT/PGD-AT-opt-final.tar'
        # opt_dict = torch.load(opt_load_path)
        # print(f"Loading optimizer from {opt_load_path}")
        # defender.optimizer.load_state_dict(opt_dict)
        for param_group in defender.optimizer.param_groups:
            defender.lr = (param_group['lr']) = 1e-5            ## last lr while training mart
            print(f"Learning rate set to {defender.lr}")
            break
        print("Loaded state dict and optimizer dict for TRADES init")
    else:
        load_path = f"models/defense/CIFAR10_models/model-120-checkpoint"
        print("Loading from path:", load_path)
        checkpoint = torch.load(load_path)
        defender.classifier.model.load_state_dict(checkpoint["state_dict"])
        defender.init_optimizer(args.base_optimizer)
        defender.lr = checkpoint["lr"]
        print(f"Learning rate (initial) is {defender.lr}")
        defender.optimizer.load_state_dict(checkpoint["optimizer"])
else:
    print("Loading train and val datasets!!")
    with open('dataset_split.pkl', 'rb') as f:
        dat = pickle.load(f)
        print(dat)
        defender.dataset = deepcopy(dat["train_ds"])
        defender.val_ds = deepcopy(dat["val_ds"])
    defender.train_base_classifier(defender.dataset, args, defender.val_ds)
    defender.lr = defender.classifier.curr_lr
    defender.optimizer = defender.classifier.optimizer

# Serialize arguments
args.full_path = os.path.abspath(args.save_dir)
args.hostname = socket.gethostname()
info_path = args.rho_dir+"/info.txt"

fp = open(info_path, 'w')
a_dict = vars(args)
for k,v in a_dict.items():
    fp.write("%s : %s\n" % (k, v))
fp.close()
print("Model save dir is:", save_dir)


adversary = Adversary(dataset=defender.dataset, defender=defender, args=args, attack_model_type=args.attack_model_type)
if args.attack_model_type=="advgan" and args.train_GAN:
    # Train advgan!
    print("Training AdvGAN!")
    adversary.attack_model.set_base_model(defender.classifier)
    adversary.attack_model.train_advgan(defender.dataset, args, args.adv_gan_train_epochs)
    print("AdvGAN training complete!")
elif args.attack_model_type=="advgan":
    ## load the AdvGAN
    print('Loading AdvGAN!')
    advgan_load_path = f'saved_models/{args.data}/AdvGAN/AdvGANnetG_epoch_60.pth'
    sd = torch.load(advgan_load_path)
    adversary.attack_model.netG.load_state_dict(sd)
    print('Loaded succesfully from ', advgan_load_path)

defender.set_adversary(adversary)
defender.update_classifier(dataset=defender.dataset, train_args=args)

fname = "defender_model_K_"+str(args.K)+".pkl"
with open(os.path.join(save_dir, 'defender.pkl'), 'wb') as f:
    pickle.dump(defender, f)
def_save_path = os.path.join(save_dir, "model_final")
defender.classifier.save_model(def_save_path)

def_final_save_path = os.path.join(final_save_dir, "model-final")
defender.classifier.save_model(def_final_save_path)
