import os
import sys
import math
import random
import time
import argparse
from tqdm import tqdm
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import numpy as np
import tensorflow as tf
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.cm as cm


import save_nlayer_weights as nl
from setup_mnist import MNIST
from setup_cifar import CIFAR
from utils import generate_data
from bounds.crown.get_bounds_ours import get_weights_list

import monte_carlo

layer = []

def t_square(x):
    return x * x

def t_relu(x):
    return (x > 0).float() * x
    
class MLP:
    def __init__(self, args, device):
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            nhidden = args.hidden
            # quadratic bound only works for ReLU
            assert ((not args.quad) or args.activation == "relu")
            # for all activations we can use general framework
            assert args.method == "general" or args.activation == "relu"

            targeted = True
            if args.targettype == "least":
                target_type = 0b0100
            elif args.targettype == "top2":
                target_type = 0b0001
            elif args.targettype == "random":
                target_type = 0b0010
            elif args.targettype == "untargeted":
                target_type = 0b10000
                targeted = False

            if args.modeltype == "vanilla":
                suffix = ""
            else:
                suffix = "_" + args.modeltype
            
            # try models/mnist_3layer_relu_1024
            activation = args.activation
            modelfile = "models/" + args.model + "_" + str(args.numlayer) + "layer_" + activation + "_" + str(nhidden) + suffix
                        
            print("Obtaining", modelfile)
            
            if not os.path.isfile(modelfile):
                # if not found, try models/mnist_3layer_relu_1024_1024
                modelfile += ("_"+str(nhidden))*(args.numlayer-2) + suffix
                # if still not found, try models/mnist_3layer_relu
                if not os.path.isfile(modelfile):
                    modelfile = "models/" + args.model + "_" + str(args.numlayer) + "layer_" + activation + "_" + suffix
                    # if still not found, try models/mnist_3layer_relu_1024_best
                    if not os.path.isfile(modelfile):
                        modelfile = "models/" + args.model + "_" + str(args.numlayer) + "layer_" + activation + "_" + str(nhidden) + suffix + "_best"
                        if not os.path.isfile(modelfile):
                            raise(RuntimeError("cannot find model file"))
          
            if args.model == "mnist":
                data = MNIST()
                model = nl.NLayerModel([nhidden] * (args.numlayer - 1), modelfile, activation=activation)
            elif args.model == "cifar":
                data = CIFAR()
                model = nl.NLayerModel([nhidden] * (args.numlayer - 1), modelfile, image_size=32, image_channel=3, activation=activation)
            else:
                raise(RuntimeError("unknown model: "+args.model))

            self.inputs, self.targets, self.true_labels, self.true_ids, self.img_info = generate_data(data, samples=args.numimage, targeted=targeted, random_and_least_likely = True, target_type = target_type, predictor=model.model.predict, start=args.startimage)
            # get the logit layer predictions
            self.preds = model.model.predict(self.inputs) 
            
            # the weights and bias are saved in lists: weights and bias
            # weights[i-1] gives the ith layer of weight and so on
            weights, biases = get_weights_list(model)
            
            self.model = model
            self.modelfile = modelfile
            self.weights = weights
            self.biases = biases
            self.activation = activation
            self.args = args
            self.device = device
    
    def get_data(self):
        return {
            "inputs": self.inputs,
            "targets": self.targets,
            "true_labels": self.true_labels,
            "img_info": self.img_info,
            "preds": self.preds
        }
    
    def get_dims(self):
        dims = [self.args.hidden for i in range(self.args.numlayer)]
        dims[-1] = 1
        return dims
    
    def get_weights(self):
        return [torch.from_numpy(weight).to(self.device) for weight in self.weights]
    
    def get_biases(self):
        return [torch.from_numpy(bias).to(self.device) for bias in self.biases]
    
    def __call__(self, x):
        batch_num = x.size(0)
        x = x.unsqueeze(2)
        weights = [torch.from_numpy(self.weights[i]).to(self.device) for i in range(len(self.weights))]
        biases = [torch.from_numpy(self.biases[i]).to(self.device) for i in range(len(self.biases))]
        for i in range(len(weights)):
            weights[i] = weights[i].unsqueeze(0).expand(batch_num, -1, -1)
        for i in range(len(biases)):
            biases[i] = biases[i].unsqueeze(0).expand(batch_num, -1, -1).transpose(1, 2)
        for i in range(len(weights)):
            x = weights[i] @ x + biases[i]
            if i != len(weights)-1:
                x = x * (x > 0).float()
        return x

class Box:
    def __init__(self, corner_a, corner_b):
        self.corner_a = corner_a
        self.corner_b = corner_b
        #self.dim = corner_a.size(0)
        #assert corner_b.size(0) == self.dim

