import os
import sys
import math
import random
import time
import argparse
from tqdm import tqdm
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import numpy as np
import tensorflow as tf
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import torch.autograd.profiler as profiler

from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.cm as cm

import imageio

def debug(text, file="ffs.txt"):
    with open(file, "a") as f:
        f.write(str(text))
        f.write("\n---\n")

def t_square(x):
    return x * x

def cnn_identity(shape, device):
    ret = torch.eye(int(np.prod(shape[1:]))).unsqueeze(0).expand(shape[0], -1, -1).to(device)
    return ret.reshape(shape + shape[1:])

def img_compose(alpha, beta, img):
    return torch.einsum("bijkuvw,buvw->bijk", alpha, img) + beta

def cnn_compose(alpha, beta, conv, next_shape):
    alpha_shape = list(alpha.size())
    weight = conv.weight[:]
    bias = conv.bias[:]

    #output_padding = [i-1 for i in conv.stride]
    output_padding = [next_shape[2] - ((alpha.size(5) - 1) * conv.stride[0] + conv.kernel_size[0] - 2 * conv.padding[0]) for i in range(len(conv.stride))]
    
    alpha_r = alpha.reshape([-1] + alpha_shape[-3:])
    alpha_nxt = F.conv_transpose2d(alpha_r, weight, stride=list(conv.stride), padding=list(conv.padding), output_padding=output_padding)
    alpha_nxt = alpha_nxt.reshape(alpha_shape[:-3] + list(alpha_nxt.size())[1:])
    
    bias = torch.ones([alpha.size(0)] + list(alpha.size())[-3:]).to(bias.device) * bias.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
    beta_nxt = beta + torch.einsum("bijkuvw,buvw->bijk", alpha, bias)
    #print("HUH", alpha.size(), beta.size(), alpha_nxt.size(), beta_nxt.size())
    return alpha_nxt, beta_nxt

class Box:
    def __init__(self, corner_a, corner_b):
        self.corner_a = corner_a
        self.corner_b = corner_b
        #self.dim = corner_a.size(0)
        #assert corner_b.size(0) == self.dim

class InputInequality:
    
    def __init__(self, box):
        #super(InputInequality, self).__init__()
        self.x_0 = (box.corner_a + box.corner_b)/2
        self.eps = (box.corner_b - box.corner_a)/2
        self.eps = self.eps.unsqueeze(1).unsqueeze(1).unsqueeze(1)
    
    def get_upper_bound(self, a, b):
        u = img_compose(a, b, self.x_0)
        #ub_raw = u + self.eps * torch.sum(torch.abs(a), dim=2).unsqueeze(2)
        ub_raw = u + torch.sum(torch.abs(self.eps * a), dim=[4, 5, 6])
        return ub_raw
    
    def get_lower_bound(self, a, b):
        u = img_compose(a, b, self.x_0)
        lb_raw = u - torch.sum(torch.abs(self.eps * a), dim=[4, 5, 6])
        return lb_raw

class GaussianInequality:
    
    def __init__(self, x, eps):
        #super(ProbabilisticInputInequality, self).__init__()
        self.x_0 = x.clone()
        self.eps = eps.clone()
    
    #Return max_{x in Q} ax + b with probability >= 1-q
    def get_upper_bound(self, a, b, q):
        u = img_compose(a, b, self.x_0)
        #ub_raw = u + torch.sum(torch.abs(self.eps * a), dim=[4, 5, 6])

        dualnorm_sq = torch.sum(t_square(self.eps * a), dim=[4, 5, 6])
        #print("HUH", a.size(), dualnorm_sq.size())
        ub = torch.zeros_like(u) + 1e8
        #ub = torch.clone(ub_raw)
        mask = q > 0
        feps = 1e-8
        #print(torch.sqrt(-torch.log(q)).mean().item(), "Q")
        ub[mask] = u[mask] + torch.erfinv(1-2*q[mask]) * torch.sqrt(2 * dualnorm_sq[mask] + feps)
        #print("UB", ub[mask].mean().item(), (u[mask] + torch.sqrt(-torch.log(q[mask]) * 2 * dualnorm_sq[mask] + feps)).mean().item())
        #ub = torch.min(ub, ub_raw)
        #ub[q < 0] = ub_raw[q < 0]
        return ub
    
    #Return min_{x in Q} ax + b with probability >= 1-q
    def get_lower_bound(self, a, b, q):
        u = img_compose(a, b, self.x_0)
        #lb_raw = u - torch.sum(torch.abs(self.eps * a), dim=[4, 5, 6])

        dualnorm_sq = torch.sum(t_square(self.eps * a), dim=[4, 5, 6])
        lb = torch.zeros_like(u) - 1e8
        mask = q > 0
        feps = 1e-8
        lb[mask] = u[mask] - torch.erfinv(1-2*q[mask]) * torch.sqrt(2 * dualnorm_sq[mask] + feps)
        #print(torch.erfinv(torch.ones_like(u) * (1-(1e-7))).mean().item(), torch.erfinv(torch.zeros_like(u)).mean().item(), "ERFINV")
        #print("LB", lb[mask].mean().item(), (u[mask] - torch.sqrt(-torch.log(q[mask]) * 2 * dualnorm_sq[mask] + feps)).mean().item())
        #lb = torch.max(lb, lb_raw)
        return lb

