import argparse
import os
import re
from collections import OrderedDict
from contextlib import contextmanager

import numpy as np

from core.models import NormDistBase
from ell_inf_models import *
from utils import random_seed

parser = argparse.ArgumentParser(description='Training SortNet')

parser.add_argument('--dataset', default='CIFAR10', type=str)
parser.add_argument('--auxiliary-dir', default=None, type=str)
parser.add_argument('--auxiliary', default=None, type=str)
parser.add_argument('--fraction', default=0.7, type=float)
parser.add_argument('--model', default='SortMLPModel(depth=6,width=5120,scalar=True)', type=str)
parser.add_argument('--loss', default='mixture', type=str)

parser.add_argument('--p-start', default=8.0, type=float)
parser.add_argument('--p-end', default=1000.0, type=float)

parser.add_argument('--eps-train', default=None, type=float)
parser.add_argument('--eps-test', default=None, type=float)
parser.add_argument('--eps-smooth', default=0, type=float)

parser.add_argument('--epochs', default='0,0,200,2950,3000', type=str)
# corresponding to: eps_start, eps_end, p_start, p_end, total
parser.add_argument('--decays', default=None, type=str)
parser.add_argument('-b', '--batch-size', default=512, type=int)
parser.add_argument('--lr', default=0.02, type=float)
parser.add_argument('--wd', default=0.02, type=float)

parser.add_argument('--beta1', default=0.9, type=float)
parser.add_argument('--beta2', default=0.99, type=float)
parser.add_argument('--epsilon', default=1e-10, type=float)

parser.add_argument('--start-epoch', default=0, type=int)

parser.add_argument('--checkpoint', type=str, help='path to checkpoint file')
parser.add_argument('--logfile', type=str, help='target path to store predictions/labels')

parser.add_argument('--gpu', default=-1, type=int, help='GPU id to use')
parser.add_argument('--dist-url', default='tcp://localhost:23456')
parser.add_argument('--world-size', default=1)
parser.add_argument('--rank', default=0)

parser.add_argument('-p', '--print-freq', default=200, type=int, metavar='N', help='print frequency')
parser.add_argument('--result-dir', default='result', type=str)
parser.add_argument('--filter-name', default='', type=str)
parser.add_argument('--seed', default=2021, type=int)
parser.add_argument('--visualize', action='store_true')


def cal_acc(outputs, targets):
    predicted = torch.max(outputs.data, 1)[1]
    return (predicted == targets).float().mean().item()


def parallel_reduce(*argv):
    tensor = torch.FloatTensor(argv).cuda()
    torch.distributed.all_reduce(tensor)
    ret = tensor.cpu() / torch.distributed.get_world_size()
    return ret.tolist()


@contextmanager
def eval(model):
    state = [m.training for m in model.modules()]
    model.eval()
    yield
    for m, s in zip(model.modules(), state):
        m.train(s)


@torch.no_grad()
def certified_test(net, eps, up, down, epoch, test_loader, gpu, logfile):
    outputs = []
    worst_outputs = []
    labels = []
    normdist_models = get_normdist_models(net)
    cur_p = [m.p for m in normdist_models]
    for m in normdist_models:
        m.p = float('inf')

    with eval(net):
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs = inputs.cuda(gpu, non_blocking=True)
            targets = targets.cuda(gpu, non_blocking=True)
            output, worst = net(inputs, targets=targets, eps=eps, up=up, down=down)
            outputs.append(output)
            worst_outputs.append(worst)
            labels.append(targets)
    outputs = torch.cat(outputs, dim=0).cpu().numpy()
    worst_outputs = torch.cat(worst_outputs, dim=0).cpu().numpy()
    labels = torch.cat(labels, dim=0).cpu().numpy()
    np.savez_compressed(logfile, outputs=outputs, worst_outputs=worst_outputs, labels=labels)
    correct = cal_acc(outputs, labels)
    certified = cal_acc(worst_outputs, labels)

    for m, p in zip(normdist_models, cur_p):
        m.p = p
    return correct, certified


def parse_function_call(s):
    s = re.split(r'[()]', s)
    if len(s) == 1:
        return s[0], {}
    name, params, _ = s
    params = re.split(r',\s*', params)
    params = dict([p.split('=') for p in params])
    for key, value in params.items():
        try:
            params[key] = int(value)
        except ValueError:
            try:
                params[key] = float(value)
            except ValueError:
                special = {'True': True, 'False': False, 'None': None}
                try:
                    params[key] = special[value]
                except KeyError:
                    pass
    return name, params


def get_normdist_models(model):
    return [m for m in model.modules() if isinstance(m, NormDistBase)]


def main_worker(gpu, args):
    torch.backends.cudnn.benchmark = True
    random_seed(args.seed + args.rank)  # make data aug different for different processes
    torch.cuda.set_device(gpu)

    assert args.batch_size % args.world_size == 0
    from dataset import load_data, get_statistics, default_eps, input_dim
    train_loader, test_loader = load_data(args.dataset, 'data/', args.batch_size // args.world_size, False,
                                          augmentation=True, auxiliary=None, fraction=args.fraction)
    mean, std = get_statistics(args.dataset)
    num_classes = len(train_loader.dataset.classes)

    model_name, params = parse_function_call(args.model)
    model = globals()[model_name](input_dim=input_dim[args.dataset], num_classes=num_classes, **params)
    model = model.cuda(gpu)

    if args.eps_test is None:
        args.eps_test = default_eps[args.dataset]
    if args.eps_train is None:
        args.eps_train = args.eps_test
    args.eps_train /= std
    args.eps_test /= std

    if args.checkpoint:
        assert os.path.isfile(args.checkpoint)
        checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage.cuda(gpu))
        state_dict = checkpoint['state_dict']
        if next(iter(state_dict)).startswith('module.'):
            new_state_dict = OrderedDict([(k[7:], v) for k, v in state_dict.items()])
            state_dict = new_state_dict
        model.load_state_dict(state_dict)

    up = torch.FloatTensor((1 - mean) / std).view(-1, 1, 1).cuda(gpu)
    down = torch.FloatTensor((0 - mean) / std).view(-1, 1, 1).cuda(gpu)
    certified_test(model, args.eps_test, up, down, args.epochs[-1], test_loader, gpu, args.logfile)


def main():
    args = parser.parse_args()
    main_worker(args.gpu, args)


if __name__ == '__main__':
    main()
