import argparse
import os
import pickle
import torchattacks
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from copy import deepcopy
import random
import numpy as np


from ncp_acl.builder import get_base_model


parser = argparse.ArgumentParser(description='NCP-ACL AutoAttack Analysis')

parser.add_argument('--backbone', type=str, default='resnet18')

parser.add_argument('--data_dir', type=str, default='')
parser.add_argument('--dataset', default='', type=str)

parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--device', type=str, default='')
parser.add_argument('--multi_bn', action='store_true')

parser.add_argument('--autoattack_eps', type=float, default=8/255)

parser.add_argument('--use_normalize', action='store_true')

parser.add_argument('--encoder_ckpt_path', type=str)
parser.add_argument('--classifier_ckpt_path', type=str)

def fix_seed(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def measure_adv_autoattack(model, device, args, test_dataloader, norm='Linf', eps=8/255, version='standard', num_classes=10):
    model.to(device)
    model.eval()

    attacker = torchattacks.AutoAttack(model, norm=norm, eps=eps, version=version, n_classes=num_classes)
    attacker.device = device
        
    cor = 0

    for imgs, labels in tqdm(test_dataloader):
        adv_images = attacker(imgs, labels)
        labels = labels.to(device)
        adv_preds = model(adv_images)
        cor += (adv_preds.argmax(dim=-1) == labels).float().sum()

    return cor.detach().cpu() / len(test_dataloader.dataset) * 100


def main():
    args = parser.parse_args()
    
    assert args.dataset in ['cifar10', 'cifar100', 'stl10'], 'Only CIFAR10, CIFAR100 and STL10 are supported'
    
    fix_seed(args.seed)
    
    test_dataloader, num_classes = set_loader(args)
    
    encoder_path = args.encoder_ckpt_path
    # path of pre-trained encoder

    classifier_path = args.classifier_ckpt_path
    # path of linear classifier

    res_dict = {}
    
    encoder = get_base_model(args.dataset,
                             args.backbone,
                             hidden_dim = num_classes,
                             use_normalize=args.use_normalize,                                                          
                             zero_init_residual=True,
                             use_projector=False,
                             multi_bn=False)
    classifier = nn.Linear(512, num_classes)

    encoder_checkpoint = torch.load(encoder_path, map_location='cpu')
    classifier_checkpoint = torch.load(classifier_path, map_location='cpu')
    
    if args.multi_bn:
        encoder_state_dict = cvt_state_dict(encoder_checkpoint['state_dict'], use_pgd_bn=True)
    else:
        encoder_state_dict = encoder_checkpoint['state_dict']
    
    encoder_load_msg = encoder.load_state_dict(encoder_state_dict, strict=True)
    classifier_load_msg = classifier.load_state_dict(classifier_checkpoint['state_dict'], strict=True)
    
    print(f'Encoder Load State Dict MSG {encoder_load_msg}')
    print(f'Classifier Load State Dict MSG {classifier_load_msg}')
    
    encoder.eval()
    classifier.eval()
    
    model = nn.Sequential(encoder, classifier)

    robust_acc = measure_adv_autoattack(model, 
                                        args.device, 
                                        args,
                                        test_dataloader, 
                                        norm='Linf', 
                                        eps=args.autoattack_eps,
                                        version='standard',
                                        num_classes=num_classes)
        
    res_dict[f'AutoAttack_Accuracy'] = robust_acc
    print(f'AutoAttack Accuracy: {robust_acc}')


def set_loader(args):
    augmentation = [transforms.ToTensor(),]    
    test_transforms = transforms.Compose(augmentation)
    
    if args.dataset in ['cifar10', 'cifar100']:
        test_dataset = vars(torchvision.datasets)[args.dataset.upper()](args.data_dir, 
                                                                         train=False,
                                                                         download=False,
                                                                         transform=test_transforms)
        
    else:
        test_dataset = vars(torchvision.datasets)[args.dataset.upper()](args.data_dir,
                                                                        split='test',
                                                                        download=False,
                                                                        transform=test_transforms)
    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=256, 
                                                  shuffle=False,
                                                  num_workers=4,
                                                  pin_memory=True,
                                                  drop_last=False)

    num_classes = 100 if args.dataset == 'cifar100' else 10
    
    return test_dataloader, num_classes
  
    
def cvt_state_dict(state_dict, use_pgd_bn=True):
    state_dict_new = {}
    
    if use_pgd_bn:
        bn_index = 1
    else:
        bn_index = 0

    for name in state_dict.keys():
        if 'bn' in name:
            assert 'bn_list' in name, 'state_dict does not consist of the multi batch normalization layer'
            if f'.bn_list.{bn_index}' in name:
                state_dict_new[name.replace(f'.bn_list.{bn_index}', '')] = state_dict[name]
        else:
            state_dict_new[name] = state_dict[name]  

    keys = list(state_dict_new.keys())
    for key in keys:
        if 'downsample.conv' in key:
            state_dict_new[key.replace('downsample.conv', 'downsample.0')] = state_dict_new[key]
            del state_dict_new[key]
        elif 'downsample.bn' in key:
            state_dict_new[key.replace('downsample.bn', 'downsample.1')] = state_dict_new[key]
            del state_dict_new[key]

    return state_dict_new


if __name__ == '__main__':
    main()