import os, pdb
import argparse
import torch
from rslad_loss import *
from attacks import *
from cifar100_models import *
import torchvision
from torchvision import datasets, transforms
import time
# we fix the random seed to 0, this method can keep the results consistent in the same conputer. 
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True

from robustbench.utils import load_model
from torch.utils.data import Dataset
#########################################################################################################

from argparse import ArgumentParser
from status import ProgressBar
from args import create_parser
try:
    import wandb
except ImportError:
    wandb = None
from autoattack import AutoAttack
import torchattacks

parser = create_parser()
args = parser.parse_known_args()[0]

print(args)

basepath = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if not args.nowand:
    assert wandb is not None, "Wandb not installed, please install it or run without wandb"
    wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), name=args.wandb_name, tags=[args.wandb_tags])
    args.wandb_url = wandb.run.get_url()
    wandb.save(basepath+'/resnet18_'+str(args.method)+'cifar100.py', base_path=basepath)

##########################################################################################################################################

prefix = 'resnet18-CIFAR100_rsladdistill'
epochs = args.epochs
batch_size = args.batch
epsilon = 8/255.0

teacher = WideResNet(image_size=64, depth=34, widen_factor=10, num_classes=200)
teacher = torch.nn.DataParallel(teacher)

checkpoint = torch.load('./models/WRN34_Tiny.pth')
# print(checkpoint.items())
# exit()
if "net" in checkpoint.keys():
    teacher.load_state_dict(checkpoint["net"])
elif "state_dict" in checkpoint.keys():
    teacher.load_state_dict(checkpoint["state_dict"])
elif "model" in checkpoint.keys():
    teacher.load_state_dict(checkpoint["model"])
else:
    teacher.load_state_dict(checkpoint)


teacher = teacher.cuda()
#teacher = teacher.half()
teacher.eval()




class TinyImageNet(Dataset):
    def __init__(self, dataset_type, transform=None):
        self.root = '../dataset/tiny-imagenet-200/'
        data_path = os.path.join(self.root, dataset_type)

        self.dataset = torchvision.datasets.ImageFolder(root=data_path)

        self.transform = transform

    def __getitem__(self, index):
        img, targets = self.dataset[index]

        if self.transform is not None:
            img = self.transform(img)

        return img, targets

    def __len__(self):
        return self.dataset.__len__()



transform_train = transforms.Compose([
    transforms.RandomCrop(64, padding=8),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

train_dataset = TinyImageNet("train", transform_train)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = TinyImageNet("val", transform_test)

testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

if args.student == "RES-18":
    student = pResNet18(num_classes=200)
    student = torch.nn.DataParallel(student)
    student = student.cuda()
elif args.student == "MN-V2":
    student = mobilenet_v2()
    student = torch.nn.DataParallel(student)
    student = student.cuda()
student.train()
optimizer = optim.SGD(student.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)


def kl_loss(a,b):
    return -a*b + torch.log(b+1e-5)*b


class Normalize(nn.Module):
    def __init__(self, mean, std) :
        super(Normalize, self).__init__()
        self.register_buffer('mean', torch.Tensor(mean).to("cuda"))
        self.register_buffer('std', torch.Tensor(std).to("cuda"))
        
    def forward(self, input):
        # Broadcasting
        mean = self.mean.reshape(1, 3, 1, 1).cuda()
        std = self.std.reshape(1, 3, 1, 1).cuda()
        return (input - mean) / std


norm_layer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

student = nn.Sequential(norm_layer, student).cuda()


teacher = nn.Sequential(norm_layer, teacher).cuda()
teacher.eval()


progress_bar = ProgressBar()
XENT_loss = nn.CrossEntropyLoss()
torchattackPGD = torchattacks.PGD(student)
criterion_kl = nn.KLDivLoss(reduction="batchmean")
for epoch in range(1,epochs+1):
    for step,(X,y) in enumerate(trainloader):
        N,_,_,_ = X.shape
        student.train()
        X = X.float().cuda()
        y = y.cuda()
        optimizer.zero_grad()
        inputs_adv = torchattackPGD(X, y)
        # delta = torch.rand_like(X) * args.gamma
        with torch.no_grad():
            delta = inputs_adv - X
            teacher_plus = teacher(X + args.beta * delta)
        
            teacher_logits = teacher(X)
            
            teacher_minus = teacher(X - args.gamma * delta)
        
        student_adv = student(inputs_adv)
        student_plus = student(X + args.beta *delta) 
        student_logits = student(X) 
        student_minus = student(X - args.gamma * delta) 
        
        
        
        kl_loss = criterion_kl(F.log_softmax(student_adv, dim=1), F.softmax(teacher_logits.detach(), dim=1))   
        kl_loss2 = criterion_kl(F.log_softmax(student_plus - student_minus, dim=1), F.softmax((teacher_plus - teacher_minus).detach(), dim=1))
        #kl_loss3 = criterion_kl(F.log_softmax(student_plus + student_minus - 2*student_logits, dim=1), F.softmax((teacher_plus + teacher_minus - 2*teacher_logits).detach(), dim=1))
       
        loss = kl_loss + args.alpha * (epoch/epochs) * kl_loss2 
        loss.backward()
        optimizer.step()

        #print('loss',kl_loss.item(), kl_loss2.item(), kl_loss3.item())

        progress_bar.prog(step, len(trainloader), epoch, loss.item())

    if epoch%1 == 0 :
        test_accs = []
        test_accs_adv = []
        student.eval()
        for step,(test_batch_data,test_batch_labels) in enumerate(testloader):
            test_ifgsm_data = attack_pgd(student,test_batch_data,test_batch_labels,attack_iters=20,step_size=2.0/255.0,epsilon=8.0/255.0)
            with torch.no_grad():
                logits = student(test_batch_data)
                logits_adv = student(test_ifgsm_data)
            
            predictions_adv = np.argmax(logits_adv.cpu().detach().numpy(),axis=1)
            predictions_adv = predictions_adv - test_batch_labels.cpu().detach().numpy()
            
            predictions = np.argmax(logits.cpu().detach().numpy(),axis=1)
            predictions = predictions - test_batch_labels.cpu().detach().numpy()
            
            test_accs = test_accs + predictions.tolist()
            test_accs_adv = test_accs_adv + predictions_adv.tolist()
        test_accs = np.array(test_accs)
        test_accs_adv = np.array(test_accs_adv)
        test_acc = np.sum(test_accs==0)/len(test_accs)
        test_acc_adv = np.sum(test_accs_adv==0)/len(test_accs_adv)
        print('PGD20 acc',test_acc_adv)


        if not args.nowand:
            d2={'clean_acc': test_acc, 'robust_acc': test_acc_adv}
            wandb.log(d2)


    if epoch in [50, 80]:
        for param_group in optimizer.param_groups:
            param_group['lr'] *= 0.1


save_time = time.strftime('%Y-%m-%d', time.localtime(time.time()))
torch.save(student.state_dict(),'./result_models/'+ args.wandb_name + save_time+ str(args.student) + '_tinyimg.pt')

student.eval()
autoattack = AutoAttack(student, norm='Linf', eps=8/255.0, version='standard')
x_total = [x for (x, y) in testloader]
y_total = [y for (x, y) in testloader]
x_total = torch.cat(x_total, 0)
y_total = torch.cat(y_total, 0)
_, robust_acc = autoattack.run_standard_evaluation(x_total, y_total)
print('final AA',robust_acc)
if not args.nowand:
    AA_d = {'RESULT_AA': robust_acc}
    wandb.log(AA_d)

if not args.nowand:
    wandb.finish()