class BoundLinear(nn.Module):
    
    def __init__(self, box):
        super(BoundLinear, self).__init__()
        self.x_0 = (box.corner_a + box.corner_b)/2
        self.eps = (box.corner_b - box.corner_a)/2
        self.eps = self.eps.squeeze(-1).unsqueeze(1)
    
    #Return max_{x in Q} ax + b with probability >= 1-q
    def get_upper_bound(self, a, b, q=None):
        u = a @ self.x_0 + b
        #ub_raw = u + self.eps * torch.sum(torch.abs(a), dim=2).unsqueeze(2)
        ub_raw = u + torch.sum(torch.abs(self.eps * a), dim=2).unsqueeze(2)
        if q is None:
            return ub_raw
        q = q.unsqueeze(2)
        dualnorm = torch.norm(self.eps * a, dim=2).unsqueeze(2)
        ub = ub_raw
        mask = q > 0
        ub[mask] = u[mask] + torch.sqrt(-torch.log(q[mask]) * 2 * t_square(dualnorm)[mask])
        #ub[q < 0] = ub_raw[q < 0]
        return ub
    
    #Return min_{x in Q} ax + b with probability >= 1-q
    def get_lower_bound(self, a, b, q=None):
        u = a @ self.x_0 + b
        #lb_raw = u - self.eps * torch.sum(torch.abs(a), dim=2).unsqueeze(2)
        lb_raw = u - torch.sum(torch.abs(self.eps * a), dim=2).unsqueeze(2)
        if q is None:
            return lb_raw
        q = q.unsqueeze(2)
        dualnorm = torch.norm(self.eps * a, dim=2).unsqueeze(2)
        lb = lb_raw
        mask = q > 0
        lb[mask] = u[mask] - torch.sqrt(-torch.log(q[mask]) * 2 * t_square(dualnorm)[mask])
        #lb[q < 0] = lb_raw[q < 0]
        return lb

class BoundLinearDelta(nn.Module):
    
    def __init__(self, box, idx):
        super(BoundLinearDelta, self).__init__()
        self.idx = idx
        corner_a = box.corner_a.repeat(1, 2, 1)
        corner_b = box.corner_b.repeat(1, 2, 1)
        center = 0.5*(corner_a + corner_b)
        eps = torch.zeros_like(center)
        eps[:,idx.size(1):,:] = 0.5*idx*((corner_b - corner_a)[:,idx.size(1):,:])
        new_box = Box(center-eps, center+eps)

        self.bounder = BoundLinear(new_box)
    
    def get_upper_bound(self, a, b, c):
        new_a = torch.zeros_like(a).repeat(1, 1, 2)
        new_a[:,:,:a.size(2)] = a - b
        new_a[:,:,a.size(2):] = b
        return self.bounder.get_upper_bound(new_a, c)
    
    def get_lower_bound(self, a, b, c):
        new_a = torch.zeros_like(a).repeat(1, 1, 2)
        new_a[:,:,:a.size(2)] = a - b
        new_a[:,:,a.size(2):] = b
        return self.bounder.get_lower_bound(new_a, c)

class BoundActivation(nn.Module):

    def __init__(self, activation_type):
        super(BoundActivation, self).__init__()
        self.activation_type = activation_type

    def get_upper_bound_plus(self, l, u):
        if self.activation_type == "relu":
            m_l = (u.abs() > l.abs()).float() * torch.ones_like(l)
            r_l = torch.zeros_like(l)
            m_u = u/(u-l)
            r_u = l.detach().clone()
            m_l[u < 0] = 0
            r_l[u < 0] = 0
            m_u[u < 0] = 0
            r_u[u < 0] = 0
            m_l[l > 0] = 1
            r_l[l > 0] = 0
            m_u[l > 0] = 1
            r_u[l > 0] = 0

            alpha_l = m_l
            alpha_u = m_u
            beta_l = -alpha_l * r_l
            beta_u = -alpha_u * r_u
            return alpha_u, beta_u
        assert False
    
    def get_lower_bound_plus(self, l, u):
        if self.activation_type == "relu":
            #m_l = u/(u-l)
            m_l = (u.abs() > l.abs()).float() * torch.ones_like(l)
            r_l = torch.zeros_like(l)
            m_u = u/(u-l)
            r_u = l.detach().clone()
            m_l[u < 0] = 0
            r_l[u < 0] = 0
            m_u[u < 0] = 0
            r_u[u < 0] = 0
            m_l[l > 0] = 1
            r_l[l > 0] = 0
            m_u[l > 0] = 1
            r_u[l > 0] = 0

            alpha_l = m_l
            alpha_u = m_u
            beta_l = -alpha_l * r_l
            beta_u = -alpha_u * r_u
            return alpha_l, beta_l
        assert False
    
    def get_upper_bound(self, l, u):
        alpha_u, beta_u = self.get_upper_bound_plus(l, u)
        return torch.diag_embed(alpha_u.squeeze(2)), beta_u
    
    def get_lower_bound(self, l, u):
        alpha_l, beta_l = self.get_lower_bound_plus(l, u)
        return torch.diag_embed(alpha_l.squeeze(2)), beta_l

    def forward(self, l, u):
        if self.activation_type == "relu":
            m_l = u/(u-l)
            r_l = torch.zeros_like(l)
            m_u = u/(u-l)
            r_u = l.detach().clone()
            m_l[u < 0] = 0
            r_l[u < 0] = 0
            m_u[u < 0] = 0
            r_u[u < 0] = 0
            m_l[l > 0] = 1
            r_l[l > 0] = 0
            m_u[l > 0] = 1
            r_u[l > 0] = 0
            return m_l, r_l, m_u, r_u
        
        assert False

