import os
import argparse
import torchvision
import torch.optim as optim
import torch.nn.utils as torch_utils
from torchvision import transforms
from models import *
from tqdm import tqdm
import numpy as np
import copy
import torchattacks 
from utils import Logger, save_checkpoint, torch_accuracy, AverageMeter, custom_datasets
from sklearn.metrics import accuracy_score 


parser = argparse.ArgumentParser(description='ROAD')
parser.add_argument('--epochs', type=int, default=120, metavar='N', help='number of epochs to train')
parser.add_argument('--arch', type=str, default="WRN", help="decide which network to use, choose from smallcnn, resnet18, WRN")
parser.add_argument('--num_classes', type=int, default=100)
parser.add_argument('--lr', default=0.1, type=float)

parser.add_argument('--epsilon', type=float, default=8/255, help='perturbation bound')
parser.add_argument('--num-steps', type=int, default=10, help='maximum perturbation step')
parser.add_argument('--step-size', type=float, default=2/255, help='step size')

parser.add_argument('--eval',type=bool, default=False, help='whether to resume training')
parser.add_argument('--out-dir',type=str, default='./logs',help='dir of output')

parser.add_argument('--beta',type=float, default=6.0, help='robustness factor')
parser.add_argument('--gamma', type=float, default=3.0, help='guidance factor')

args = parser.parse_args()

# Training settings
args.out_dir = os.path.join(args.out_dir, '')
if not os.path.exists(args.out_dir):
    os.makedirs(args.out_dir)

weight_decay = 5e-4
seed = 1
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

# Setup data loader
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
])

if args.num_classes == 10 :
    trainset =custom_datasets.Custom_CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
    testset = custom_datasets.Custom_CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
else : 
    trainset =custom_datasets.Custom_CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
    testset = custom_datasets.Custom_CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

def adjust_learning_rate(optimizer, epoch):
    """decrease the learning rate"""
    lr = args.lr
    if epoch >= 100:
        lr = args.lr * 0.001
    elif epoch >= 90:
        lr = args.lr * 0.01
    elif epoch >= 75:
        lr = args.lr * 0.1
    if epoch <= 15 :
        lr = 0.001 + epoch / 15 * (args.lr - 0.001)
        
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def get_alpha(epoch):
    angle = epoch * (np.pi / (2 * args.epochs))

    sin_value = np.sin(angle)

    alpha = sin_value * 0.8 

    return alpha

def train(epoch, model, model_nat, optimizer, optimizer_nat, all_predictions_adv, device, descrip_str) :
    kl = torch.nn.KLDivLoss(reduction = 'batchmean')
    losses = AverageMeter()
    clean_accuracy = AverageMeter()
    adv_accuracy = AverageMeter()

    pbar = tqdm(train_loader)
    pbar.set_description(descrip_str)
    alpha = get_alpha(epoch)
         
    for batch_idx, (inputs, targets, input_indices) in enumerate(pbar):
        pbar_dic = OrderedDict()

        inputs, targets = inputs.to(device), targets.to(device)
        
        pgd_atk = torchattacks.PGD(model, eps=8/255, alpha=2/255, steps=10)
        x_adv = pgd_atk(inputs,targets)
        x_adv = x_adv.to(device)

        # for natural 
        targets_numpy = targets.cpu().detach().numpy()
        identity_matrix = torch.eye(len(train_loader.dataset.classes)) 
        targets_one_hot = identity_matrix[targets_numpy]
        
        model.train()
        model_nat.train() 
        optimizer.zero_grad()
        optimizer_nat.zero_grad() 
        
        logit = model(x_adv) 
        logit_nat = model(inputs)

        prob_nat = torch.softmax(logit_nat,dim =-1)
        prob_adv = torch.softmax(logit, dim=-1)

        if epoch == 1 : 
            all_predictions_adv[input_indices] = targets_one_hot
        
        soft_targets_nat = ((1 - alpha) * targets_one_hot) + (alpha * prob_nat.cpu().detach())
        soft_targets_adv = ((1 - alpha) * targets_one_hot) + (alpha * all_predictions_adv[input_indices])
       
        soft_targets_nat = soft_targets_nat.to(device)
        soft_targets_adv = soft_targets_adv.to(device)
        
        nat_logit = model_nat(inputs)
        prob_nat_guide = torch.softmax(nat_logit, dim=-1).detach() 

        # for logits 
        log_pred_adv = F.log_softmax(logit, dim=-1)
        log_pred_nat = F.log_softmax(logit_nat, dim =-1) 
    
        # natural classifier 
        log_pred_nat_guide = F.log_softmax(nat_logit, dim=-1)
        
        # CE 
        loss_ce = (- soft_targets_adv * log_pred_adv).mean(0).sum()
        # loss for robustness 
        kl_loss = kl(log_pred_adv, prob_nat)
        # loss for guidance 
        loss_guide = kl(log_pred_nat, prob_nat_guide)
        
        loss = loss_ce + args.beta * kl_loss + args.gamma * loss_guide 
        loss_nat = (-soft_targets_nat * log_pred_nat_guide).mean(0).sum()

        loss.backward()
        loss_nat.backward() 

        optimizer.step()
        optimizer_nat.step()

        all_predictions_adv[input_indices] = prob_adv.cpu().detach()

        losses.update(loss.item())
        clean_accuracy.update(torch_accuracy(nat_logit, targets, (1,))[0].item())
        adv_accuracy.update(torch_accuracy(logit, targets, (1,))[0].item())

        pbar_dic['loss'] = '{:.2f}'.format(losses.mean)
        pbar_dic['Acc'] = '{:.2f}'.format(clean_accuracy.mean)
        pbar_dic['advAcc'] = '{:.2f}'.format(adv_accuracy.mean)
        pbar.set_postfix(pbar_dic)
    return all_predictions_adv

