import argparse
import os
import pickle
import torchattacks
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from copy import deepcopy
import random
import numpy as np

from ncp_acl.builder import get_base_model

parser = argparse.ArgumentParser(description='NCP-ACL Standard Accuracy')

parser.add_argument('--backbone', type=str, default='resnet18')

parser.add_argument('--data_dir', type=str, default='')
parser.add_argument('--dataset', default='', type=str)

parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--device', type=str, default='')
parser.add_argument('--multi_bn', action='store_true')

parser.add_argument('--use_normalize', action='store_true')

parser.add_argument('--encoder_ckpt_path', type=str)
parser.add_argument('--classifier_ckpt_path', type=str)


def fix_seed(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def cvt_state_dict(state_dict, use_pgd_bn=True):
    state_dict_new = {}
    
    if use_pgd_bn:
        bn_index = 1
    else:
        bn_index = 0

    for name in state_dict.keys():
        if 'bn' in name:
            assert 'bn_list' in name, 'state_dict does not consist of the multi batch normalization layer'
            if f'.bn_list.{bn_index}' in name:
                state_dict_new[name.replace(f'.bn_list.{bn_index}', '')] = state_dict[name]
        else:
            state_dict_new[name] = state_dict[name]  

    keys = list(state_dict_new.keys())
    for key in keys:
        if 'downsample.conv' in key:
            state_dict_new[key.replace('downsample.conv', 'downsample.0')] = state_dict_new[key]
            del state_dict_new[key]
        elif 'downsample.bn' in key:
            state_dict_new[key.replace('downsample.bn', 'downsample.1')] = state_dict_new[key]
            del state_dict_new[key]

    return state_dict_new

def set_loader(args):
    augmentation = [transforms.ToTensor(),]        
    test_transforms = transforms.Compose(augmentation)
    
    if args.dataset in ['cifar10', 'cifar100']:
        test_dataset = vars(torchvision.datasets)[args.dataset.upper()](args.data_dir, 
                                                                         train=False,
                                                                         download=True,
                                                                         transform=test_transforms)
        
    else:
        test_dataset = vars(torchvision.datasets)[args.dataset.upper()](args.data_dir,
                                                                        split='test',
                                                                        download=True,
                                                                        transform=test_transforms)
    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=512, 
                                                  shuffle=False,
                                                  num_workers=4,
                                                  pin_memory=True,
                                                  drop_last=False)

    num_classes = 100 if args.dataset == 'cifar100' else 10
    
    return test_dataloader, num_classes

def test_accuracy(model, test_dataloader, device):
    model.to(device)
    model.eval()
    
    cor = 0
    prob_list = []
    max_prob_list = []
    res = {}
    with torch.no_grad():
        for imgs, labels in test_dataloader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            
            outputs = model(imgs)
            cor += (outputs.argmax(dim=-1) == labels).float().sum()

    res['acc'] = cor.detach().cpu() / len(test_dataloader.dataset) * 100
    return res

if __name__ == '__main__':
    args = parser.parse_args()
    
    assert args.dataset in ['cifar10', 'cifar100', 'stl10'], 'Only CIFAR10, CIFAR100, STL10 are supported'
    
    fix_seed(args.seed)
    
    test_dataloader, num_classes = set_loader(args)
    
    encoder_path = args.encoder_ckpt_path
    # path of pre-trained encoder

    classifier_path = args.classifier_ckpt_path
    # path of linear classifier
    
    res_dict = {}
    
    encoder = get_base_model(args.dataset,
                             args.backbone,
                             hidden_dim = num_classes,
                             use_normalize=args.use_normalize,                                                                                       
                             zero_init_residual=True,
                             use_projector=False,
                             multi_bn=False)
    classifier = nn.Linear(512, num_classes)
    

    encoder_checkpoint = torch.load(encoder_path, map_location='cpu')
    classifier_checkpoint = torch.load(classifier_path, map_location='cpu')

    if args.multi_bn:
        encoder_state_dict = cvt_state_dict(encoder_checkpoint['state_dict'], use_pgd_bn=True)
    else:
        encoder_state_dict = encoder_checkpoint['state_dict']
        
    encoder_load_msg = encoder.load_state_dict(encoder_state_dict, strict=True)
    classifier_load_msg = classifier.load_state_dict(classifier_checkpoint['state_dict'], strict=True)
    
    print(f'Encoder Load State Dict MSG {encoder_load_msg}')
    print(f'Classifier Load State Dict MSG {classifier_load_msg}')
    
    encoder.eval()
    classifier.eval()
    
    model = nn.Sequential(encoder, classifier)
    
    standard_acc = test_accuracy(model, test_dataloader, device=args.device)['acc']        
    
    print(f'Dataset {args.dataset} Standard Accuracy: {standard_acc}')