class ProbabilisticInputInequality:
    
    def __init__(self, box):
        #super(ProbabilisticInputInequality, self).__init__()
        self.x_0 = (box.corner_a + box.corner_b)/2
        self.eps = (box.corner_b - box.corner_a)/2
        self.eps = self.eps.unsqueeze(1).unsqueeze(1).unsqueeze(1)
        self.total = 0
    
    #Return max_{x in Q} ax + b with probability >= 1-q
    def get_upper_bound(self, a, b, q=None):
        u = img_compose(a, b, self.x_0)
        ub_raw = u + torch.sum(torch.abs(self.eps * a), dim=[4, 5, 6])
        #normal_block = torch.sum(torch.abs(a), dim=[4,5,6])
        #sq_block = torch.sqrt(torch.sum(t_square(a), dim=[4,5,6]))
        #print("HEY", a.size())
        #print(normal_block.mean().item(), normal_block.size(), "NORMAL")
        #print(sq_block.mean().item(), sq_block.size(), "SQ")
        #test_a = torch.randn_like(a)
        #normal_block = torch.sum(torch.abs(test_a), dim=[4,5,6])
        #sq_block = torch.sqrt(torch.sum(t_square(test_a), dim=[4,5,6]))
        #print(normal_block.mean().item(), normal_block.size(), "NORMAL")
        #print(sq_block.mean().item(), sq_block.size(), "SQ")
        if q is None:
            return ub_raw
        dualnorm_sq = torch.sum(t_square(self.eps * a), dim=[4, 5, 6])
        #print("HUH", a.size(), dualnorm_sq.size())
        ub = torch.clone(ub_raw)
        mask = q > 0
        self.total += mask.sum()
        feps = 1e-8
        #print(torch.sqrt(-torch.log(q)).mean().item(), "Q")
        ub[mask] = u[mask] + torch.sqrt(-torch.log(q[mask]) * 2 * dualnorm_sq[mask] + feps)
        ub = torch.min(ub, ub_raw)
        #ub[q < 0] = ub_raw[q < 0]
        return ub
    
    #Return min_{x in Q} ax + b with probability >= 1-q
    def get_lower_bound(self, a, b, q=None):
        u = img_compose(a, b, self.x_0)
        #print("SIZE", (self.eps * a).size())
        lb_raw = u - torch.sum(torch.abs(self.eps * a), dim=[4, 5, 6])
        if q is None:
            return lb_raw
        dualnorm_sq = torch.sum(t_square(self.eps * a), dim=[4, 5, 6])
        lb = torch.clone(lb_raw)
        mask = q > 0
        self.total += mask.sum()
        feps = 1e-8
        lb[mask] = u[mask] - torch.sqrt(-torch.log(q[mask]) * 2 * dualnorm_sq[mask] + feps)
        #print("NORM", dualnorm_sq)
        #print("EPS", torch.mean(self.eps))
        #print("Q", -torch.log(q[mask]))
        #print("DIFF", lb_raw[mask] - lb[mask])
        #print("HUHA", torch.sum(torch.abs(self.eps * a), dim=[4, 5, 6]))
        #print("HUHB", torch.sqrt(torch.sum(t_square(self.eps * a), dim=[4, 5, 6])))
        lb = torch.max(lb, lb_raw)
        return lb

    #Return upper bound on probability that ax + b < 0
    def get_q(self, a, b):
        u = img_compose(a, b, self.x_0)
        mask = u > 0
        dualnorm_sq = torch.sum(t_square(self.eps * a), dim=[4, 5, 6])
        #q = torch.exp(-t_square(u) / (2 * dualnorm_sq + 1e-6))
        q = torch.ones_like(u)
        q[mask] = torch.exp(-t_square(u[mask]) / (2 * dualnorm_sq[mask] + 1e-6))
        return q

class DeltaInputInequality:
    
    def __init__(self, box, idx):
        #super(DeltaInputInequality, self).__init__()
        self.idx = idx
        corner_a = box.corner_a.repeat(1, 2, 1)
        corner_b = box.corner_b.repeat(1, 2, 1)
        center = 0.5*(corner_a + corner_b)
        eps = torch.zeros_like(center)
        eps[:,idx.size(1):,:] = 0.5*idx*((corner_b - corner_a)[:,idx.size(1):,:])
        new_box = Box(center-eps, center+eps)

        self.bounder = InputInequality(new_box)
    
    def get_upper_bound(self, a, b, c):
        new_a = torch.zeros_like(a).repeat(1, 1, 2)
        new_a[:,:,:a.size(2)] = a - b
        new_a[:,:,a.size(2):] = b
        return self.bounder.get_upper_bound(new_a, c)
    
    def get_lower_bound(self, a, b, c):
        new_a = torch.zeros_like(a).repeat(1, 1, 2)
        new_a[:,:,:a.size(2)] = a - b
        new_a[:,:,a.size(2):] = b
        return self.bounder.get_lower_bound(new_a, c)