def test(model, device):
    model.eval()
    
    total_loss = 0.0
    total_num = 0.0
    y_true = []
    y_pred = []
    y_pred_adv = []
    y_logits = []
    
    pbar = tqdm(test_loader)

    for batch_idx, (inputs, targets, input_indices) in enumerate(pbar):
        pgd_atk = torchattacks.PGD(model, eps=8/255, alpha=2/255, steps=20)
        inputs, targets = inputs.to(device), targets.to(device)

        x_adv = pgd_atk(inputs,targets)
        x_adv = x_adv.to(device)
         
        num_batch = targets.shape[0]
        total_num += num_batch

        logits = model(inputs)
        logits_adv = model(x_adv)
        y_true.extend(targets.cpu().tolist())
        y_pred_adv.extend(torch.max(logits_adv, dim=-1)[1].cpu().tolist())
        y_pred.extend(torch.max(logits, dim=-1)[1].cpu().tolist())
        y_logits.append(logits.cpu().detach().numpy())

        top1 = accuracy_score(y_true, y_pred) * 100
        top1_adv = accuracy_score(y_true, y_pred_adv) * 100
    
    return top1, top1_adv


def attack(model, device, num_classes):
    model.eval()

    total_loss = 0.0
    total_num = 0.0
    y_true = []
    y_pred_nat = []
    y_pred_pgd20 = [] 
    y_pred_pgd100 = []
    y_pred_mim = []
    y_pred_aa = []
    
    pbar = tqdm(test_loader)
    pbar.set_description('attack')

    for batch_idx, (inputs, target, input_indices) in enumerate(pbar):
        pbar_dic = OrderedDict()

        pgd20_atk = torchattacks.PGD(model, eps=8/255, alpha=2/255, steps=20)
        pgd100_atk = torchattacks.PGD(model, eps=8/255, alpha=2/255, steps=100)
        mim_atk = torchattacks.MIFGSM(model, eps=8/255, alpha=2/255, steps=10)
        aa_atk = torchattacks.AutoAttack(model, eps=8/255, n_classes= num_classes)
        inputs, target = inputs.to(device), target.to(device)

        x_adv_pgd20 = pgd20_atk(inputs,target)
        x_adv_pgd100 = pgd100_atk(inputs, target)
        x_adv_mim = mim_atk(inputs, target)
        x_adv_aa = aa_atk(inputs, target)

        x_adv_pgd20 = x_adv_pgd20.to(device)
        x_adv_pgd100 = x_adv_pgd100.to(device)
        x_adv_mim = x_adv_mim.to(device)
        x_adv_aa = x_adv_aa.to(device)

        
        num_batch = target.shape[0]
        total_num += num_batch

        logits = model(inputs)
        logits_pgd20 = model(x_adv_pgd20)
        logits_pgd100 = model(x_adv_pgd100)
        logits_mim = model(x_adv_mim)
        logits_aa = model(x_adv_aa)

        y_true.extend(target.cpu().tolist())
        y_pred_pgd20.extend(torch.max(logits_pgd20, dim=-1)[1].cpu().tolist())
        y_pred_pgd100.extend(torch.max(logits_pgd100, dim=-1)[1].cpu().tolist())
        y_pred_mim.extend(torch.max(logits_mim, dim=-1)[1].cpu().tolist())
        y_pred_aa.extend(torch.max(logits_aa, dim=-1)[1].cpu().tolist())
        y_pred_nat.extend(torch.max(logits, dim=-1)[1].cpu().tolist())
        
        clean_accuracy = accuracy_score(y_true, y_pred_nat) * 100
        pgd20_accuracy = accuracy_score(y_true, y_pred_pgd20) * 100
        pgd100_accuracy = accuracy_score(y_true, y_pred_pgd100) * 100
        mim_accuracy = accuracy_score(y_true, y_pred_mim) * 100
        aa_accuracy = accuracy_score(y_true, y_pred_aa) * 100

        pbar_dic['NAT'] = '{:.2f}'.format(clean_accuracy)
        pbar_dic['PGD_20'] = '{:.2f}'.format(pgd20_accuracy)
        pbar_dic['PGD_100'] = '{:.2f}'.format(pgd100_accuracy)
        pbar_dic['MIM'] = '{:.2f}'.format(mim_accuracy)
        pbar_dic['AA'] = '{:.2f}'.format(aa_accuracy)

        pbar.set_postfix(pbar_dic)

    return clean_accuracy, pgd20_accuracy, pgd100_accuracy, mim_accuracy, aa_accuracy