class InnerCROWN(nn.Module):
    def __init__(self, model, args, q_tot=None, steps=1, q_mode="all", batch_num=1):
        super(InnerCROWN, self).__init__()
        self.model = model
        self.weights = self.model.get_weights()
        self.biases = self.model.get_biases()
        self.activation = "relu"#self.model.activation
        self.true_dims = list(self.model.get_dims())
        self.dims = list(self.true_dims)
        #self.dims[-1] = 1
        self.args = args
        self.q_tot = q_tot
        self.steps = steps
        self.q_mode = q_mode
        self.BATCH_NUM = batch_num
        if torch.is_tensor(self.q_tot):
            self.BATCH_NUM = self.q_tot.size(0)
        self.reset_q()
        
    def reset_q(self):
        self.q_base = []
        for i in range(len(self.dims)):
            if self.q_mode == "last":
                self.q_base.append(-1e-9+torch.zeros([self.BATCH_NUM, self.dims[i], 2], dtype=torch.float32, device=self.model.device))
                if i == len(self.dims)-1:
                    self.q_base[i][:,0,0] = 1
            else:
                self.q_base.append(torch.ones([self.BATCH_NUM,self.dims[i], 2], dtype=torch.float32, device=self.model.device, requires_grad=(self.q_mode=="opt")))
        self.get_q()
    
    def get_q(self):
        if self.q_tot is None:
            self.q = [-1e-9+torch.zeros_like(self.q_base[i]) for i in range(len(self.dims))]
            return
        #for j in range(self.BATCH_NUM):
        #    base_normalization = sum([t_square(self.q_base[i][j,:,:]).sum() for i in range(len(self.dims))])
        #    for i in range(len(self.dims)):
        #        self.q_base[i][j,:,:] *= 1e3/base_normalization
        normalization = torch.zeros((self.BATCH_NUM,), dtype=torch.float32, device=self.model.device)
        for j in range(self.BATCH_NUM):
            normalization[j] = self.q_tot[j]/sum([t_square(self.q_base[i][j,:,:]).sum() for i in range(len(self.dims))])
        self.q = [t_square(self.q_base[i]) * normalization.unsqueeze(1).unsqueeze(2) for i in range(len(self.dims))]
        
        for i in range(len(self.dims)):
            self.q[i][self.q[i] < 1e-9] = -1e-9
        
        #for j in range(self.BATCH_NUM):
        #    print(j, sum([self.q[i][j].sum().item() for i in range(len(self.dims))]))

    def compute_ibp(self, true_label, target_label, x_0, eps, requires_grad=False):
        self.BATCH_NUM = x_0.size(0)
        weights = [weight[:] for weight in self.weights]
        biases = [bias[:] for bias in self.biases]
        #weights = [torch.from_numpy(self.weights[i]) for i in range(len(self.weights))]
        #biases = [torch.from_numpy(self.biases[i]) for i in range(len(self.biases))]
        for i in range(len(weights)):
            weights[i] = weights[i].unsqueeze(0).expand(self.BATCH_NUM, -1, -1)
        for i in range(len(biases)):
            biases[i] = biases[i].unsqueeze(0).expand(self.BATCH_NUM, -1, -1).transpose(1, 2)
        if target_label is not None:
            weights[-1] = (weights[-1][torch.arange(0,self.BATCH_NUM),true_label,:] - weights[-1][torch.arange(0,self.BATCH_NUM),target_label,:]).unsqueeze(1)
            biases[-1] = (biases[-1][torch.arange(0,self.BATCH_NUM),true_label,:] - biases[-1][torch.arange(0,self.BATCH_NUM),target_label,:]).unsqueeze(1)
        elif true_label is not None:
            weights[-1] = weights[-1][torch.arange(0,self.BATCH_NUM),true_label,:].unsqueeze(1) - weights[-1][torch.arange(0,self.BATCH_NUM),:,:]
            biases[-1] = biases[-1][torch.arange(0,self.BATCH_NUM),true_label,:].unsqueeze(1) - biases[-1][torch.arange(0,self.BATCH_NUM),:,:]
        
        l = [x_0 - eps]
        u = [x_0 + eps]
        for i in range(len(self.dims)):
            mask = weights[i] >= 0
            l_relu = l[-1]
            u_relu = u[-1]
            if i != 0:
                l_relu = l[-1] * (l[-1] >= 0)
                u_relu = u[-1] * (u[-1] >= 0)
            l_nxt = torch.bmm(mask * weights[i], l_relu) + torch.bmm((~mask) * weights[i], u_relu) + biases[i]
            u_nxt = torch.bmm(mask * weights[i], u_relu) + torch.bmm((~mask) * weights[i], l_relu) + biases[i]
            l.append(l_nxt)
            u.append(u_nxt)
            assert not (l_nxt != l_nxt).any()
        return l[-1].squeeze(2).squeeze(1), u[-1].squeeze(2).squeeze(1)
    
    def crown_ibp(self, true_label, target_label, x_0, eps, requires_grad=False):
        self.BATCH_NUM = x_0.size(0)
        x_0 = x_0.view(x_0.size(0), -1).unsqueeze(2)
        eps = eps.view(eps.size(0), -1).unsqueeze(2)
        weights = [weight[:] for weight in self.weights]
        biases = [bias[:] for bias in self.biases]
        #weights = [torch.from_numpy(self.weights[i]) for i in range(len(self.weights))]
        #biases = [torch.from_numpy(self.biases[i]) for i in range(len(self.biases))]
        for i in range(len(weights)):
            weights[i] = weights[i].unsqueeze(0).expand(self.BATCH_NUM, -1, -1)
        for i in range(len(biases)):
            biases[i] = biases[i].unsqueeze(0).expand(self.BATCH_NUM, -1, -1).transpose(1, 2)
        if target_label is not None:
            weights[-1] = (weights[-1][torch.arange(0,self.BATCH_NUM),true_label,:] - weights[-1][torch.arange(0,self.BATCH_NUM),target_label,:]).unsqueeze(1)
            biases[-1] = (biases[-1][torch.arange(0,self.BATCH_NUM),true_label,:] - biases[-1][torch.arange(0,self.BATCH_NUM),target_label,:]).unsqueeze(1)
        elif true_label is not None:
            weights[-1] = weights[-1][torch.arange(0,self.BATCH_NUM),true_label,:].unsqueeze(1) - weights[-1][torch.arange(0,self.BATCH_NUM),:,:]
            biases[-1] = biases[-1][torch.arange(0,self.BATCH_NUM),true_label,:].unsqueeze(1) - biases[-1][torch.arange(0,self.BATCH_NUM),:,:]
        
        l = [x_0 - eps]
        u = [x_0 + eps]
        for i in range(len(self.dims)):
            mask = weights[i] >= 0
            l_relu = l[-1]
            u_relu = u[-1]
            if i != 0:
                l_relu = l[-1] * (l[-1] >= 0)
                u_relu = u[-1] * (u[-1] >= 0)
            l_nxt = torch.bmm(mask * weights[i], l_relu) + torch.bmm((~mask) * weights[i], u_relu) + biases[i]
            u_nxt = torch.bmm(mask * weights[i], u_relu) + torch.bmm((~mask) * weights[i], l_relu) + biases[i]
            l.append(l_nxt)
            u.append(u_nxt)
            assert not (l_nxt != l_nxt).any()
        return l[-1].squeeze(2).squeeze(1), u[-1].squeeze(2).squeeze(1)
    
    def compute_ll(self, true_label, target_label, x_0, eps, requires_grad=False):
        self.BATCH_NUM = x_0.size(0)
        self.reset_q()
        self.get_q()
        input_bounder = BoundLinear(Box(x_0-eps, x_0+eps))
        activation_bounder = BoundActivation(self.activation)
        
        weights = [weight[:] for weight in self.weights]
        biases = [bias[:] for bias in self.biases]
        #weights = [torch.from_numpy(self.weights[i]) for i in range(len(self.weights))]
        #biases = [torch.from_numpy(self.biases[i]) for i in range(len(self.biases))]
        for i in range(len(weights)):
            weights[i] = weights[i].unsqueeze(0).expand(self.BATCH_NUM, -1, -1)
        for i in range(len(biases)):
            biases[i] = biases[i].unsqueeze(0).expand(self.BATCH_NUM, -1, -1).transpose(1, 2)
        if target_label is not None:
            weights[-1] = (weights[-1][torch.arange(0,self.BATCH_NUM),true_label,:] - weights[-1][torch.arange(0,self.BATCH_NUM),target_label,:]).unsqueeze(1)
            biases[-1] = (biases[-1][torch.arange(0,self.BATCH_NUM),true_label,:] - biases[-1][torch.arange(0,self.BATCH_NUM),target_label,:]).unsqueeze(1)
        elif true_label is not None:
            weights[-1] = weights[-1][torch.arange(0,self.BATCH_NUM),true_label,:].unsqueeze(1) - weights[-1][torch.arange(0,self.BATCH_NUM),:,:]
            biases[-1] = biases[-1][torch.arange(0,self.BATCH_NUM),true_label,:].unsqueeze(1) - biases[-1][torch.arange(0,self.BATCH_NUM),:,:]
        
        l = [x_0 - eps]
        u = [x_0 + eps]
        alpha_l = [torch.diag_embed(torch.ones_like(x_0).squeeze(2))]
        alpha_u = [torch.diag_embed(torch.ones_like(x_0).squeeze(2))]
        beta_l = [torch.zeros_like(x_0)]
        beta_u = [torch.zeros_like(x_0)]
        #activation_bounds = [(torch.ones_like(x_0), torch.zeros_like(x_0), torch.ones_like(x_0), l[0])]
        for i in range(len(self.dims)):
            #print("INTERMEDIATE OLD", i+1, l[i].sum(), u[i].sum())
            #print(i, weights[i].size())
            if i == 0:
                #print(weights[i].size(), biases[i].size(), self.q[i][:,:,0].size(), "HEY")
                l_nxt = input_bounder.get_lower_bound(weights[i], biases[i], self.q[i][:,:,0])
                u_nxt = input_bounder.get_upper_bound(weights[i], biases[i], self.q[i][:,:,1])
                l.append(l_nxt)
                u.append(u_nxt)
                continue
            #print("I", i)
            activation_lb = activation_bounder.get_lower_bound(l[-1], u[-1])
            alpha_l.append(activation_lb[0])
            beta_l.append(activation_lb[1])
            activation_ub = activation_bounder.get_upper_bound(l[-1], u[-1])
            alpha_u.append(activation_ub[0])
            beta_u.append(activation_ub[1])
            
            leaf = biases[i].detach().clone()
            leaf.requires_grad = requires_grad
            b_l = leaf.clone()
            b_u = leaf.clone()
            a_l = [weights[i]]
            a_u = [weights[i]]
            
            for j in range(i, 0, -1):
                #b_l += a_l[-1] @ biases[j-1]
                #b_u += a_u[-1] @ biases[j-1]

                #print(weights[i].size(), biases[i].size(), "eh")
                #print(a_l[-1].size(), alpha_l[j].size(), "huh")
                #print(beta_l[-1].size(), "beta")
                mask_l = a_l[-1] >= 0
                mask_u = a_u[-1] >= 0

                mini_alpha_l = torch.bmm(mask_l.float() * a_l[-1], alpha_l[j].transpose(1, 2)) + torch.bmm((~mask_l).float() * a_l[-1], alpha_u[j].transpose(1, 2))
                mini_beta_l = torch.bmm(mask_l.float() * a_l[-1], beta_l[j]) + torch.bmm((~mask_l).float() * a_l[-1], beta_u[j])
                mini_alpha_u = torch.bmm(mask_u.float() * a_u[-1], alpha_u[j].transpose(1, 2)) + torch.bmm((~mask_u).float() * a_u[-1], alpha_l[j].transpose(1, 2))
                mini_beta_u = torch.bmm(mask_u.float() * a_u[-1], beta_u[j]) + torch.bmm((~mask_u).float() * a_u[-1], beta_l[j])

                #print(mini_alpha_l.size(), mini_beta_l.size(), "ok")

                b_l += mini_beta_l
                b_u += mini_beta_u
                
                b_l += (mini_alpha_l @ biases[j-1])
                b_u += (mini_alpha_u @ biases[j-1])
                
                a_l.append(mini_alpha_l @ weights[j-1])
                a_u.append(mini_alpha_u @ weights[j-1])
            
            l_nxt = input_bounder.get_lower_bound(a_l[-1], b_l, self.q[i][:,:,0])
            u_nxt = input_bounder.get_upper_bound(a_u[-1], b_u, self.q[i][:,:,1])
            l.append(l_nxt)
            u.append(u_nxt)
            assert not (l_nxt != l_nxt).any()
        return l[-1].squeeze(2).squeeze(1), u[-1].squeeze(2).squeeze(1)
    
    def compute_l(self, true_label, target_label, x_0, eps, requires_grad=False):
        #x_0.requires_grad = requires_grad
        input_bounder = BoundLinear(Box(x_0-eps, x_0+eps))
        #input_bounder = BoundLinear(x_0, eps)
        activation_bounder = BoundActivation(self.activation)
        
        weights = [weight[:] for weight in self.weights]
        biases = [bias[:] for bias in self.biases]
        #weights = [torch.from_numpy(self.weights[i]) for i in range(len(self.weights))]
        #biases = [torch.from_numpy(self.biases[i]) for i in range(len(self.biases))]
        for i in range(len(weights)):
            weights[i] = weights[i].unsqueeze(0).expand(self.BATCH_NUM, -1, -1)
        for i in range(len(biases)):
            biases[i] = biases[i].unsqueeze(0).expand(self.BATCH_NUM, -1, -1).transpose(1, 2)
        if target_label is not None:
            weights[-1] = (weights[-1][torch.arange(0,self.BATCH_NUM),true_label,:] - weights[-1][torch.arange(0,self.BATCH_NUM),target_label,:]).unsqueeze(1)
            biases[-1] = (biases[-1][torch.arange(0,self.BATCH_NUM),true_label,:] - biases[-1][torch.arange(0,self.BATCH_NUM),target_label,:]).unsqueeze(1)
        elif true_label is not None:
            weights[-1] = weights[-1][torch.arange(0,self.BATCH_NUM),true_label,:] - weights[-1][torch.arange(0,self.BATCH_NUM),:,:]
            biases[-1] = biases[-1][torch.arange(0,self.BATCH_NUM),true_label,:] - biases[-1][torch.arange(0,self.BATCH_NUM),:,:]
        
        l = [x_0 - eps]
        u = [x_0 + eps]
        d_l = [torch.diag_embed(torch.ones_like(x_0).squeeze(2))]
        d_u = [torch.diag_embed(torch.ones_like(x_0).squeeze(2))]
        activation_bounds = [(torch.ones_like(x_0), torch.zeros_like(x_0), torch.ones_like(x_0), l[0])]
        for i in range(len(self.dims)):
            if i == 0:
                l_nxt = input_bounder.get_lower_bound(weights[i], biases[i], self.q[i][:,:,0])
                u_nxt = input_bounder.get_upper_bound(weights[i], biases[i], self.q[i][:,:,1])
                l.append(l_nxt)
                u.append(u_nxt)
                continue
            
            activation_bounds.append(activation_bounder(l[-1], u[-1]))
            d_l.append(torch.diag_embed(activation_bounds[-1][0].squeeze(2)))
            d_u.append(torch.diag_embed(activation_bounds[-1][2].squeeze(2)))
            
            leaf = biases[i].detach().clone()
            leaf.requires_grad = requires_grad
            b_l = leaf.clone()
            b_u = leaf.clone()
            a_l = [weights[i] @ d_l[i]]
            a_u = [weights[i] @ d_u[i]]
            
            for j in range(i, 0, -1):
                b_l += a_l[-1] @ biases[j-1]
                b_u += a_u[-1] @ biases[j-1]
                
                h = torch.zeros_like(a_l[-1].transpose(1,2)).where(a_l[-1].transpose(1,2) >= 0, activation_bounds[j][3].expand(-1, -1, self.dims[i]))
                hm = (a_l[-1].transpose(1,2) * h).sum(dim=1).unsqueeze(2)
                t = torch.zeros_like(a_u[-1].transpose(1,2)).where(a_u[-1].transpose(1,2) <= 0, activation_bounds[j][3].expand(-1, -1, self.dims[i]))
                tm = (a_u[-1].transpose(1,2) * t).sum(dim=1).unsqueeze(2)
                
                b_l -= hm
                b_u -= tm
                
                a_l.append(a_l[-1] @ weights[j-1] @ d_l[j-1])
                a_u.append(a_u[-1] @ weights[j-1] @ d_u[j-1])
            
            l_nxt = input_bounder.get_lower_bound(a_l[-1], b_l, self.q[i][:,:,0])
            u_nxt = input_bounder.get_upper_bound(a_u[-1], b_u, self.q[i][:,:,1])
            l.append(l_nxt)
            u.append(u_nxt)
            assert not (l_nxt != l_nxt).any()
        return l[-1].squeeze(2).squeeze(1), u[-1].squeeze(2).squeeze(1)

    def forward(self, true_label, target_label, x_0, eps):
        x_0 = x_0.unsqueeze(2)
        self.BATCH_NUM = x_0.size(0)
        self.reset_q()
        if self.q_mode != "opt":
            self.get_q()
            vals = self.compute_ll(true_label, target_label, x_0, eps)
            return vals[0], {"l": vals[0], "u": vals[1]}

        best_l = -1e5 * torch.ones((self.BATCH_NUM,), dtype=torch.float32)
        optimizer = optim.AdamW(self.q_base, lr=1e-2)
        
        losses = []
        for t_ in range(self.steps):
            self.get_q()
                     
            
            for i in range(len(self.dims)):
                assert not (self.q_base[i] != self.q_base[i]).any()
            optimizer.zero_grad()
            lb = self.compute_l(true_label, target_label, x_0, eps, requires_grad=True)[0]
            loss = -lb.sum()
            losses.append(loss.item())
            best_l = torch.cat([lb.unsqueeze(0), best_l.unsqueeze(0)]).max(dim=0).values
            loss.backward()
            nn.utils.clip_grad_norm_(self.q_base, 1e2)
            optimizer.step()
        return best_l, {}