class ConvWrapper:
    def __init__(self, conv):
        self.module = conv
        self.a = self.module.weight[:]
        self.b = self.module.bias[:]
        self.targeted = False
        self.partial_target = False
    
    def set_target(self, l, u, true_label=None, target_label=None):
        if target_label is not None:
            self.targeted = True
            assert False
        elif true_label is not None:
            self.partial_target = True
            self.true_label = true_label
            self.split = []
            for i in range(self.module.weight.size(0)):
                self.split.append(
                    nn.Conv2d(self.module.in_channels, self.module.out_channels, self.module.kernel_size, self.module.stride, self.module.padding)
                )
                self.split[i].weight.data = self.module.weight[i,:,:,:].unsqueeze(0) - self.module.weight[:]
                self.split[i].bias.data = self.module.bias[i].unsqueeze(0) - self.module.bias[:]

    def compute_bounds(self, l, u):
        self.a_l = self.a[:]
        self.a_u = self.a[:]
        self.b_l = self.b[:]
        self.b_u = self.b[:]

        #self.a = self.module.weight[:]
        #self.b = self.module.bias[:]

    def propagate_intervals(self, l, u):
        if not self.partial_target:
            t1 = F.conv2d(0.5*(l+u), self.module.weight, self.module.bias, stride=self.module.stride, padding=self.module.padding)
            t2 = F.conv2d(0.5*(l-u), torch.abs(self.module.weight), stride=self.module.stride, padding=self.module.padding)
            return t1+t2, t1-t2
        else:
            zl = F.conv2d(torch.zeros_like(l), torch.zeros_like(self.module.weight), stride=self.module.stride, padding=self.module.padding)
            zu = F.conv2d(torch.zeros_like(l), torch.zeros_like(self.module.weight), stride=self.module.stride, padding=self.module.padding)
            for i in range(self.module.weight.size(0)):
                mask = self.true_label == i
                if mask.sum().item() == 0:
                    continue
                t1 = F.conv2d(0.5*(l+u)[mask], self.split[i].weight, self.split[i].bias, stride=self.module.stride, padding=self.module.padding)
                t2 = F.conv2d(0.5*(l-u)[mask], torch.abs(self.split[i].weight), stride=self.module.stride, padding=self.module.padding)
                zl[mask] = t1+t2
                zu[mask] = t1-t2
            return zl, zu

    def propagate_uncertainty(self, alpha_l, beta_l, alpha_u, beta_u):
        nxt_alpha_l, nxt_alpha_u = self.propagate_intervals(alpha_l, alpha_u)
        nxt_beta_l, nxt_beta_u = self.propagate_intervals(beta_l, beta_u)
        return nxt_alpha_l, nxt_beta_l, nxt_alpha_u, nxt_beta_u
    
    def propagate_bounds(self, alpha, beta, lower, next_shape):
        #(batch_num, h_1, w_1, c_1, h_2, w_2, c_2)
        if not self.partial_target:
            return cnn_compose(alpha, beta, self.module, next_shape)
        else:
            alpha_nxt, beta_nxt = cnn_compose(alpha, beta, self.module, next_shape)
            for i in range(self.module.weight.size(0)):
                mask = self.true_label == i
                if mask.sum().item() == 0:
                    continue
                alpha_nxt[mask], beta_nxt[mask] = cnn_compose(alpha[mask], beta[mask], self.split[i], next_shape)
            return alpha_nxt, beta_nxt

    def propagate_d(self, d):
        if not self.partial_target:
            return F.conv2d(d, torch.abs(self.module.weight), stride=self.module.stride, padding=self.module.padding)
        else:
            nd = F.conv2d(torch.zeros_like(d), torch.zeros_like(self.module.weight), stride=self.module.stride, padding=self.module.padding)
            for i in range(self.module.weight.size(0)):
                mask = self.true_label == i
                if mask.sum().item() == 0:
                    continue
                nd[mask] = F.conv2d(d[mask], torch.abs(self.split[i].weight), stride=self.module.stride, padding=self.module.padding)
            return nd
    

