import logging
import os.path as osp

from autoattack import AutoAttack
from robustbench.data import load_cifar10
from robustbench.model_zoo.enums import ThreatModel
from robustbench.utils import load_model

from dent import Dent
from conf import cfg, load_cfg_fom_args
import numpy as np
from utils.torch import set_seed, reorder_data_points
import torch

logger = logging.getLogger(__name__)


def evaluate(description):
    load_cfg_fom_args(description)
    assert cfg.CORRUPTION.DATASET == 'cifar10'
    base_model = load_model(cfg.MODEL.ARCH, cfg.CKPT_DIR,
                       'cifar10', ThreatModel.Linf).cuda()
    if cfg.MODEL.ADAPTATION == "dent":
        assert cfg.MODEL.EPISODIC
        dent_model = Dent(base_model, cfg.OPTIM)

    # only evaluate on NUM_EX data points
    set_seed(0)
    origin_x_test, origin_y_test = load_cifar10(cfg.CORRUPTION.NUM_EX, cfg.DATA_DIR)
    N = origin_x_test.shape[0]
    indices = np.random.permutation(np.arange(N))[:cfg.NUM_EX]
    x_test = origin_x_test[indices]
    y_test = origin_y_test[indices]
    bs=cfg.TEST.BATCH_SIZE

    # reorder data points
    if cfg.ATTACK.REORDER:
        print("Reorder data points.")
        x_test, y_test = reorder_data_points(x_test, y_test, 2*bs)

    x_test, y_test = x_test.cuda(), y_test.cuda()
    
    if not cfg.ATTACK.SKIP_ATTACK:
        print("Evaluate DENT under Auto-Attack (DENT-AA)")
        adversary = AutoAttack(
            dent_model, norm='Linf', eps=8./255., version='standard',
            log_path=osp.join(cfg.SAVE_DIR, cfg.LOG_DEST))
        adversary.run_standard_evaluation(
            x_test, y_test, bs=cfg.TEST.BATCH_SIZE)

    # calculate accuracy
    n_batches = int(np.ceil(x_test.shape[0] / bs))
    flags = torch.zeros(x_test.shape[0], dtype=torch.bool)

    for batch_idx in range(n_batches):
        start_idx = batch_idx * bs
        end_idx = min( (batch_idx + 1) * bs, x_test.shape[0])

        x = x_test[start_idx:end_idx, :].clone().to('cuda')
        y = y_test[start_idx:end_idx].clone().to('cuda')
        with torch.no_grad():
            output = dent_model(x)
        correct_batch = y.eq(output.max(dim=1)[1]).detach().cpu()
        flags[start_idx:end_idx] = correct_batch
     
    accuracy = torch.sum(flags).item() / x_test.shape[0]
    print(f'Accuracy of DENT: {accuracy:.2%}')


if __name__ == '__main__':
    evaluate('"CIFAR-10 AutoAttack Linf 8/255.')