def main():
    best_acc_clean = 0
    best_acc_adv = best_ema_acc_adv = 0
    start_epoch = 1

    if args.arch == "smallcnn":
        model = SmallCNN()
    if args.arch == "resnet18":
        model = ResNet18(num_classes=args.num_classes)
    if args.arch == 'mobilenet':
        model = MobileNetV2(num_classes=args.num_classes)
    if args.arch == "preactresnet18":
        model = PreActResNet18(num_classes=args.num_classes)
    if args.arch == "WRN":
        model = Wide_ResNet(depth=28, num_classes=args.num_classes, widen_factor=10, dropRate=0.0)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # load natural classifier 
    # register hooks to extract feature maps 

    all_predictions_adv = torch.zeros(len(train_loader.dataset), len(train_loader.dataset.classes), dtype=torch.float32)

    model = model.to(device)
    model_nat = copy.deepcopy(model).to(device)
    
    if not args.eval:
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=weight_decay)  
        optimizer_nat = optim.SGD(model_nat.parameters(), lr=args.lr , momentum=0.9, weight_decay=weight_decay) 
        logger_test = Logger(os.path.join(args.out_dir, 'log_results_road_wrn_100.txt'), title='reweight')
        logger_test.set_names(['Epoch', 'Natural', 'PGD20'])

        for epoch in range(start_epoch, args.epochs +1):
            adjust_learning_rate(optimizer, epoch)
            adjust_learning_rate(optimizer_nat, epoch)
            descrip_str = 'Training epoch:{}/{}'.format(epoch, args.epochs)
            
            all_predictions_adv = train(epoch, model, model_nat, optimizer, optimizer_nat, all_predictions_adv, device, descrip_str)
            nat_acc, pgd20_acc= test(model, device=device)
            logger_test.append([epoch, nat_acc, pgd20_acc])
            
            if pgd20_acc > best_acc_adv:
                print('==> Updating the best model..')
                best_acc_adv = pgd20_acc
                torch.save(model.state_dict(), os.path.join(args.out_dir, 'bestpoint_road_wrn_100.pth.tar'))
                torch.save(model_nat.state_dict(), os.path.join(args.out_dir, 'bestpoint_nat_road_wrn_100.pth.tar'))
            # # Save the last checkpoint
            torch.save(model.state_dict(), os.path.join(args.out_dir, 'lastpoint_road_wrn_100.pth.tar'))
            torch.save(model_nat.state_dict(), os.path.join(args.out_dir, 'lastpoint_nat_road_wrn_100.pth.tar'))
    else :
        logger_test = Logger(os.path.join(args.out_dir, 'log_results_road_wrn_100.txt'), title='reweight')

    model.load_state_dict(torch.load(os.path.join(args.out_dir, 'bestpoint_road_wrn_100.pth.tar')))
    clean_accuracy, pgd20_accuracy, pgd100_accuracy, mim_accuracy, aa_accuracy = attack(model, device, args.num_classes)

    logger_test.set_names(['Epoch', 'clean', 'PGD20', 'PGD100', 'MIM', 'AA'])
    logger_test.append([151, clean_accuracy, pgd20_accuracy, pgd100_accuracy, mim_accuracy, aa_accuracy]) 
    logger_test.close()


if __name__ == '__main__':
    main()