class CROWN(nn.Module):

    def __init__(self, model, args):
        super(CROWN, self).__init__()
        self.model = model
        self.args = args

    def forward(self, true_label, target_label, x_0, q_tot=None, q_mode="all"):
        BATCH_NUM = x_0.size(0)
        q = None if q_tot is None else q_tot * torch.ones((BATCH_NUM,), dtype=torch.float32)
        inner = InnerCROWN(self.model, self.args, q_tot=q, q_mode=q_mode)
        
        eps_low = torch.zeros((BATCH_NUM,1,1), dtype=torch.float32, device=x_0.device)
        eps_high = torch.ones((BATCH_NUM,1,1), dtype=torch.float32, device=x_0.device)
        while (eps_high - eps_low).max() > 1e-6:
            eps_mid = (eps_low + eps_high)/2
            lb, other = inner(true_label, target_label, x_0, eps_mid)
            eps_low[lb > 0] = eps_mid[lb > 0]
            eps_high[~(lb > 0)] = eps_mid[~(lb > 0)]
        return eps_low, {}

class CROWNDataset(Dataset):
    
    def __init__(self, model):
        self.data = model.get_data()
    
    def __getitem__(self, idx):
        return np.argmax(self.data["true_labels"][idx]), np.argmax(self.data["targets"][idx]), self.data["inputs"][idx].flatten().astype(np.float32)
    
    def __len__(self):
        return len(self.data["true_labels"])