class ReluWrapper:
    def __init__(self, activation, use_zero=False):
        self.module = activation
        self.use_zero = use_zero

    def compute_bounds(self, l, u):
        m_l = torch.zeros_like(l) if self.use_zero else (u.abs() > l.abs()).float() * torch.ones_like(l)
        r_l = torch.zeros_like(l)
        m_u = torch.ones_like(l)#u/(u-l)
        mask = u > l + 1e-8
        m_u[mask] = u[mask]/(u[mask]-l[mask])
        r_u = l.detach().clone()
        m_l[u < 0] = 0
        r_l[u < 0] = 0
        m_u[u < 0] = 0
        r_u[u < 0] = 0
        m_l[l > 0] = 1
        r_l[l > 0] = 0
        m_u[l > 0] = 1
        r_u[l > 0] = 0
        
        self.a_l = m_l
        self.a_u = m_u
        self.b_l = -self.a_l * r_l
        self.b_u = -self.a_u * r_u
        
    def propagate_intervals(self, l, u):
        return self.module(l), self.module(u)

    def propagate_uncertainty(self, alpha_l, beta_l, alpha_u, beta_u):
        nxt_alpha_l = torch.zeros_like(alpha_l)
        nxt_beta_l = torch.zeros_like(beta_l)
        nxt_alpha_u = torch.zeros_like(alpha_u)
        nxt_beta_u = torch.zeros_like(beta_u)
        mask = alpha_l > 0
        nxt_alpha_l[mask] = alpha_l[mask]
        nxt_beta_l[mask] = beta_l[mask]
        mask = alpha_u > 0
        nxt_alpha_u[mask] = alpha_u[mask]
        nxt_beta_u = beta_u.clone()
        return nxt_alpha_l, nxt_beta_l, nxt_alpha_u, nxt_beta_u
    
    def propagate_bounds_new(self, alpha, beta, lower):
        #slightly slower
        alpha_shape = list(alpha.size())
        beta_shape = list(beta.size())
        
        bs = alpha_shape[0]
        ijk = np.prod(alpha_shape[1:4])
        uvw = np.prod(alpha_shape[4:7])

        alpha = alpha.view(bs, ijk, uvw)
        beta = beta.view(bs, ijk)
        alpha_pos = alpha.clamp(min=0)
        alpha_neg = alpha.clamp(max=0)

        a_pos = (self.a_l if lower else self.a_u).view(bs, uvw)
        a_neg = (self.a_u if lower else self.a_l).view(bs, uvw)
        b_pos = (self.b_l if lower else self.b_u).view(bs, uvw)
        b_neg = (self.b_u if lower else self.b_l).view(bs, uvw)

        alpha_nxt = alpha_pos * a_pos.unsqueeze(1) + alpha_neg * a_neg.unsqueeze(1)
        beta_nxt = beta + (alpha_pos @ b_pos.unsqueeze(-1) + alpha_neg @ b_neg.unsqueeze(-1)).squeeze(-1)

        alpha_nxt = alpha_nxt.view(alpha_shape)
        beta_nxt = beta_nxt.view(beta_shape)

        return alpha_nxt, beta_nxt

    def propagate_bounds_old(self, alpha, beta, lower):
        #print(alpha.size(), "ALPHA")
        alpha_pos = alpha.clamp(min=0)
        alpha_neg = alpha.clamp(max=0)

        a_pos = self.a_l if lower else self.a_u
        a_neg = self.a_u if lower else self.a_l
        b_pos = self.b_l if lower else self.b_u
        b_neg = self.b_u if lower else self.b_l
        #print(alpha_pos.size(), a_pos.size())
        #print((alpha_pos * a_pos.unsqueeze(1).unsqueeze(1).unsqueeze(1)).size(), "huh")

        alpha_nxt = alpha_pos * a_pos.unsqueeze(1).unsqueeze(1).unsqueeze(1) + alpha_neg * a_neg.unsqueeze(1).unsqueeze(1).unsqueeze(1)
        beta_nxt = beta + torch.einsum("bijkuvw,buvw->bijk", alpha_pos, b_pos) + torch.einsum("bijkuvw,buvw->bijk", alpha_neg, b_neg)
        return alpha_nxt, beta_nxt
    
    def propagate_bounds(self, alpha, beta, lower):
        return self.propagate_bounds_old(alpha, beta, lower)

    def propagate_d(self, d):
        return torch.abs(d)

