import argparse
import os
import pickle
import torchattacks
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from copy import deepcopy
import random
import numpy as np

from ncp_acl.builder import get_base_model

parser = argparse.ArgumentParser(description='NCP-ACL Robust Accuracy Analysis')

parser.add_argument('--backbone', type=str, default='resnet18')

parser.add_argument('--data_dir', type=str, default='')
parser.add_argument('--dataset', default='', type=str)

parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--multi_bn', action='store_true')

parser.add_argument('--pgd_eps', type=float, default=8/255)
parser.add_argument('--pgd_num_steps', type=int, default=20)
parser.add_argument('--pgd_step_size', type=float, default=2/255)

parser.add_argument('--use_normalize', action='store_true')

parser.add_argument('--encoder_ckpt_path', type=str)
parser.add_argument('--classifier_ckpt_path', type=str)

def fix_seed(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def measure_adv_pgd(model, device, eps, alpha, steps, test_dataloader):
    model.to(device)
    model.eval()

    attacker = torchattacks.PGD(model, eps=eps, alpha=alpha, steps=steps, random_start=True) 
    attacker.device = device

    cor = 0

    for imgs, labels in tqdm(test_dataloader):
        adv_images = attacker(imgs, labels)
        labels = labels.to(device)
        adv_preds = model(adv_images)
        cor += (adv_preds.argmax(dim=-1) == labels).float().sum()

    return cor.detach().cpu() / len(test_dataloader.dataset) * 100

def main():
    args = parser.parse_args()
    
    assert args.dataset in ['cifar10', 'cifar100', 'stl10'], 'Only CIFAR10, CIFAR100 and STL10 are supported'
    
    fix_seed(args.seed)
    
    test_dataloader, num_classes = set_loader(args)
    
    encoder_path = args.encoder_ckpt_path
    # path of pre-trained encoder

    classifier_path = args.classifier_ckpt_path
    # path of linear classifier
        
    res_dict = {}
    
    encoder = get_base_model(args.dataset,
                             args.backbone,
                             hidden_dim = num_classes,
                             use_normalize=args.use_normalize,                             
                             zero_init_residual=True,
                             use_projector=False,
                             multi_bn=False)        
    classifier = nn.Linear(512, num_classes)
    

    encoder_checkpoint = torch.load(encoder_path, map_location='cpu')
    classifier_checkpoint = torch.load(classifier_path, map_location='cpu')
    
    if args.multi_bn:
        encoder_state_dict = cvt_state_dict(encoder_checkpoint['state_dict'], use_pgd_bn=True)
    else:
        encoder_state_dict = encoder_checkpoint['state_dict']
    
    encoder_load_msg = encoder.load_state_dict(encoder_state_dict, strict=True)
    classifier_load_msg = classifier.load_state_dict(classifier_checkpoint['state_dict'], strict=True)
    
    print(f'Encoder Load State Dict MSG {encoder_load_msg}')
    print(f'Classifier Load State Dict MSG {classifier_load_msg}')
    
    encoder.eval()
    classifier.eval()
    
    model = nn.Sequential(encoder, classifier)

    robust_acc = measure_adv_pgd(model, 
                                 args.device, 
                                 args.pgd_eps,
                                 args.pgd_step_size,
                                 args.pgd_num_steps,
                                 test_dataloader)
        
    res_dict[f'Robust_Accuracy'] = robust_acc
    print(f'Robust Accuracy: {robust_acc}')
            
        
def set_loader(args):
    augmentation = [transforms.ToTensor(),]    
    test_transforms = transforms.Compose(augmentation)
    
    if args.dataset in ['cifar10', 'cifar100']:
        test_dataset = vars(torchvision.datasets)[args.dataset.upper()](args.data_dir, 
                                                                         train=False,
                                                                         download=False,
                                                                         transform=test_transforms)
        
    else:
        test_dataset = vars(torchvision.datasets)[args.dataset.upper()](args.data_dir,
                                                                        split='test',
                                                                        download=False,
                                                                        transform=test_transforms)
    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=256, 
                                                  shuffle=False,
                                                  num_workers=4,
                                                  pin_memory=True,
                                                  drop_last=False)

    num_classes = 100 if args.dataset == 'cifar100' else 10
    
    return test_dataloader, num_classes
  
def cvt_state_dict(state_dict, use_pgd_bn=True):
    state_dict_new = {}
    
    if use_pgd_bn:
        bn_index = 1
    else:
        bn_index = 0

    for name in state_dict.keys():
        if 'bn' in name:
            assert 'bn_list' in name, 'state_dict does not consist of the multi batch normalization layer'
            if f'.bn_list.{bn_index}' in name:
                state_dict_new[name.replace(f'.bn_list.{bn_index}', '')] = state_dict[name]
        else:
            state_dict_new[name] = state_dict[name]  

    keys = list(state_dict_new.keys())
    for key in keys:
        if 'downsample.conv' in key:
            state_dict_new[key.replace('downsample.conv', 'downsample.0')] = state_dict_new[key]
            del state_dict_new[key]
        elif 'downsample.bn' in key:
            state_dict_new[key.replace('downsample.bn', 'downsample.1')] = state_dict_new[key]
            del state_dict_new[key]

    return state_dict_new

if __name__ == '__main__':
    main()