def sampling_main(args):
    #python get_bounds_torch.py --model mnist --hidden 20 --numlayer 3 --targettype least --norm i --numimage 10 --activation relu --method ours
    #python get_bounds_torch.py --model mnist --hidden 20 --numlayer 3 --targettype top2 --norm i --numimage 1 --activation relu --method ours
    #python get_bounds_torch.py --model cifar --hidden 1024 --numlayer 7 --targettype top2 --norm i --numimage 10 --activation relu --method ours

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MLP(args, device)
    dataset = CROWNDataset(model)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=2)
    
    certifier = CROWN(model, args)
    savefile = "rebuttal_logs/" + model.modelfile.split("/")[1]
    with open(savefile + ".txt", "w") as f:
        plt.figure()
        plt.title("Naive Sampling on 3-Layer MNIST Classifier (Uniform Noise on $L_\infty$ Ball)")
        plt.xlabel("Total failure probability $Q$")
        plt.ylabel("Certified radius $\epsilon$")
        
        for q_mode, col, label in zip(["all", "last", "sampling"], ["blue", "red", "green"], ["I-PROVEN", "PROVEN", "Sampling"]):
            plot_y = []
            plot_x = []
            y_err = []
            timings = []
            print("START", q_mode)
            for q_tot in tqdm([1e-4, 1e-2, 5e-2, 25e-2, 50e-2, 95e-2]):
                rad_tot = 0
                items = []
                start_time = time.time()
                for idx, batch in enumerate(dataloader):
                    batch[0] = batch[0].to(device)
                    batch[1] = batch[1].to(device)
                    batch[2] = batch[2].to(device)
                    if q_mode != "sampling":
                        rad, _ = certifier(*batch, q_tot=q_tot, q_mode=q_mode)
                        rad_tot += rad.sum().item()
                        items += rad.flatten().tolist()
                    else:
                        true_labels = batch[0]
                        targets = batch[1]
                        inputs = batch[2]
                        t = int(1/q_tot * math.log(1/0.0001))
                        rad, _ = monte_carlo.get_radius_monte_carlo(model, true_labels, targets, inputs, t, t)
                        rad_tot += rad.sum().item()
                        items += rad.flatten().tolist()
                end_time = time.time()
                print(len(items), "ITEMS")
                y_err.append(3 * np.std(np.array(items)) / math.sqrt(len(items)))
                plot_y.append(rad_tot/len(dataset))
                plot_x.append(0 if q_tot is None else q_tot)
                timings.append(end_time-start_time)
            
            
            #plt.plot(plot_x, plot_y, 'o', color = col, label=label)
            plt.errorbar(plot_x, plot_y, yerr=y_err, fmt='o', color = col, label=label)
            f.write(q_mode + " " + col + "\n")
            f.write(str(plot_x))
            f.write("\n")
            f.write(str(plot_y))
            f.write("\n")
            f.write(str(timings))
            f.write("\n")
            for i in range(len(plot_y)):
                if i != 0:
                    f.write(" & ")
                f.write("{:.4f}".format(plot_y[i]))
            f.write("\n")
            for i in range(len(plot_y)):
                r = 100 * plot_y[i] / plot_y[0]
                if i != 0:
                    f.write(" & ")
                f.write("{:.0f}".format(r))
            f.write("\n")
            for i in range(len(timings)):
                if i != 0:
                    f.write(" & ")
                f.write("{:.4f}".format(timings[i]))
            f.write("\n")
            f.write("\n")

                
    plt.ylim(bottom = 0.0)
    plt.legend()
    plt.savefig(savefile + ".png")
    plt.show()