class TorchCNN(nn.Module):
    def __init__(self, channels, kernel, stride, padding, activation):
        super(TorchCNN, self).__init__()
        self.architecture = "cnn"
        self.activation = activation
        self.channels = channels
        self.kernel = kernel
        self.stride = stride
        self.padding = padding
        for i in range(1, len(self.channels)):
            self.__setattr__("conv"+str(i), nn.Conv2d(channels[i-1], channels[i], kernel[i], stride[i], padding[i]))
        self.flatten = nn.Flatten()

    def forward(self, x):
        for i in range(1, len(self.channels)):
            x = self.__getattr__("conv"+str(i))(x)
            if i+1 != len(self.channels):
                x = self.activation(x)
        x = self.flatten(x)
        return x

    
    def CROWN(self, x, eps, true_label, target_label, use_zero=False):
        input_bounder = InputInequality(Box(x-eps, x+eps))
        l, u = [x-eps], [x+eps]
        wrapper = None
        shapes = []
        false_x = torch.zeros_like(x)
        shapes.append(torch.zeros_like(false_x))
        for i in range(1, len(self.channels)):
            false_x = self.__getattr__("conv"+str(i))(false_x)
            shapes.append(torch.zeros_like(false_x))

        for i in range(1, len(self.channels)):
            #print("I", i)
            alpha_l = cnn_identity(list(shapes[i].size()), x.device)
            beta_l = torch.zeros_like(shapes[i])
            alpha_u = cnn_identity(list(shapes[i].size()), x.device)
            beta_u = torch.zeros_like(shapes[i])
            for j in range(i, 0, -1):
                wrapper = ConvWrapper(self.__getattr__("conv"+str(j)))
                if j+1 == len(self.channels):
                    wrapper.set_target(l[j-1], u[j-1], true_label, target_label)
                else:
                    wrapper.set_target(l[j-1], u[j-1])
                wrapper.compute_bounds(l[j-1], u[j-1])
                alpha_l, beta_l = wrapper.propagate_bounds(alpha_l, beta_l, lower=True, next_shape=shapes[j-1].size())
                alpha_u, beta_u = wrapper.propagate_bounds(alpha_u, beta_u, lower=False,  next_shape=shapes[j-1].size())
                if j != 1:
                    wrapper = ReluWrapper(self.activation, use_zero=use_zero)
                    wrapper.compute_bounds(l[j-1], u[j-1])
                    alpha_l, beta_l = wrapper.propagate_bounds(alpha_l, beta_l, lower=True)
                    alpha_u, beta_u = wrapper.propagate_bounds(alpha_u, beta_u, lower=False)
            l.append(input_bounder.get_lower_bound(alpha_l, beta_l))
            u.append(input_bounder.get_upper_bound(alpha_u, beta_u))
        #print("COMPARE", input_bounder.total, 2 * num_neurons)
        return self.flatten(l[-1]), self.flatten(u[-1])

    def IBP_plus(self, x, eps, true_label, target_label):
        l, u = [x-eps], [x+eps]
        wrapper = None
        for i in range(1, len(self.channels)):
            l_nxt, u_nxt = l[i-1], u[i-1]
            if i != 1:
                wrapper = ReluWrapper(self.activation)
                #wrapper.compute_bounds(l_nxt, u_nxt)
                l_nxt, u_nxt = wrapper.propagate_intervals(l_nxt, u_nxt)
            wrapper = ConvWrapper(self.__getattr__("conv"+str(i)))
            if i+1 == len(self.channels):
                wrapper.set_target(l_nxt, u_nxt, true_label, target_label)
            else:
                wrapper.set_target(l_nxt, u_nxt)
            l_nxt, u_nxt = wrapper.propagate_intervals(l_nxt, u_nxt)
            l.append(l_nxt)
            u.append(u_nxt)
        return l, u
    
    def CROWN_IBP(self, x, eps, true_label, target_label, use_zero=False):
        l, u = self.IBP_plus(x, eps, true_label, target_label)
        input_bounder = InputInequality(Box(x-eps, x+eps))
        wrapper = None
        l_final, u_final = None, None
        shapes = []
        false_x = torch.zeros_like(x)
        shapes.append(torch.zeros_like(false_x))
        for i in range(1, len(self.channels)):
            false_x = self.__getattr__("conv"+str(i))(false_x)
            shapes.append(torch.zeros_like(false_x))
        for i in range(1, len(self.channels)):
            #x = self.__getattr__("conv"+str(i))(x)
            #print("DIM", i, x.size())
            if i+1 != len(self.channels):
                continue
            alpha_l = cnn_identity(list(shapes[i].size()), x.device)
            beta_l = torch.zeros_like(shapes[i])
            alpha_u = cnn_identity(list(shapes[i].size()), x.device)
            beta_u = torch.zeros_like(shapes[i])
            for j in range(i, 0, -1):
                wrapper = ConvWrapper(self.__getattr__("conv"+str(j)))
                if j+1 == len(self.channels):
                    wrapper.set_target(l[j-1], u[j-1], true_label, target_label)
                else:
                    wrapper.set_target(l[j-1], u[j-1])
                wrapper.compute_bounds(l[j-1], u[j-1])
                alpha_l, beta_l = wrapper.propagate_bounds(alpha_l, beta_l, lower=True, next_shape=shapes[j-1].size())
                alpha_u, beta_u = wrapper.propagate_bounds(alpha_u, beta_u, lower=False, next_shape=shapes[j-1].size())
                #print("AB", j, alpha_l.size(), beta_l.size(), alpha_u.size(), beta_u.size())
                if j != 1:
                    wrapper = ReluWrapper(self.activation, use_zero=use_zero)
                    wrapper.compute_bounds(l[j-1], u[j-1])
                    alpha_l, beta_l = wrapper.propagate_bounds(alpha_l, beta_l, lower=True)
                    alpha_u, beta_u = wrapper.propagate_bounds(alpha_u, beta_u, lower=False)
            l_final = input_bounder.get_lower_bound(alpha_l, beta_l)
            u_final = input_bounder.get_upper_bound(alpha_u, beta_u)
        return self.flatten(l_final), self.flatten(u_final)
    
    def IBP(self, x, eps, true_label, target_label):
        l, u = self.IBP_plus(x, eps, true_label, target_label)
        return self.flatten(l[-1]), self.flatten(u[-1])
    
    def PROVEN_IBP_alpha(self, x, eps, true_label, target_label, use_zero=False):
        l, u = self.IBP_plus(x, eps, true_label, target_label)
        input_bounder = ProbabilisticInputInequality(Box(x-eps, x+eps))
        wrapper = None
        q = None
        shapes = []
        false_x = torch.zeros_like(x)
        shapes.append(torch.zeros_like(false_x))
        for i in range(1, len(self.channels)):
            false_x = self.__getattr__("conv"+str(i))(false_x)
            shapes.append(torch.zeros_like(false_x))
        for i in range(1, len(self.channels)):
            #x = self.__getattr__("conv"+str(i))(x)
            if i+1 != len(self.channels):
                continue
            alpha_l = cnn_identity(list(shapes[i].size()), x.device)
            beta_l = torch.zeros_like(shapes[i])
            alpha_u = cnn_identity(list(shapes[i].size()), x.device)
            beta_u = torch.zeros_like(shapes[i])
            for j in range(i, 0, -1):
                wrapper = ConvWrapper(self.__getattr__("conv"+str(j)))
                if j+1 == len(self.channels):
                    wrapper.set_target(l[j-1], u[j-1], true_label, target_label)
                else:
                    wrapper.set_target(l[j-1], u[j-1])
                wrapper.compute_bounds(l[j-1], u[j-1])
                alpha_l, beta_l = wrapper.propagate_bounds(alpha_l, beta_l, lower=True, next_shape=shapes[j-1].size())
                alpha_u, beta_u = wrapper.propagate_bounds(alpha_u, beta_u, lower=False, next_shape=shapes[j-1].size())
                if j != 1:
                    wrapper = ReluWrapper(self.activation, use_zero=use_zero)
                    wrapper.compute_bounds(l[j-1], u[j-1])
                    alpha_l, beta_l = wrapper.propagate_bounds(alpha_l, beta_l, lower=True)
                    alpha_u, beta_u = wrapper.propagate_bounds(alpha_u, beta_u, lower=False)
            q = input_bounder.get_q(alpha_l, beta_l)
        return self.flatten(q)

    def PROVEN_IBP(self, x, eps, true_label, target_label, q, use_zero=False):
        l, u = self.IBP_plus(x, eps, true_label, target_label)
        input_bounder = ProbabilisticInputInequality(Box(x-eps, x+eps))
        wrapper = None
        l_final, u_final = None, None
        shapes = []
        false_x = torch.zeros_like(x)
        shapes.append(torch.zeros_like(false_x))
        for i in range(1, len(self.channels)):
            false_x = self.__getattr__("conv"+str(i))(false_x)
            shapes.append(torch.zeros_like(false_x))

        for i in range(1, len(self.channels)):
            #x = self.__getattr__("conv"+str(i))(x)
            #if x.size(0) == 1:
            #    print("PROVEN_IBP", i, l[i-1].size(), u[i-1].size())
            #    print("PROVEN_IBP", i, l[i-1][0,0,0,0], u[i-1][0,0,0,0])
            if i+1 != len(self.channels):
                continue
            alpha_l = cnn_identity(list(shapes[i].size()), x.device)
            beta_l = torch.zeros_like(shapes[i])
            alpha_u = cnn_identity(list(shapes[i].size()), x.device)
            beta_u = torch.zeros_like(shapes[i])
            for j in range(i, 0, -1):
                wrapper = ConvWrapper(self.__getattr__("conv"+str(j)))
                if j+1 == len(self.channels):
                    wrapper.set_target(l[j-1], u[j-1], true_label, target_label)
                else:
                    wrapper.set_target(l[j-1], u[j-1])
                wrapper.compute_bounds(l[j-1], u[j-1])
                alpha_l, beta_l = wrapper.propagate_bounds(alpha_l, beta_l, lower=True, next_shape=shapes[j-1].size())
                alpha_u, beta_u = wrapper.propagate_bounds(alpha_u, beta_u, lower=False, next_shape=shapes[j-1].size())
                if j != 1:
                    wrapper = ReluWrapper(self.activation, use_zero=use_zero)
                    wrapper.compute_bounds(l[j-1], u[j-1])
                    alpha_l, beta_l = wrapper.propagate_bounds(alpha_l, beta_l, lower=True)
                    alpha_u, beta_u = wrapper.propagate_bounds(alpha_u, beta_u, lower=False)
            l_final = input_bounder.get_lower_bound(alpha_l, beta_l, torch.zeros_like(beta_l)+q)
            u_final = input_bounder.get_upper_bound(alpha_u, beta_u, torch.zeros_like(beta_u)+q)
        return self.flatten(l_final), self.flatten(u_final)

    def UBP_plus(self, x, eps, true_label, target_label):
        a_l, b_l, a_u, b_u = [x], [-eps], [x], [eps]
        wrapper = None
        for i in range(1, len(self.channels)):
            a_l_nxt, b_l_nxt, a_u_nxt, b_u_nxt = a_l[i-1], b_l[i-1], a_u[i-1], b_u[i-1]
            if i != 1:
                wrapper = ReluWrapper(self.activation)
                #wrapper.compute_bounds(l_nxt, u_nxt)
                a_l_nxt, b_l_nxt, a_u_nxt, b_u_nxt = wrapper.propagate_uncertainty(a_l_nxt, b_l_nxt, a_u_nxt, b_u_nxt)
            wrapper = ConvWrapper(self.__getattr__("conv"+str(i)))
            if i+1 == len(self.channels):
                wrapper.set_target(a_l_nxt, a_u_nxt, true_label, target_label)
            else:
                wrapper.set_target(a_l_nxt, a_u_nxt)
            a_l_nxt, b_l_nxt, a_u_nxt, b_u_nxt = wrapper.propagate_uncertainty(a_l_nxt, b_l_nxt, a_u_nxt, b_u_nxt)
            a_l.append(a_l_nxt)
            b_l.append(b_l_nxt)
            a_u.append(a_u_nxt)
            b_u.append(b_u_nxt)
        return a_l, b_l, a_u, b_u

    def UBP_test(self, x, eps, true_label, target_label):
        a_l, b_l, a_u, b_u = self.UBP_plus(x, eps, true_label, target_label)
        l, u = self.IBP_plus(x, eps, true_label, target_label)
        print(len(a_l), len(b_l), len(a_u), len(b_u), len(l), len(u), "HEY")
        for i in range(len(l)):
            print(l[i][0,0,0,0].item(), u[i][0,0,0,0].item(), a_l[i][0,0,0,0].item(), a_u[i][0,0,0,0].item(), (a_l[i] + 5*b_l[i])[0,0,0,0].item(), (a_u[i] + 5*b_u[i])[0,0,0,0].item())
        return self.flatten(l[-1]), self.flatten(u[-1])
    
    def PROVEN_gaussian(self, x, eps, true_label, target_label, q_tot, use_zero=False):
        input_bounder = GaussianInequality(x, eps)
        l, u = [None], [None]
        #ol, ou = self.IBP_plus(x, eps, true_label, target_label)
        wrapper = None

        shapes = []
        false_x = torch.zeros_like(x)
        shapes.append(torch.zeros_like(false_x))
        for i in range(1, len(self.channels)):
            false_x = self.__getattr__("conv"+str(i))(false_x)
            shapes.append(torch.zeros_like(false_x))

        num_neurons = 0
        q = 0
        for i in range(len(self.channels)-1, 0, -1):
            num_neurons += np.product(shapes[i].size()[1:])
        q = q_tot / (2 * num_neurons)
        #print("COMPARE", input_bounder.total, 2 * num_neurons)

        for i in range(1, len(self.channels)):
            #print("I", i)
            alpha_l = cnn_identity(list(shapes[i].size()), x.device)
            beta_l = torch.zeros_like(shapes[i])
            alpha_u = cnn_identity(list(shapes[i].size()), x.device)
            beta_u = torch.zeros_like(shapes[i])
            #if x.size(0) == 1:
            #    print("PROVEN\t", i, l[i-1].size(), u[i-1].size())
            #    print("PROVEN\t", i, l[i-1][0,0,0,0], u[i-1][0,0,0,0])
            #    print("IBP\t", i, ol[i-1].size(), ou[i-1].size())
            #    print("IBP\t", i, ol[i-1][0,0,0,0], ou[i-1][0,0,0,0])
            for j in range(i, 0, -1):
                wrapper = ConvWrapper(self.__getattr__("conv"+str(j)))
                if j+1 == len(self.channels):
                    wrapper.set_target(l[j-1], u[j-1], true_label, target_label)
                else:
                    wrapper.set_target(l[j-1], u[j-1])
                wrapper.compute_bounds(l[j-1], u[j-1])
                alpha_l, beta_l = wrapper.propagate_bounds(alpha_l, beta_l, lower=True, next_shape=shapes[j-1].size())
                alpha_u, beta_u = wrapper.propagate_bounds(alpha_u, beta_u, lower=False,  next_shape=shapes[j-1].size())
                if j != 1:
                    wrapper = ReluWrapper(self.activation, use_zero=use_zero)
                    wrapper.compute_bounds(l[j-1], u[j-1])
                    alpha_l, beta_l = wrapper.propagate_bounds(alpha_l, beta_l, lower=True)
                    alpha_u, beta_u = wrapper.propagate_bounds(alpha_u, beta_u, lower=False)
            q_temp = q #if i > threshold else 0
            l.append(input_bounder.get_lower_bound(alpha_l, beta_l, torch.zeros_like(beta_l)+q_temp))
            u.append(input_bounder.get_upper_bound(alpha_u, beta_u, torch.zeros_like(beta_u)+q_temp))
            #print((l[i] - ol[i]).mean().item(), (u[i] - ou[i]).mean().item(), i, "HUH")
        #print("COMPARE", input_bounder.total, 2 * num_neurons)
        return self.flatten(l[-1]), self.flatten(u[-1])
    
    def PROVEN(self, x, eps, true_label, target_label, q_tot, use_zero=False):
        input_bounder = ProbabilisticInputInequality(Box(x-eps, x+eps))
        l, u = [x-eps], [x+eps]
        #ol, ou = self.IBP_plus(x, eps, true_label, target_label)
        wrapper = None

        shapes = []
        false_x = torch.zeros_like(x)
        shapes.append(torch.zeros_like(false_x))
        for i in range(1, len(self.channels)):
            false_x = self.__getattr__("conv"+str(i))(false_x)
            shapes.append(torch.zeros_like(false_x))

        num_neurons = 0
        threshold = 0
        q = 0
        for i in range(len(self.channels)-1, 0, -1):
            num_neurons += np.product(shapes[i].size()[1:])
            if num_neurons > 100:
                threshold = i-1
                q = q_tot / (2 * num_neurons)
                break
        #print("COMPARE", input_bounder.total, 2 * num_neurons)

        for i in range(1, len(self.channels)):
            #print("I", i)
            alpha_l = cnn_identity(list(shapes[i].size()), x.device)
            beta_l = torch.zeros_like(shapes[i])
            alpha_u = cnn_identity(list(shapes[i].size()), x.device)
            beta_u = torch.zeros_like(shapes[i])
            #if x.size(0) == 1:
            #    print("PROVEN\t", i, l[i-1].size(), u[i-1].size())
            #    print("PROVEN\t", i, l[i-1][0,0,0,0], u[i-1][0,0,0,0])
            #    print("IBP\t", i, ol[i-1].size(), ou[i-1].size())
            #    print("IBP\t", i, ol[i-1][0,0,0,0], ou[i-1][0,0,0,0])
            for j in range(i, 0, -1):
                wrapper = ConvWrapper(self.__getattr__("conv"+str(j)))
                if j+1 == len(self.channels):
                    wrapper.set_target(l[j-1], u[j-1], true_label, target_label)
                else:
                    wrapper.set_target(l[j-1], u[j-1])
                wrapper.compute_bounds(l[j-1], u[j-1])
                alpha_l, beta_l = wrapper.propagate_bounds(alpha_l, beta_l, lower=True, next_shape=shapes[j-1].size())
                alpha_u, beta_u = wrapper.propagate_bounds(alpha_u, beta_u, lower=False,  next_shape=shapes[j-1].size())
                if j != 1:
                    wrapper = ReluWrapper(self.activation, use_zero=use_zero)
                    wrapper.compute_bounds(l[j-1], u[j-1])
                    alpha_l, beta_l = wrapper.propagate_bounds(alpha_l, beta_l, lower=True)
                    alpha_u, beta_u = wrapper.propagate_bounds(alpha_u, beta_u, lower=False)
            q_temp = q if i > threshold else 0
            l.append(input_bounder.get_lower_bound(alpha_l, beta_l, torch.zeros_like(beta_l)+q_temp))
            u.append(input_bounder.get_upper_bound(alpha_u, beta_u, torch.zeros_like(beta_u)+q_temp))
            #print((l[i] - ol[i]).mean().item(), (u[i] - ou[i]).mean().item(), i, "HUH")
        #print("COMPARE", input_bounder.total, 2 * num_neurons)
        return self.flatten(l[-1]), self.flatten(u[-1])
    
    def PROVEN_hm(self, x, eps, true_label, target_label, q_tot, use_zero=False):
        input_bounder = ProbabilisticInputInequality(Box(x-eps, x+eps))
        l, u = [x-eps], [x+eps]
        #ol, ou = self.IBP_plus(x, eps, true_label, target_label)
        wrapper = None

        shapes = []
        false_x = torch.zeros_like(x)
        shapes.append(torch.zeros_like(false_x))
        for i in range(1, len(self.channels)):
            false_x = self.__getattr__("conv"+str(i))(false_x)
            shapes.append(torch.zeros_like(false_x))

        num_neurons = 0
        threshold = 0
        q = 0
        for i in range(len(self.channels)-1, 0, -1):
            num_neurons += np.product(shapes[i].size()[1:])
            if num_neurons > 100:
                threshold = i-1
                q = q_tot / (2 * num_neurons)
                break
        #print("COMPARE", input_bounder.total, 2 * num_neurons)

        for i in range(1, len(self.channels)):
            #print("I", i)
            alpha_l = cnn_identity(list(shapes[i].size()), x.device)
            beta_l = torch.zeros_like(shapes[i])
            alpha_u = cnn_identity(list(shapes[i].size()), x.device)
            beta_u = torch.zeros_like(shapes[i])
            #if x.size(0) == 1:
            #    print("PROVEN\t", i, l[i-1].size(), u[i-1].size())
            #    print("PROVEN\t", i, l[i-1][0,0,0,0], u[i-1][0,0,0,0])
            #    print("IBP\t", i, ol[i-1].size(), ou[i-1].size())
            #    print("IBP\t", i, ol[i-1][0,0,0,0], ou[i-1][0,0,0,0])
            for j in range(i, 0, -1):
                wrapper = ConvWrapper(self.__getattr__("conv"+str(j)))
                if j+1 == len(self.channels):
                    wrapper.set_target(l[j-1], u[j-1], true_label, target_label)
                else:
                    wrapper.set_target(l[j-1], u[j-1])
                wrapper.compute_bounds(l[j-1], u[j-1])
                alpha_l, beta_l = wrapper.propagate_bounds(alpha_l, beta_l, lower=True, next_shape=shapes[j-1].size())
                alpha_u, beta_u = wrapper.propagate_bounds(alpha_u, beta_u, lower=False,  next_shape=shapes[j-1].size())
                if j != 1:
                    wrapper = ReluWrapper(self.activation, use_zero=use_zero)
                    wrapper.compute_bounds(l[j-1], u[j-1])
                    alpha_l, beta_l = wrapper.propagate_bounds(alpha_l, beta_l, lower=True)
                    alpha_u, beta_u = wrapper.propagate_bounds(alpha_u, beta_u, lower=False)
            q_temp = q if i > threshold else 0
            l.append(input_bounder.get_lower_bound(alpha_l, beta_l, torch.zeros_like(beta_l)+q_temp))
            u.append(input_bounder.get_upper_bound(alpha_u, beta_u, torch.zeros_like(beta_u)+q_temp))
            if i+1 == len(self.channels):
                return alpha_l, beta_l, alpha_u, beta_u
            #print((l[i] - ol[i]).mean().item(), (u[i] - ou[i]).mean().item(), i, "HUH")
        #print("COMPARE", input_bounder.total, 2 * num_neurons)
        return None

    def BD(self, x, eps, true_label, target_label):
        return None

    def ANOTHER(self, x, eps, true_label, target_label):
        d = torch.ones_like(x) * eps 
        wrapper = None
        for i in range(1, len(self.channels)):
            if i != 1:
                #wrapper = ReluWrapper(self.activation)
                #l_nxt, u_nxt = wrapper.propagate_intervals(l_nxt, u_nxt)
                d = torch.abs(d)
            wrapper = ConvWrapper(self.__getattr__("conv"+str(i)))
            if i+1 == len(self.channels):
                wrapper.set_target(None, None, true_label, target_label)
            else:
                wrapper.set_target(None, None)
            d = wrapper.propagate_d(d)
        d = self.flatten(d)
        y = self.forward(x)
        return y-d, y+d
