import argparse
import copy
import logging
import os
import time
import math
from shutil import copyfile

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex import amp
from torch.utils.data import DataLoader, TensorDataset

from lip_convnets import LipConvNet
from utils import *

def robust_statistics(losses_arr, correct_arr, certificates_arr, 
                      epsilon_list=[36., 72., 108., 144., 180., 216.,0]):
    mean_loss = np.mean(losses_arr)
    mean_acc = np.mean(correct_arr)
    mean_certs = (certificates_arr * correct_arr).sum()/correct_arr.sum()
    
    robust_acc_list = []
    for epsilon in epsilon_list:
        robust_correct_arr = (certificates_arr > (epsilon/255.)) & correct_arr
        robust_acc = robust_correct_arr.sum()/robust_correct_arr.shape[0]
        robust_acc_list.append(robust_acc)
    return mean_loss, mean_acc, mean_certs, robust_acc_list


# train_loader, test_loader = get_loaders('./tiny-imagenet-200', 32, 'TinyImageNet', normalize=True)
train_loader, test_loader = get_loaders('./cifar-data', 32, 'cifar10', normalize=True)

# model_test = LipConvNet('bcop', 'maxmin', 32, block_size = 7, num_classes=200, input_side=64, lln='lln').cuda()
model_test = LipConvNet('con_orth', 'maxmin', 32, block_size = 1, num_classes=10, input_side=32, lln='lln').cuda()
model_test.load_state_dict(torch.load('./LipConvnet_7.27_cifar10_1_con_orth_oni=7_32_maxmin_cr0.0_lln/best.pth'))
model_test.float()
model_test.eval()

# std = TinyImageNet_std
std = cifar10_std
std = torch.tensor(std).cuda()
L = 1/torch.max(std)

losses_arr, correct_arr, certificates_arr = evaluate_certificates(train_loader, model_test, L)
# losses_arr, correct_arr, certificates_arr = evaluate_certificates(test_loader, model_test, L)
        
test_loss, test_acc, test_cert, train_robust_acc_list = robust_statistics(
            losses_arr, correct_arr, certificates_arr)

print(train_robust_acc_list[0], train_robust_acc_list[1], train_robust_acc_list[2], train_robust_acc_list[6])