def main(args):
    #python get_bounds_torch.py --model mnist --hidden 20 --numlayer 3 --targettype least --norm i --numimage 10 --activation relu --method ours
    #python get_bounds_torch.py --model mnist --hidden 20 --numlayer 3 --targettype top2 --norm i --numimage 1 --activation relu --method ours
    #python get_bounds_torch.py --model cifar --hidden 1024 --numlayer 7 --targettype top2 --norm i --numimage 10 --activation relu --method ours

    model = MLP(args)
    dataset = CROWNDataset(model)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
    
    test_point = dataset[0]

    certifier = CROWN(model, args)
    savefile = "logs/" + model.modelfile.split("/")[1]
    with open(savefile + ".txt", "w") as f:
        plt.figure()
        plt.title("Improved PROVEN on 7-Layer CIFAR-10 Classifier")
        plt.xlabel("Total failure probability $Q$")
        plt.ylabel("Certified radius $\epsilon$")
        
        for q_mode, col, label in zip(["all", "last"], ["blue", "red"], ["CIFAR ({1,2,3,4,5,6,7})", "PROVEN ({7})"]):
            plot_y = []
            plot_x = []
            y_err = []
            print("START", q_mode)
            for q_tot in tqdm([None, 1e-4, 1e-2, 5e-2, 25e-2, 50e-2, 95e-2]):
                rad_tot = 0
                items = []
                for idx, batch in enumerate(dataloader):
                    print(idx, "ok")
                    rad, _ = certifier(*batch, q_tot=q_tot, q_mode=q_mode)
                    rad_tot += rad.sum().item()
                    items += rad.flatten().tolist()
                print(len(items), "ITEMS")
                y_err.append(3 * np.std(np.array(items)) / math.sqrt(len(items)))
                plot_y.append(rad_tot/len(dataset))
                plot_x.append(0 if q_tot is None else q_tot)
            
            
            #plt.plot(plot_x, plot_y, 'o', color = col, label=label)
            plt.errorbar(plot_x, plot_y, yerr=y_err, fmt='o', color = col, label=label)
            f.write(q_mode + " " + col + "\n")
            f.write(str(plot_x))
            f.write("\n")
            f.write(str(plot_y))
            f.write("\n")
            for i in range(len(plot_y)):
                if i != 0:
                    f.write(" & ")
                f.write("{:.4f}".format(plot_y[i]))
            f.write("\n")
            for i in range(len(plot_y)):
                r = 100 * plot_y[i] / plot_y[0]
                if i != 0:
                    f.write(" & ")
                f.write("{:.0f}".format(r))
            f.write("\n")

                
    plt.ylim(bottom = 0.0)
    plt.legend()
    plt.savefig(savefile + ".png")
    plt.show()


