## Copyright (C) 2019, Huan Zhang <huan@huan-zhang.com>
##                     Hongge Chen <chenhg@mit.edu>
##                     Chaowei Xiao <xiaocw@umich.edu>
## 
## This program is licenced under the BSD 2-Clause License,
## contained in the LICENCE file in this directory.
##
from __future__ import division

import sys
import copy
import torch
from torch.nn import Sequential, Linear, ReLU, CrossEntropyLoss, BCELoss
import numpy as np
from datasets import loaders
from bound_layers import BoundSequential, BoundLinear, BoundConv2d, BoundDataParallel
import torch.optim as optim
# from gpu_profile import gpu_profile
import time
from datetime import datetime
from convex_adversarial import DualNetwork
from eps_scheduler import EpsilonScheduler
from config import load_config, get_path, config_modelloader, config_dataloader, config_dataloader_index, update_dict
from argparser import argparser
import random
# sys.settrace(gpu_profile)

from torch.autograd import Variable

from mixup_utils import mixup_data, mixup_criterion, mixup_data_fixed_lam, mixup_process, to_one_hot, get_lambda, BoundIBPLargeModel
import torch.nn as nn
import torch.nn.functional as F
sys.path.append('..')
if sys.version_info[0] < 3:
    import cPickle as pickle
else:
    import _pickle as pickle
from collections import OrderedDict

import comixup.mixup
from comixup.load_data import load_data_subset
from comixup.utils import distance
from comixup.mixup_parallel import MixupProcessParallel


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class Logger(object):
    def __init__(self, log_file = None):
        self.log_file = log_file

    def log(self, *args, **kwargs):
        print(*args, **kwargs)
        if self.log_file:
            print(*args, **kwargs, file = self.log_file)
            self.log_file.flush()