if __name__ == "__main__":    
    #### parser ####
    parser = argparse.ArgumentParser(description='compute activation bound for CIFAR and MNIST')
    parser.add_argument('--model', 
                default="mnist",
                choices=["mnist", "cifar"],
                help='model to be used')
    parser.add_argument('--eps',
                default = 0.005,
                type = float,
                help = "epsilon for verification")
    parser.add_argument('--hidden',
                default = 20,
                type = int,
                help = "number of hidden neurons per layer")
    parser.add_argument('--numlayer',
                default = 3,
                type = int,
                help='number of layers in the model')
    parser.add_argument('--numimage',
                default = 1,
                type = int,
                help='number of images to run')
    parser.add_argument('--startimage',
                default = 0,
                type = int,
                help='start image')
    parser.add_argument('--norm',
                default = "i",
                type = str,
                choices = ["i", "1", "2"],
                help='perturbation norm: "i": Linf, "1": L1, "2": L2')
    parser.add_argument('--method',
                default = "ours",
                type = str,
                choices = ["general", "ours", "adaptive", "spectral", "naive", "green", "blue", "purple"],
                help='"ours": our proposed bound, "spectral": spectral norm bounds, "naive": naive bound')
    parser.add_argument('--lipsbnd',
                type = str,
                default = "disable",
                choices = ["disable", "fast", "naive", "both"],
                help='compute Lipschitz bound, after using some method to compute neuron lower/upper bounds')
    parser.add_argument('--lipsteps',
                type = int,
                default = 30,
                help='number of steps to use in lipschitz bound')
    parser.add_argument('--LP',
                action = "store_true",
                help='use LP to get bounds for final output')
    parser.add_argument('--LPFULL',
                action = "store_true",
                help='use FULL LP to get bounds for output')
    parser.add_argument('--quad',
                action = "store_true",
                help='use quadratic bound to imporve 2nd layer output')
    parser.add_argument('--warmup',
                action = "store_true",
                help='warm up before the first iteration')
    parser.add_argument('--modeltype',
                default = "vanilla",
                choices = ["vanilla", "dropout", "distill", "adv_retrain"],
                help = "select model type")
    parser.add_argument('--targettype',
                default="least",
                choices = ["untargeted", "least", "top2", "random"],
                help='untargeted minimum distortion') 
    parser.add_argument('--steps',
                default = 15,
                type = int,
                help = 'how many steps to binary search')
    parser.add_argument('--activation',
                default="relu",
                choices=["relu", "tanh", "sigmoid", "arctan", "elu", "hard_sigmoid", "softplus"])
    
    random.seed(1215)
    np.random.seed(1215)
    tf.set_random_seed(1215)
    torch.manual_seed(1215)
    args = parser.parse_args()

    sampling_main(args)