def Train(model, t, loader, eps_scheduler, max_eps, norm, logger, verbose, train, opt, method, mixup_params = None, weight_params = None, trades_params = None, hinge_robust_loss = False, ifeval = False, **kwargs):
    # if train=True, use training mode
    # if train=False, use test mode, no back prop

    use_mixup = train and mixup_params and mixup_params["use_mixup"] and (t >= mixup_params["warmup_epochs"])
    
    if use_mixup and "end_epoch" in mixup_params.keys() and (t >= mixup_params["end_epoch"]):
        use_mixup = False
    
    num_class = 10
    losses = AverageMeter()
    l1_losses = AverageMeter()
    errors = AverageMeter()
    robust_errors = AverageMeter()
    regular_ce_losses = AverageMeter()
    robust_ce_losses = AverageMeter()
    relu_activities = AverageMeter()
    bound_bias = AverageMeter()
    bound_diff = AverageMeter()
    unstable_neurons = AverageMeter()
    dead_neurons = AverageMeter()
    alive_neurons = AverageMeter()
    batch_time = AverageMeter()
    batch_multiplier = kwargs.get("batch_multiplier", 1)  
    kappa = 1
    beta = 1
    if train:
        model.train() 
    else:
        model.eval()
    # pregenerate the array for specifications, will be used for scatter
    sa = np.zeros((num_class, num_class - 1), dtype = np.int32)
    for i in range(sa.shape[0]):
        for j in range(sa.shape[1]):
            if j < i:
                sa[i][j] = j
            else:
                sa[i][j] = j + 1
    sa = torch.LongTensor(sa) 
    batch_size = loader.batch_size * batch_multiplier
    if batch_multiplier > 1 and train:
        logger.log('Warning: Large batch training. The equivalent batch size is {} * {} = {}.'.format(batch_multiplier, loader.batch_size, batch_size))
    # per-channel std and mean
    std = torch.tensor(loader.std).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
    mean = torch.tensor(loader.mean).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
 
    model_range = 0.0
    # end_eps = eps_scheduler.get_eps(t+1, 0)

    # if end_eps.nelement() == 1:
    #     end_eps = end_eps.item()
    # else:
    #     end_eps = end_eps.mean().item()
    # if end_eps < np.finfo(np.float32).tiny:
    #     logger.log('eps {} close to 0, using natural training'.format(end_eps))
    #     method = "natural"
    if t < eps_scheduler.schedule_start:
        logger.log('using natural training')
        method = "natural"
    if ifeval:
        margin_record = torch.zeros(len(loader.dataset))
    for i, (data, labels, index) in enumerate(loader): 
        start = time.time()
        eps = eps_scheduler.get_eps(t, int(i//batch_multiplier)) 
        if data.is_cuda:
            eps = eps.cuda()
        if train and i % batch_multiplier == 0:   
            opt.zero_grad()
        # generate specifications
        c = torch.eye(num_class).type_as(data)[labels].unsqueeze(1) - torch.eye(num_class).type_as(data).unsqueeze(0) 
        # remove specifications to self
        I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(labels.data).unsqueeze(0)))
        c = (c[I].view(data.size(0),num_class-1,num_class))
        # scatter matrix to avoid compute margin to self
        sa_labels = sa[labels]
        # storing computed lower bounds after scatter
        lb_s = torch.zeros(data.size(0), num_class)
        ub_s = torch.zeros(data.size(0), num_class)


        # FIXME: Assume unnormalized data is from range 0 - 1
        if kwargs["bounded_input"]:
            if norm != np.inf:
                raise ValueError("bounded input only makes sense for Linf perturbation. "
                                 "Please set the bounded_input option to false.")
            data_max = torch.reshape((1. - mean) / std, (1, -1, 1, 1))
            data_min = torch.reshape((0. - mean) / std, (1, -1, 1, 1))
            if not train or eps_scheduler.epsnoise_type == 0:
                if eps.nelement() == 1:
                    data_ub = torch.min(data + (eps.item() / std), data_max)
                    data_lb = torch.max(data - (eps.item() / std), data_min)
                else:
                    eps_index = eps[index]
                    eps_dis= eps_index.view(-1,1,1,1).expand(-1,data.shape[1],data.shape[2],data.shape[3])
                    data_ub = torch.min(data + (eps_dis / std), data_max)
                    data_lb = torch.max(data - (eps_dis / std), data_min)
            elif eps_scheduler.epsnoise_type == 1:
                if eps.nelement() == 1:
                    eps_mean = eps.item()
                    noise = torch.zeros_like(data).normal_(mean = 0, std = min(eps_mean+0.0001, eps_scheduler.sigma))
                    eps_dis = torch.nn.ReLU()(eps.item()+noise)
                    data_ub = torch.min(data + (eps_dis / std), data_max)
                    data_lb = torch.max(data - (eps_dis / std), data_min)
                else:
                    eps_mean = torch.mean(eps).item()
                    noise = -torch.abs(torch.zeros_like(data).normal_(mean = 0, std = min(eps_mean+0.0001, eps_scheduler.sigma)))
                    eps_index = eps[index]
                    eps_dis= torch.nn.ReLU()(eps_index.view(-1,1,1,1).expand(-1,data.shape[1],data.shape[2],data.shape[3])+noise)
                    data_ub = torch.min(data + (eps_dis / std), data_max)
                    data_lb = torch.max(data - (eps_dis / std), data_min)
            elif eps_scheduler.epsnoise_type == 2:
                if eps.nelement() == 1:
                    eps_mean = eps.item()
                    noise = -torch.abs(torch.zeros_like(data).normal_(mean = 0, std = min(eps_mean+0.0001, eps_scheduler.sigma)))
                    eps_dis = torch.nn.ReLU()(eps.item()+noise)
                    data_ub = torch.min(data + (eps_dis / std), data_max)
                    data_lb = torch.max(data - (eps_dis / std), data_min)
                else:
                    eps_mean = torch.mean(eps).item()
                    noise = -torch.abs(torch.zeros_like(data).normal_(mean = 0, std = min(eps_mean+0.0001, eps_scheduler.sigma)))
                    eps_index = eps[index]
                    eps_dis= torch.nn.ReLU()(eps_index.view(-1,1,1,1).expand(-1,data.shape[1],data.shape[2],data.shape[3])+noise)
                    data_ub = torch.min(data + (eps_dis / std), data_max)
                    data_lb = torch.max(data - (eps_dis / std), data_min)
        else:
            if norm == np.inf:
                data_ub = data + (eps / std)
                data_lb = data - (eps / std)
            else:
                # For other norms, eps will be used instead.
                data_ub = data_lb = data

        if list(model.parameters())[0].is_cuda:
            data = data.cuda()
            data_ub = data_ub.cuda()
            data_lb = data_lb.cuda()
            labels = labels.cuda()
            c = c.cuda()
            sa_labels = sa_labels.cuda()
            lb_s = lb_s.cuda()
            ub_s = ub_s.cuda()
        # convert epsilon to a tensor
        # eps_tensor = data.new(1)
        # eps_tensor[0] = eps
        
        if use_mixup:
            bce_loss = BCELoss().cuda()
            softmax = torch.nn.Softmax(dim=1).cuda()
            
            if mixup_params["mixup_type"] == "standard":
                alpha = mixup_params["alpha"]

                if mixup_params["loss_func"] == "CE":
                    mixed_data, targets_a, targets_b, lam = mixup_data(data, labels, alpha)
                    mixed_data = mixed_data.cuda()
                    # not one-hot encoded
                    targets_a = targets_a.cuda()
                    targets_b = targets_b.cuda()
                    
                    output = model(mixed_data, method_opt="forward", disable_multi_gpu = (method == "natural"))
                    
                    loss_func = mixup_criterion(targets_a, targets_b, lam)
                    regular_ce = loss_func(CrossEntropyLoss(), output)

                if mixup_params["loss_func"] == "BCE":
                    data_var, labels_var = Variable(data), Variable(labels)
                    lam = get_lambda(alpha)
                    lam = torch.from_numpy(np.array([lam]).astype('float32')).cuda()
                    lam = Variable(lam)

                    labels_reweighted = to_one_hot(labels_var, num_class)
                    labels_reweighted = labels_reweighted.cuda()

                    mixed_data, reweighted_target = mixup_process(data_var, labels_reweighted, lam=lam)
                    output = model(mixed_data, method_opt="forward", disable_multi_gpu = (method == "natural"))
                    regular_ce = bce_loss(softmax(output), reweighted_target)

            if mixup_params["mixup_type"] == "fixed_lambda":
                mixed_data, labels = mixup_data_fixed_lam(data, labels, mixup_params["lambda"])
                mixed_data = mixed_data.cuda()
                output = model(mixed_data, method_opt="forward", disable_multi_gpu = (method == "natural"))
                regular_ce = CrossEntropyLoss()(output, labels)
                
            if mixup_params["mixup_type"] == "manifold":
                assert mixup_params["loss_func"] == "BCE"
                alpha = mixup_params["alpha"]
                data_var, labels_var = Variable(data), Variable(labels)
                output, reweighted_target = model(data_var, labels_var, mixup_hidden= True, mixup_alpha = alpha,  method_opt="forward")
                regular_ce = bce_loss(softmax(output), reweighted_target)
                
            # Code taken from https://github.com/snu-mllab/Co-Mixup/blob/main/main.py
            # Please reference for input mixup parameters
            if mixup_params["mixup_type"] == "comixup":
                
                bce_loss = nn.BCELoss().cuda()
                bce_loss_sum = nn.BCELoss(reduction='sum').cuda()
                softmax = nn.Softmax(dim=1).cuda()
                criterion = nn.CrossEntropyLoss().cuda()
                criterion_batch = nn.CrossEntropyLoss(reduction='none').cuda()
                
                input_var = Variable(data, requires_grad=True)
                target_var = Variable(labels)
                A_dist = None
    
                # Calculate saliency (unary)
                if mixup_params["clean_lam"] == 0:
                    model.eval()
                    output = model(input_var, method_opt="forward", disable_multi_gpu = (method == "natural"))
                    loss_batch = criterion_batch(output, target_var)
                else:
                    model.train()
                    output = model(input_var, method_opt="forward", disable_multi_gpu = (method == "natural"))
                    loss_batch = 2 * mixup_params["clean_lam"] * criterion_batch(
                        output, target_var) / mixup_params["num_classes"]
                loss_batch_mean = torch.mean(loss_batch, dim=0)
                loss_batch_mean.backward(retain_graph=True)
                sc = torch.sqrt(torch.mean(input_var.grad**2, dim=1))
    
                # Here, we calculate distance between most salient location (Compatibility)
                # We can try various measurements
                with torch.no_grad():
                    z = F.avg_pool2d(sc, kernel_size=8, stride=1)
                    z_reshape = z.reshape(min(batch_size,sc.shape[0]), -1)
                    z_idx_1d = torch.argmax(z_reshape, dim=1)
                    z_idx_2d = torch.zeros((min(batch_size,sc.shape[0]), 2), device=z.device)
                    z_idx_2d[:, 0] = z_idx_1d // z.shape[-1]
                    z_idx_2d[:, 1] = z_idx_1d % z.shape[-1]
                    A_dist = distance(z_idx_2d, dist_type='l1')
    
                if mixup_params["clean_lam"] == 0:
                    model.train()
                    optimizer.zero_grad()
    
                # Perform mixup and calculate loss
                target_reweighted = comixup.utils.to_one_hot(labels, mixup_params["num_classes"])
                """
                if args.parallel:
                    device = input.device
                    out, target_reweighted = mpp(input.cpu(),
                                                 target_reweighted.cpu(),
                                                 args=args,
                                                 sc=sc.cpu(),
                                                 A_dist=A_dist.cpu())
                    out = out.to(device)
                    target_reweighted = target_reweighted.to(device)
    
                else:
                """
                out, target_reweighted = comixup.mixup.mixup_process(data,
                                                       target_reweighted,
                                                       mixup_params=mixup_params,
                                                       sc=sc,
                                                       A_dist=A_dist)
    
                out = model(out, method_opt="forward", disable_multi_gpu = (method == "natural"))
                regular_ce = bce_loss(softmax(out), target_reweighted)
                
                
            regular_ce_losses.update(regular_ce.cpu().detach().numpy(), data.size(0))
            
            errors.update(torch.sum(torch.argmax(output, dim=1)!=labels).cpu().detach().numpy()/data.size(0), data.size(0))
            # get range statistic
            model_range = output.max().detach().cpu().item() - output.min().detach().cpu().item()
        else:
            output = model(data, method_opt="forward", disable_multi_gpu = (method == "natural"))
            regular_ce = CrossEntropyLoss()(output, labels)
            if hinge_robust_loss and False:
                prob = torch.nn.Softmax(dim =1)(output)
                # prob = output
                output_ = torch.clone(output)
                output_max, output_max_idx = torch.max(output_, dim = 1)
                prob_max = torch.gather(prob, 1, output_max_idx.unsqueeze(-1)).squeeze(-1)
                if_correct = (output_max_idx==labels).float().cuda()
                output_[torch.arange(output_.shape[0]), output_max_idx] = -float("Inf")
                output_max2, output_max2_idx = torch.max(output_, dim = 1)
                prob_max2 = torch.gather(prob, 1, output_max2_idx.unsqueeze(-1)).squeeze(-1)
                prob_correct = torch.gather(prob, 1, labels.unsqueeze(-1)).squeeze(-1)
                margin = (prob_correct-prob_max)*(1-if_correct)+(prob_correct-prob_max2)*if_correct
                regular_hl = torch.mean(torch.nn.ReLU()(0.1-margin))
                # regular_hl = -torch.mean(torch.log(margin+1))
                # regular_hl =torch.mean( (10*-margin).exp())
                # regular_hl = -torch.mean(prob_correct)
                # regular_hl = -torch.mean(
                #     (torch.log(prob_correct)-torch.log(prob_max))*(1-if_correct)+(torch.log(prob_correct)-torch.log(prob_max2))*if_correct)
                regular_ce = regular_hl
                
            regular_ce_losses.update(regular_ce.cpu().detach().numpy(), data.size(0))
            errors.update(torch.sum(torch.argmax(output, dim=1)!=labels).cpu().detach().numpy()/data.size(0), data.size(0))
            # get range statistic
            model_range = output.max().detach().cpu().item() - output.min().detach().cpu().item()
        
        '''
        torch.set_printoptions(threshold=5000)
        print('prediction:  ', output)
        ub, lb, _, _, _, _ = model(norm=norm, x_U=data_ub, x_L=data_lb, eps=eps, C=c, method_opt="interval_range")
        lb = lb_s.scatter(1, sa_labels, lb)
        ub = ub_s.scatter(1, sa_labels, ub)
        print('interval ub: ', ub)
        print('interval lb: ', lb)
        ub, _, lb, _ = model(norm=norm, x_U=data_ub, x_L=data_lb, eps=eps, C=c, upper=True, lower=True, method_opt="backward_range")
        lb = lb_s.scatter(1, sa_labels, lb)
        ub = ub_s.scatter(1, sa_labels, ub)
        print('crown-ibp ub: ', ub)
        print('crown-ibp lb: ', lb) 
        ub, _, lb, _ = model(norm=norm, x_U=data_ub, x_L=data_lb, eps=eps, C=c, upper=True, lower=True, method_opt="full_backward_range")
        lb = lb_s.scatter(1, sa_labels, lb)
        ub = ub_s.scatter(1, sa_labels, ub)
        print('full-crown ub: ', ub)
        print('full-crown lb: ', lb)
        input()
        '''
        if eps.nelement() == 1:
            eps_mean = eps.item()
        else:
            eps_mean = torch.mean(eps).item()


        if verbose or method != "natural":
            if kwargs["bound_type"] == "convex-adv":
                # Wong and Kolter's bound, or equivalently Fast-Lin
                if kwargs["convex-proj"] is not None:
                    proj = kwargs["convex-proj"]
                    if norm == np.inf:
                        norm_type = "l1_median"
                    elif norm == 2:
                        norm_type = "l2_normal"
                    else:
                        raise(ValueError("Unsupported norm {} for convex-adv".format(norm)))
                else:
                    proj = None
                    if norm == np.inf:
                        norm_type = "l1"
                    elif norm == 2:
                        norm_type = "l2"
                    else:
                        raise(ValueError("Unsupported norm {} for convex-adv".format(norm)))
                if loader.std == [1] or loader.std == [1, 1, 1]:
                    convex_eps = eps
                else:
                    convex_eps = eps / np.mean(loader.std)
                    # for CIFAR we are roughly / 0.2
                    # FIXME this is due to a bug in convex_adversarial, we cannot use per-channel eps
                if norm == np.inf:
                    # bounded input is only for Linf
                    if kwargs["bounded_input"]:
                        # FIXME the bounded projection in convex_adversarial has a bug, data range must be positive
                        assert loader.std == [1,1,1] or loader.std == [1]
                        data_l = 0.0
                        data_u = 1.0
                    else:
                        data_l = -np.inf
                        data_u = np.inf
                else:
                    data_l = data_u = None
                f = DualNetwork(model, data, convex_eps, proj = proj, norm_type = norm_type, bounded_input = kwargs["bounded_input"], data_l = data_l, data_u = data_u)
                lb = f(c)
            elif kwargs["bound_type"] == "interval":
                ub, lb, relu_activity, unstable, dead, alive = model(norm=norm, x_U=data_ub, x_L=data_lb, eps=eps_mean, C=c, method_opt="interval_range")
            elif kwargs["bound_type"] == "crown-full":
                _, _, lb, _ = model(norm=norm, x_U=data_ub, x_L=data_lb, eps=eps_mean, C=c, upper=False, lower=True, method_opt="full_backward_range")
                unstable = dead = alive = relu_activity = torch.tensor([0])
            elif kwargs["bound_type"] == "crown-interval":
                # Enable multi-GPU only for the computationally expensive CROWN-IBP bounds, 
                # not for regular forward propagation and IBP because the communication overhead can outweigh benefits, giving little speedup. 
                ub, ilb, relu_activity, unstable, dead, alive = model(norm=norm, x_U=data_ub, x_L=data_lb, eps=eps_mean, C=c, method_opt="interval_range")
                crown_final_beta = kwargs['final-beta']
                beta = eps_scheduler.get_beta(t, int(i//batch_multiplier) , crown_final_beta)
                if beta < 1e-5:
                    lb = ilb
                else:
                    if kwargs["runnerup_only"]:
                        # regenerate a smaller c, with just the runner-up prediction
                        # mask ground truthlabel output, select the second largest class
                        # print(output)
                        # torch.set_printoptions(threshold=5000)
                        masked_output = output.detach().scatter(1, labels.unsqueeze(-1), -100)
                        # print(masked_output)
                        # location of the runner up prediction
                        runner_up = masked_output.max(1)[1]
                        # print(runner_up)
                        # print(labels)
                        # get margin from the groud-truth to runner-up only
                        runnerup_c = torch.eye(num_class).type_as(data)[labels]
                        # print(runnerup_c)
                        # set the runner up location to -
                        runnerup_c.scatter_(1, runner_up.unsqueeze(-1), -1)
                        runnerup_c = runnerup_c.unsqueeze(1).detach()
                        # print(runnerup_c)
                        # get the bound for runnerup_c
                        _, _, clb, bias = model(norm=norm, x_U=data_ub, x_L=data_lb, eps=eps_mean, C=c, method_opt="backward_range")
                        clb = clb.expand(clb.size(0), num_class - 1)
                    else:
                        # get the CROWN bound using interval bounds 
                        _, _, clb, bias = model(norm=norm, x_U=data_ub, x_L=data_lb, eps=eps_mean, C=c, method_opt="backward_range")
                        bound_bias.update(bias.sum() / data.size(0))
                    # how much better is crown-ibp better than ibp?
                    diff = (clb - ilb).sum().item()
                    bound_diff.update(diff / data.size(0), data.size(0))
                    # lb = torch.max(lb, clb)
                    lb = clb * beta + ilb * (1 - beta)
            else:
                raise RuntimeError("Unknown bound_type " + kwargs["bound_type"]) 
            lb = lb_s.scatter(1, sa_labels, lb)

            logprob = torch.nn.LogSoftmax(dim =1)(-lb)
            loglikelihood = -torch.gather(logprob, 1, labels.unsqueeze(-1))
            sm = torch.nn.functional.softmax(-lb)
            llh = -torch.log(torch.gather(sm, 1, labels.unsqueeze(-1)))
            
            sumofexps = torch.sum(torch.exp(-lb),axis = 1)



            if weight_params and weight_params["weight_type"] != 0 and train and t >= eps_scheduler.schedule_start:
                if weight_params["weight_type"] == 1:
                    alpha = weight_params["alpha"]
                    gamma = weight_params["gamma"]
                    lb_ = torch.clone(lb).detach()
                    prob = torch.nn.Softmax(dim =1)(-lb_)
                    lb_min, lb_min_idx = torch.min(lb_, dim = 1)
                    prob_max = torch.gather(prob, 1, lb_min_idx.unsqueeze(-1)).squeeze(-1)
                    if_correct = (lb_min>=0).float().cuda()
                    lb_[torch.arange(lb.shape[0]), lb_min_idx] = float("Inf")
                    lb_min2, lb_min2_idx = torch.min(lb_, dim = 1)
                    prob_max2 = torch.gather(prob, 1, lb_min2_idx.unsqueeze(-1)).squeeze(-1)
                    prob_correct = torch.gather(prob, 1, labels.unsqueeze(-1)).squeeze(-1)
                    margin = (prob_correct-prob_max)*(1-if_correct)+(prob_correct-prob_max2)*if_correct
                    weight = torch.exp(-gamma*margin)+alpha


                elif weight_params["weight_type"] == 2:
                    alpha = weight_params["alpha"]
                    gamma = weight_params["gamma"]
                    lb_ = torch.clone(lb).detach()
                    prob = torch.nn.Softmax(dim =1)(-lb_)
                    lb_min, lb_min_idx = torch.min(lb_, dim = 1)
                    prob_max = torch.gather(prob, 1, lb_min_idx.unsqueeze(-1)).squeeze(-1)
                    if_correct = (lb_min>=0).float().cuda()
                    lb_[torch.arange(lb.shape[0]), lb_min_idx] = float("Inf")
                    lb_min2, lb_min2_idx = torch.min(lb_, dim = 1)
                    prob_max2 = torch.gather(prob, 1, lb_min2_idx.unsqueeze(-1)).squeeze(-1)
                    prob_correct = torch.gather(prob, 1, labels.unsqueeze(-1)).squeeze(-1)
                    margin = (prob_correct-prob_max)*(1-if_correct)+(prob_correct-prob_max2)*if_correct
                    weight = torch.exp(-gamma*torch.abs(margin))+alpha
            elif ifeval:
                    lb_ = torch.clone(lb).detach()
                    prob = torch.nn.Softmax(dim =1)(-lb_)
                    lb_min, lb_min_idx = torch.min(lb_, dim = 1)
                    prob_max = torch.gather(prob, 1, lb_min_idx.unsqueeze(-1)).squeeze(-1)
                    if_correct = (lb_min>=0).float().cuda()
                    lb_[torch.arange(lb.shape[0]), lb_min_idx] = float("Inf")
                    lb_min2, lb_min2_idx = torch.min(lb_, dim = 1)
                    prob_max2 = torch.gather(prob, 1, lb_min2_idx.unsqueeze(-1)).squeeze(-1)
                    prob_correct = torch.gather(prob, 1, labels.unsqueeze(-1)).squeeze(-1)
                    margin = (prob_correct-prob_max)*(1-if_correct)+(prob_correct-prob_max2)*if_correct
                    margin_record[index] = margin.cpu()
                    weight = torch.ones(labels.size(), dtype=torch.float32).cuda()

            else:
                weight = torch.ones(labels.size(), dtype=torch.float32).cuda()

            
            
            if weight_params and weight_params.get("rob_neglect_clean",False) and not ifeval:
                _, output_max_index = torch.max(output,dim = 1)
                if_clean_correct = torch.eq(labels,output_max_index).type(torch.float32).cuda()
                regular_prob = torch.nn.Softmax()(output)
                prob_true = torch.gather(regular_prob, 1, labels.unsqueeze(-1)).squeeze(-1)
                weight = weight*prob_true
            weight_norm =  weight / weight.sum()
            robust_ce = torch.dot(loglikelihood.squeeze(1), weight_norm)

            if hinge_robust_loss:
                lb_ = torch.clone(lb).detach()
                prob = torch.nn.Softmax(dim =1)(-lb_)
                # prob = -lb_
                lb_min, lb_min_idx = torch.min(lb_, dim = 1)
                prob_max = torch.gather(prob, 1, lb_min_idx.unsqueeze(-1)).squeeze(-1)
                if_correct = (lb_min>=0).float().cuda()
                lb_[torch.arange(lb.shape[0]), lb_min_idx] = float("Inf")
                lb_min2, lb_min2_idx = torch.min(lb_, dim = 1)
                prob_max2 = torch.gather(prob, 1, lb_min2_idx.unsqueeze(-1)).squeeze(-1)
                prob_correct = torch.gather(prob, 1, labels.unsqueeze(-1)).squeeze(-1)
                margin = (prob_correct-prob_max)*(1-if_correct)+(prob_correct-prob_max2)*if_correct
                # robust_hl = torch.nn.ReLU()(0.5-(torch.nn.ReLU()(margin+0.4)-0.4))
                # robust_hl = -torch.log(1+margin+0.01)
                robust_hl = -torch.log(prob_correct)
                # robust_hl = (10*-margin).exp()
                # robust_hl = -prob_correct
                robust_ce = torch.dot(robust_hl, weight_norm)


            # robust_ce = CrossEntropyLoss()(-lb, labels)
            if train and eps_scheduler.schedule_type in ['autoeps', "autoeps-linear", "autoeps-smooth"]:
                eps_scheduler.update_eps_sample(index, lb, t, labels)
            
            if kwargs["bound_type"] != "convex-adv":
                
                relu_activities.update(relu_activity.sum().detach().cpu().item() / data.size(0), data.size(0))
                unstable_neurons.update(unstable.sum().detach().cpu().item() / data.size(0), data.size(0))
                dead_neurons.update(dead.sum().detach().cpu().item() / data.size(0), data.size(0))
                alive_neurons.update(alive.sum().detach().cpu().item() / data.size(0), data.size(0))

        if method == "robust":
            loss = robust_ce
            if trades_params and "lambda" in trades_params:
                criterion_kl = nn.KLDivLoss(size_average=False)
                lambda_trade  = trades_params["lambda"]
                regular_prob = torch.nn.Softmax()(output)
                prob = torch.nn.Softmax(dim =1)(-lb)
                prob_log = prob.log()
                adv_log_prob = F.log_softmax(-lb, dim=1)
                natural_prob = F.softmax(output, dim=1)
                trade_regu = criterion_kl(adv_log_prob, natural_prob)/lb.shape[0]*lambda_trade
                # trade_regu = torch.sum(torch.abs(-regular_prob))/regular_prob.shape[0]*lambda_trade
                loss += trade_regu
        elif method == "robust_activity":
            loss = robust_ce + kwargs["activity_reg"] * relu_activity.sum()
        elif method == "natural":
            loss = regular_ce
        elif method == "robust_natural":
            natural_final_factor = kwargs["final-kappa"]
            # kappa = (max_eps - eps * (1.0 - natural_final_factor)) / max_eps
            kappa = eps_scheduler.get_kappa(t, int(i//batch_multiplier) ,natural_final_factor)
            loss = (1-kappa) * robust_ce + kappa * regular_ce
            if trades_params and "lambda" in trades_params:
                criterion_kl = nn.KLDivLoss(size_average=False)
                lambda_trade  = trades_params["lambda"]
                regular_prob = torch.nn.Softmax()(output)
                prob = torch.nn.Softmax(dim =1)(-lb)
                prob_log = prob.log()
                adv_log_prob = F.log_softmax(-lb, dim=1)
                natural_prob = F.softmax(output, dim=1)
                trade_regu = criterion_kl(adv_log_prob, natural_prob)/lb.shape[0]*lambda_trade
                # trade_regu = torch.sum(torch.abs(-regular_prob))/regular_prob.shape[0]*lambda_trade
                loss += trade_regu
            
        else:
            raise ValueError("Unknown method " + method)

        if train and kwargs["l1_reg"] > np.finfo(np.float32).tiny:
            reg = kwargs["l1_reg"]
            l1_loss = 0.0
            for name, param in model.named_parameters():
                if 'bias' not in name:
                    l1_loss = l1_loss + torch.sum(torch.abs(param))
            l1_loss = reg * l1_loss
            loss = loss + l1_loss
            l1_losses.update(l1_loss.cpu().detach().numpy(), data.size(0))
        if train:
            loss.backward()
            if i % batch_multiplier == 0 or i == len(loader) - 1:
                opt.step()

        losses.update(loss.cpu().detach().numpy(), data.size(0))

        if verbose or method != "natural":
            robust_ce_losses.update(robust_ce.cpu().detach().numpy(), data.size(0))
            # robust_ce_losses.update(robust_ce, data.size(0))
            robust_errors.update(torch.sum((lb<0).any(dim=1)).cpu().detach().numpy() / data.size(0), data.size(0))

        batch_time.update(time.time() - start)
        if i % 50 == 0 and train:
            logger.log(  '[{:2d}:{:4d}]: mean eps {:4f}  '
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                    'Total Loss {loss.val:.4f} ({loss.avg:.4f})  '
                    'L1 Loss {l1_loss.val:.4f} ({l1_loss.avg:.4f})  '
                    'CE {regular_ce_loss.val:.4f} ({regular_ce_loss.avg:.4f})  '
                    'RCE {robust_ce_loss.val:.4f} ({robust_ce_loss.avg:.4f})  '
                    'Err {errors.val:.4f} ({errors.avg:.4f})  '
                    'Rob Err {robust_errors.val:.4f} ({robust_errors.avg:.4f})  '
                    'Uns {unstable.val:.1f} ({unstable.avg:.1f})  '
                    'Dead {dead.val:.1f} ({dead.avg:.1f})  '
                    'Alive {alive.val:.1f} ({alive.avg:.1f})  '
                    'Tightness {tight.val:.5f} ({tight.avg:.5f})  '
                    'Bias {bias.val:.5f} ({bias.avg:.5f})  '
                    'Diff {diff.val:.5f} ({diff.avg:.5f})  '
                    'R {model_range:.3f}  '
                    'beta {beta:.3f} ({beta:.3f})  '
                    'kappa {kappa:.3f} ({kappa:.3f})  '.format(
                    t, i, eps.mean(), batch_time=batch_time,
                    loss=losses, errors=errors, robust_errors = robust_errors, l1_loss = l1_losses,
                    regular_ce_loss = regular_ce_losses, robust_ce_loss = robust_ce_losses, 
                    unstable = unstable_neurons, dead = dead_neurons, alive = alive_neurons,
                    tight = relu_activities, bias = bound_bias, diff = bound_diff,
                    model_range = model_range, 
                    beta=beta, kappa = kappa))
            
            if eps_scheduler.schedule_type in['autoeps', "autoeps-linear", "autoeps-smooth"]:
                logger.log('mean eps {:4f} min eps {:4f} max eps {:4f}'.format(torch.mean(eps), torch.min(eps), torch.max(eps)))

    
                    
    logger.log(  '[FINAL RESULT epoch:{:2d} mean eps:{:.4f}]: '
        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
        'Total Loss {loss.val:.4f} ({loss.avg:.4f})  '
        'L1 Loss {l1_loss.val:.4f} ({l1_loss.avg:.4f})  '
        'CE {regular_ce_loss.val:.4f} ({regular_ce_loss.avg:.4f})  '
        'RCE {robust_ce_loss.val:.4f} ({robust_ce_loss.avg:.4f})  '
        'Uns {unstable.val:.3f} ({unstable.avg:.3f})  '
        'Dead {dead.val:.1f} ({dead.avg:.1f})  '
        'Alive {alive.val:.1f} ({alive.avg:.1f})  '
        'Tight {tight.val:.5f} ({tight.avg:.5f})  '
        'Bias {bias.val:.5f} ({bias.avg:.5f})  '
        'Diff {diff.val:.5f} ({diff.avg:.5f})  '
        'Err {errors.val:.4f} ({errors.avg:.4f})  '
        'Rob Err {robust_errors.val:.4f} ({robust_errors.avg:.4f})  '
        'R {model_range:.3f}  '
        'beta {beta:.3f} ({beta:.3f})  '
        'kappa {kappa:.3f} ({kappa:.3f})  \n'.format(
        t, eps.mean(), batch_time=batch_time,
        loss=losses, errors=errors, robust_errors = robust_errors, l1_loss = l1_losses,
        regular_ce_loss = regular_ce_losses, robust_ce_loss = robust_ce_losses, 
        unstable = unstable_neurons, dead = dead_neurons, alive = alive_neurons,
        tight = relu_activities, bias = bound_bias, diff = bound_diff,
        model_range = model_range, 
        kappa = kappa, beta=beta))

    # TODO (marinazh): Hack to implement manifold mixup
    if not (mixup_params and mixup_params["mixup_type"] == "manifold"):
        for i, l in enumerate(model if isinstance(model, BoundSequential) else model.module):
            if isinstance(l, BoundLinear) or isinstance(l, BoundConv2d):
                norm = l.weight.data.detach().view(l.weight.size(0), -1).abs().sum(1).max().cpu()
                logger.log('layer {} norm {}'.format(i, norm))
    if ifeval:
        return robust_errors.avg, errors.avg, margin_record
    if method == "natural":
        return errors.avg, errors.avg
    else:
        return robust_errors.avg, errors.avg
    

def main(args):
    print("main start")
    
    config = load_config(args)
    global_train_config = config["training_params"]
    global_eval_config = config["eval_params"]
    seed = global_train_config.get("random_seed",0)
    print("finish load config")
    if args.seed != 2019:
        seed = args.seed
        config["models"][0]["model_id"] = config["models"][0]["model_id"]+'_seed_'+str(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    load_last_saved = True if "load_last_saved" in global_train_config and global_train_config["load_last_saved"] == True else False
    models, model_names, other_saved = config_modelloader(config, load_last_saved=load_last_saved) 
    print("finish load model")

    
    for model, model_id, model_config, other_saved_ in zip(models, model_names, config["models"], other_saved):
        # make a copy of global training config, and update per-model config
        train_config = copy.deepcopy(global_train_config)
        if "training_params" in model_config:
            train_config = update_dict(train_config, model_config["training_params"])
        eval_config = copy.deepcopy(global_eval_config)
        if "eval_params" in model_config:
            eval_config.update(model_config["eval_params"])
        start_epoch = other_saved_["epoch"]
        eps_sample = other_saved_["eps_sample"]
        optimizer_saved = other_saved_["optimizer"]
        lr_scheduler_saved = other_saved_["lr_scheduler"]
        


        model = BoundSequential.convert(model, train_config["method_params"]["bound_opts"])
        
        # read training parameters from config file
        epochs = train_config["epochs"]
        lr = train_config["lr"]
        weight_decay = train_config["weight_decay"]
        # starting_epsilon = train_config["starting_epsilon"]
        # end_epsilon = train_config["epsilon"]
        schedule_length = train_config["eps_params"]["schedule_length"]
        schedule_start = train_config["eps_params"]["schedule_start"]

        optimizer = train_config["optimizer"]
        method = train_config["method"]
        verbose = train_config["verbose"]
        lr_decay_step = train_config["lr_decay_step"]
        lr_decay_milestones = train_config["lr_decay_milestones"]
        lr_decay_factor = train_config["lr_decay_factor"]
        multi_gpu = train_config["multi_gpu"]

        # parameters specific to a training method
        method_param = train_config["method_params"]
        eval_param = eval_config["method_params"]
        mixup_params = train_config["mixup_params"]
        if "trades_params" in train_config:
            trades_params = train_config["trades_params"]
        else:
            trades_params = None
        weight_params = train_config.get("weight_params",None)
        hinge_robust_loss = train_config.get("hinge_robust_loss",False)
        norm = float(train_config["norm"])

        train_data, test_data = config_dataloader_index(config, **train_config["loader_params"])
        
        # TODO (marinazh): Hack for manifold mixup
        if mixup_params and mixup_params["use_mixup"] and mixup_params["mixup_type"] == "manifold":
            print("WARNING: hack, can only do 1 model at a time with manifold mixup")
            model = BoundIBPLargeModel(**config["models"][0]["model_params"])
            

        if optimizer == "adam":
            opt = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        elif optimizer == "sgd":
            opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=weight_decay)
        else:
            raise ValueError("Unknown optimizer")
       
        batch_multiplier = train_config["method_params"].get("batch_multiplier", 1)
        batch_size = train_data.batch_size * batch_multiplier  
        num_steps_per_epoch = int(np.ceil(1.0 * len(train_data.dataset) / batch_size))
        epsilon_scheduler = EpsilonScheduler(train_config.get("schedule_type", "standard-linear"), num_steps_per_epoch, train_config["eps_params"], eps_sample)
        test_epsilon_scheduler = EpsilonScheduler("test", num_steps_per_epoch, train_config["eps_params"])

        ## not used
        max_eps = 0
        
        if lr_decay_step:
            # Use StepLR. Decay by lr_decay_factor every lr_decay_step.
            lr_scheduler = optim.lr_scheduler.StepLR(opt, step_size=lr_decay_step, gamma=lr_decay_factor)
            lr_decay_milestones = None
        elif lr_decay_milestones:
            # Decay learning rate by lr_decay_factor at a few milestones.
            lr_scheduler = optim.lr_scheduler.MultiStepLR(opt, milestones=lr_decay_milestones, gamma=lr_decay_factor)
        else:
            raise ValueError("one of lr_decay_step and lr_decay_milestones must be not empty.")
        if lr_scheduler_saved:
            lr_scheduler.load_state_dict(lr_scheduler_saved)
        model_name = get_path(config, model_id, "model", load = False)
        best_model_name = get_path(config, model_id, "best_model", load = False) 
        model_log = get_path(config, model_id, "train_log")
        logger = Logger(open(model_log, "w"))
        logger.log(model_name)
        logger.log("Command line:", " ".join(sys.argv[:]))
        logger.log("training configurations:", train_config)
        logger.log("Model structure:")
        logger.log(str(model))
        logger.log("data std:", train_data.std)
        best_err = np.inf
        recorded_clean_err = np.inf
        timer = 0.0
         
        if multi_gpu:
            logger.log("\nUsing multiple GPUs for computing CROWN-IBP bounds\n")
            model = BoundDataParallel(model) 
        model = model.cuda()
        
        for t in range(start_epoch, epochs):
            epoch_start_eps = epsilon_scheduler.get_eps(t, 0)
            epoch_end_eps = epsilon_scheduler.get_eps(t+1, 0)
            if type(epoch_start_eps) == np.float64:
                eps_mean = 0.5*(epoch_start_eps+epoch_end_eps)
            elif epoch_end_eps.nelement() == 1:
                eps_mean = 0.5*(epoch_start_eps+epoch_end_eps).item()
            else:
                eps_mean = torch.mean(0.5*(epoch_start_eps+epoch_end_eps)).item()
            
            logger.log("Epoch {}, learning rate {},mean epsilon {:.6g}".format(t, lr_scheduler.get_lr(), eps_mean ))
            # with torch.autograd.detect_anomaly():
            start_time = time.time() 


            Train(model, t, train_data, epsilon_scheduler, max_eps, norm, logger, verbose, True, opt, method, mixup_params, weight_params, trades_params, hinge_robust_loss = hinge_robust_loss,**method_param)
            if lr_decay_step:
                # Use stepLR. Note that we manually set up epoch number here, so the +1 offset.
                lr_scheduler.step(epoch=max(t - (schedule_start + schedule_length - 1) + 1, 0))
            elif lr_decay_milestones:
                # Use MultiStepLR with milestones.
                lr_scheduler.step()
            epoch_time = time.time() - start_time
            timer += epoch_time
            logger.log('Epoch time: {:.4f}, Total time: {:.4f}'.format(epoch_time, timer))
            logger.log("Evaluating...")
            with torch.no_grad():
                # evaluate
                err, clean_err = Train(model, t, test_data, test_epsilon_scheduler, max_eps, norm, logger, verbose, False, None, method, mixup_params, weight_params,  hinge_robust_loss = hinge_robust_loss,**eval_param)

            logger.log('saving to', model_name)
            torch.save({
                    'state_dict' : model.module.state_dict() if multi_gpu else model.state_dict(), 
                    'epoch' : t,
                    'eps_sample': epsilon_scheduler.eps_sample if hasattr(epsilon_scheduler, 'eps_sample') else 0,
                    'opt':opt.state_dict(),
                    'lr_scheduler':lr_scheduler.state_dict(),
                    }, model_name)

            # save the best model after 10 epoch, test eps is constant
            if t >= schedule_start:
                if err <= best_err:
                    best_err = err
                    recorded_clean_err = clean_err
                    logger.log('Saving best model {} with error {}'.format(best_model_name, best_err))
                    torch.save({
                            'state_dict' : model.module.state_dict() if multi_gpu else model.state_dict(), 
                            'robust_err' : err,
                            'clean_err' : clean_err,
                            'epoch' : t,
                            }, best_model_name)

        logger.log('Total Time: {:.4f}'.format(timer))
        logger.log('Model {} best err {}, clean err {}'.format(model_id, best_err, recorded_clean_err))

        with open("aaa_record.txt", "a+") as myfile:
            myfile.write('Model {} best err {}, clean err {}\n'.format(model_id, best_err, recorded_clean_err))


if __name__ == "__main__":
    args = argparser()
    main(args)