import torch
import numpy as np
import math
import time
import torch.nn as nn
from attacks import adv_whitebox

from AIDomains.zonotope import HybridZonotope
import AIDomains.abstract_layers as abs_layers
import AIDomains.concrete_layers as conc_layers
from AIDomains.wrapper import propagate_abs
from AIDomains.ai_util import construct_C, construct_C_for_lf
from typing import Callable, Iterable, List, Dict, Optional
from utils import project_to_bounds
import logging
import sys


sys.path.append('prima4complete/')
sys.path.append('prima4complete/ELINA/python_interface/')
try:
    from src.mn_bab_optimizer import MNBabOptimizer
    from src.abstract_layers.abstract_network import AbstractNetwork as MNBABNetwork
    import src.concrete_layers as MNBAB_conc_layers
    from src.utilities.config import MNBabOptimizerConfig, BacksubstitutionConfig, IntermediateBoundsMethod
    MNBAB_available = True
except ImportError:
    print("MNBAB not found, please install prima4complete")
    MNBAB_available = False

def set_value_between(self, name, value, lower, upper, dtype):
    '''Value checker for properties. Checks whether the value is between lower and upper.'''
    value = dtype(value)
    assert lower <= value <= upper, f"{name} should be between {lower} and {upper}."
    setattr(self, name, value)

def set_value_typecheck(self, name, value, dtype):
    '''Value checker for properties. Checks whether the value is of the specified type.'''
    assert isinstance(value, dtype), f"{name} should be of type {dtype}."
    setattr(self, name, value)

def set_value_typecast(self, name, value, dtype, constraint=None, msg:str=None):
    '''Value checker for properties. Casts the value to the specified type.'''
    if constraint:
        assert constraint(value), msg
    setattr(self, name, dtype(value))

class BasicModelWrapper(nn.Module):
    '''
    Implements standard training procedure
    '''
    def __init__(self, net:Callable, loss_fn:Callable, input_dim, device, args, summary_accu_stat:bool=True, data_range=(0,1)):
        super().__init__()
        self.net = net
        self.BNs = [layer for layer in self.net if isinstance(layer, abs_layers._BatchNorm)]
        self.freeze_BN = False
        self.loss_fn = loss_fn
        self.args = args
        self.input_dim = input_dim
        self.device = device
        self.grad_cleaner = torch.optim.SGD(self.net.parameters(), lr=1) # will only call zero_grad on it
        self.summary_accu_stat = summary_accu_stat
        self.data_min = data_range[0]
        self.data_max = data_range[1]
        self.robust_weight = None
        self.natural = False
        self.name = ''

    def forward(self, x):
        return self.net(x)

    def Get_Performance(self, x, y, use_model=None):
        '''
        Compute standard statistics from the clean input.
        '''
        if use_model is None:
            outputs = self.forward(x)
        else:
            outputs = use_model(x)
        y = y.to(torch.int64)
        loss = self.loss_fn(outputs, y)
        #  = y.to(torch.float32)
        accu, pred_correct = self._Get_Accuracy(outputs, y)
        return loss, accu, pred_correct

    def _Get_Accuracy(self, outputs, y):
        assert len(outputs) == len(y), 'prediction and label should match.'
        pred_correct = torch.argmax(outputs, dim=1) == y
        num_correct = torch.sum(pred_correct)
        return num_correct / len(y), pred_correct

    def _set_BN(self, BN_layers, update_stat:bool=None):
        '''
        @param
            update_stat: can be combined with training=True; will use the existing BN stat instead in this case.
        '''
        if update_stat is not None:
            for layer in BN_layers:
                layer.update_stat = update_stat

    def compute_nat_loss_and_set_BN(self, x, y, **kwargs):
        self._set_BN(self.BNs, update_stat=True if not self.freeze_BN else False)
        nat_loss, nat_accu, is_nat_accu = self.Get_Performance(x, y)
        self._set_BN(self.BNs, update_stat=False)
        return nat_loss, nat_accu, is_nat_accu
    
    def get_robust_stat_from_bounds(self, lb, ub, y, **kwargs):
        raise NotImplementedError
    
    def get_robust_stat_from_input_noise(self, eps, x, y, **kwargs):
        # print(f"get_robust_stat: {x[0, 0, 5:15, 14]}")
        return self.get_robust_stat_from_bounds((x - eps).clamp(min=self.data_min), (x + eps).clamp(max=self.data_max), y, **kwargs)

    def combine_loss(self, nat_loss, robust_loss):
        loss = (1 - self.robust_weight) * nat_loss + self.robust_weight * robust_loss
        return loss
    
    def grad_postprocess(self):
        '''
        Will be called right before optimizer.step(); can be used to modify the grad.
        '''
        # grad clipping
        torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.args.grad_clip)

    def param_postprocess(self):
        '''
        Will be called right after optimizer.step(); can be used to modify the parameters.
        '''
        # make parameters in the starting k layers positive
        if self.args.num_nonneg_layer > 0:
            for layer in self.net[:self.args.num_nonneg_layer]:
                if hasattr(layer, "weight"):
                    layer.weight.data.clamp_(min=0)


    def format_return(self, loss, nat_loss, nat_accu, is_nat_accu, robust_loss, robust_accu, is_robust_accu):
        if self.summary_accu_stat:
            return (loss, nat_loss, robust_loss), (nat_accu.item(), robust_accu.item())
        else:
            return (loss, nat_loss, robust_loss), (nat_accu.item(), robust_accu.item()), (is_nat_accu, is_robust_accu)
        
    def compute_model_stat(self, x, y, eps, **kwargs):
        self.current_eps = eps
        nat_loss, nat_accu, is_nat_accu = self.compute_nat_loss_and_set_BN(x, y)
        cert_loss, cert_accu, is_cert_accu = self.get_robust_stat_from_input_noise(eps, x, y, **kwargs)
        loss = self.combine_loss(nat_loss, cert_loss)
        return self.format_return(loss, nat_loss, nat_accu, is_nat_accu, cert_loss, cert_accu, is_cert_accu)
        

class PGDModelWrapper(BasicModelWrapper):
    '''
    Implements PGD training
    '''
    def __init__(self, net, loss_fn, input_dim, device, args, **kwargs):
        super().__init__(net, loss_fn, input_dim, device, args)
        # change robust_weight directly during steps instead of modifying args
        self.num_steps = args.train_steps
        self.name = 'pgd'

    def get_robust_stat_from_bounds(self, lb, ub, y):
        '''
        Compute PGD loss and accuracy
        '''
        retain_graph = True if len(self.BNs) > 0 else False
        xadv = adv_whitebox(self.net, (lb+ub)/2, y, lb, ub, self.device, self.num_steps, ODI_num_steps=0, step_size=max(0.25, 2/self.num_steps), lossFunc="pgd", retain_graph=retain_graph)
        yadv = self.net(xadv)
        adv_accu, is_adv_accu = self._Get_Accuracy(yadv, y)
        adv_loss = self.loss_fn(yadv, y)
        return adv_loss, adv_accu, is_adv_accu
    
    
    def compute_model_stat(self, x, y, eps, **kwargs):
        '''
        Use clean stat for BN. May be improved via separating adv and clean BN stat, as in {ref}.
        '''
        self.current_eps = eps
        # compute natural loss
        nat_loss, nat_accu, is_nat_accu = self.compute_nat_loss_and_set_BN(x, y)
        # compute PGD loss
        if eps > 0 and not self.natural:
            adv_loss, adv_accu, is_adv_accu = self.get_robust_stat_from_input_noise(eps, x, y)
        else:
            adv_accu, is_adv_accu, adv_loss = nat_accu, is_nat_accu, nat_loss

        loss = self.combine_loss(nat_loss, adv_loss)
        return self.format_return(loss, nat_loss, nat_accu, is_nat_accu, adv_loss, adv_accu, is_adv_accu)



class BoxModelWrapper(BasicModelWrapper):
    '''
    Implements IBP training
    '''
    def __init__(self, net, loss_fn, input_dim, device, args, store_box_bounds:bool=False, **kwargs):
        super().__init__(net, loss_fn, input_dim, device, args)
        self.current_eps = self.args.test_eps
        self.store_box_bounds = store_box_bounds
        self.name = 'ibp'

    def get_IBP_bounds(self, abs_net, input_lb, input_ub, y=None):
        '''
        If y is specified, then use final layer elision trick and only provide pseudo bounds;
        '''
        x_abs = HybridZonotope.construct_from_bounds(input_lb, input_ub, domain='box')
        if y is None:
            abs_out = abs_net(x_abs)
            out_lb, out_ub = abs_out.concretize()
            if not self.store_box_bounds:
                abs_net.reset_bounds()
            return out_lb, out_ub
        else:
            pseudo_bound, pseudo_labels = propagate_abs(abs_net, "box", x_abs, y)
            if not self.store_box_bounds:
                abs_net.reset_bounds()
            return pseudo_bound, pseudo_labels

    def get_robust_stat_from_bounds(self, lb, ub, y, **kwargs):
        pseudo_bound, pseudo_labels = self.get_IBP_bounds(self.net, lb, ub, y)
        loss = self.loss_fn(pseudo_bound, pseudo_labels)
        cert_accu, is_cert_accu = self._Get_Accuracy(pseudo_bound, pseudo_labels)
        return loss, cert_accu, is_cert_accu


class HZonoModelWrapper(BoxModelWrapper):
    '''
    Implements IBP training
    '''
    def __init__(self, net, loss_fn, input_dim, device, args, store_box_bounds:bool=False, domain='hbox', **kwargs):
        super().__init__(net, loss_fn, input_dim, device, args, store_box_bounds=store_box_bounds, **kwargs)
        self.domain = domain
        self.name = domain

    def get_HZono_bounds(self, abs_net, input_lb, input_ub, y=None):
        '''
        If y is specified, then use final layer elision trick and only provide pseudo bounds;
        '''
        x_abs = HybridZonotope.construct_from_bounds(input_lb, input_ub, domain=self.domain)
        if y is None:
            abs_out = abs_net(x_abs)
            out_lb, out_ub = abs_out.concretize()
            if not self.store_box_bounds:
                abs_net.reset_bounds()
            return out_lb, out_ub
        else:
            pseudo_bound, pseudo_labels = propagate_abs(abs_net, self.domain, x_abs, y)
            if not self.store_box_bounds:
                abs_net.reset_bounds()
            return pseudo_bound, pseudo_labels

    def get_robust_stat_from_bounds(self, lb, ub, y, **kwargs):
        if self.args.subbatch_size is None:
            
            pseudo_bound, pseudo_labels = self.get_HZono_bounds(self.net, lb, ub, y)
            
        else:
            # divide the batch
            bs = self.args.subbatch_size
            pseudo_bounds_list = []
            pseudo_labels_list = []
            for i in range(math.ceil(len(lb) / bs)):
                cur_lb, cur_ub, cur_y = lb[i*bs:(i+1)*bs], ub[i*bs:(i+1)*bs], y[i*bs:(i+1)*bs]
                
                pseudo_bound, pseudo_labels = self.get_HZono_bounds(self.net, cur_lb, cur_ub, cur_y)
                pseudo_bounds_list.append(pseudo_bound)
                pseudo_labels_list.append(pseudo_labels)
                
            pseudo_bound = torch.cat(pseudo_bounds_list)
            pseudo_labels = torch.cat(pseudo_labels_list)

        loss = self.loss_fn(pseudo_bound, pseudo_labels)
        cert_accu, is_cert_accu = self._Get_Accuracy(pseudo_bound, pseudo_labels)
        return loss, cert_accu, is_cert_accu

            
class MNBaBDeepPolyModelWrapper(BasicModelWrapper):
    def __init__(self, net, loss_fn, input_dim, device, args, **kwargs):
        super().__init__(net, loss_fn, input_dim, device, args, **kwargs)
        self.name = 'mndp'
        self.current_eps = self.args.test_eps
        assert MNBAB_available
        assert isinstance(net, abs_layers.Sequential)
        concrete_net = net.to_concrete()
        assert isinstance(concrete_net, torch.nn.Sequential)

        # ensure we really use the same memory 
        parc_mem_locations = dict(((k, v.data_ptr()) for (k,v) in net.named_parameters()))
        for name, param in concrete_net.named_parameters():
            id = int(name.split('.')[0])
            if isinstance(concrete_net[id], conc_layers.Normalization): continue
            assert parc_mem_locations['layers.'+name] == param.data_ptr(), f"Memory location of {name} is different between abstract ({parc_mem_locations[name]}) and concrete networks ({param.data_ptr()})."

        if isinstance(concrete_net[0], conc_layers.Normalization):
            normalize = MNBAB_conc_layers.normalize.Normalize(concrete_net[0].mean.flatten().tolist(),
                                           concrete_net[0].std.flatten().tolist(),
                                           channel_dim=1).to(device)
            concrete_net = torch.nn.Sequential(normalize, *concrete_net[1:])

        self.anet = MNBABNetwork.from_concrete_module(concrete_net, input_dim, device=device)
        self.sync_keys = []
        for name, param in self.anet.named_parameters():
            id = int(name.split('.')[1])
            if isinstance(self.anet.layers[id], MNBAB_conc_layers.normalize.Normalize): continue
            assert parc_mem_locations[name] == param.data_ptr(), f"Memory location of {name} is different between abstract ({parc_mem_locations[name]}) and mnbab-abstract networks ({param.data_ptr()})."
            self.sync_keys.append(name)

        self.optimizer = MNBabOptimizer(MNBabOptimizerConfig(dict()),
                                        BacksubstitutionConfig({"intermediate_bounds_method": "dp",
                                                                "box_pass": True}))
                                        
    def _sync_net(self):
        self.anet.load_state_dict(self.net.state_dict(), strict=False)

    def get_DP_bounds(self, lb, ub, y, reuse_bounds: Optional[Dict] = None, return_bounds: bool = False,
                      reuse_bound_mode: Optional[str] = "standard", loss_fusion:bool = False):
        if loss_fusion and reuse_bounds is not None:
            query_matrix, query_offset = construct_C_for_lf(self.net.output_dim[-1], y, reuse_bounds["output"])
            two_sided = False
        else:
            query_matrix = construct_C(self.net.output_dim[-1], y)
            two_sided = loss_fusion
        self.anet.reset_input_bounds()
        extra_args = {}
        if reuse_bounds is not None:
            extra_args['reuse_bounds'] = reuse_bounds
            extra_args['reuse_bound_mode'] = reuse_bound_mode
        else:
            self.optimizer.backsubstitution_config.intermediate_bounds_method = IntermediateBoundsMethod["dp"]
        if return_bounds:
            extra_args['return_bounds'] = return_bounds
        res = self.optimizer.bound_minimum_with_deep_poly(
            self.optimizer.backsubstitution_config,
            lb,
            ub,
            self.anet,
            query_matrix,
            return_tensors=True,
            ibp_pass=self.optimizer.backsubstitution_config.box_pass,
            reset_input_bounds=False,
            get_stability=True,
            two_sided=two_sided,
            **extra_args,
        )
        deep_poly_lbs = res[0]

        if loss_fusion and reuse_bounds is not None:
            deep_poly_lbs = torch.log(torch.clip(query_offset - deep_poly_lbs.flatten(), min=1+1e-20))

        if return_bounds:
            ibp_stability, dp_stability, intermediate_bounds = res[-3:]
        else:
            ibp_stability, dp_stability = res[-2:]
            intermediate_bounds = None
        if reuse_bounds is not None:
            if ibp_stability is None:
                ibp_stability = [-1,-1]
            if dp_stability is None:
                dp_stability = [-1,-1]

        ibp_stab = ibp_stability[1] if ibp_stability is not None else 0
        dp_stab = dp_stability[1] if dp_stability is not None else 0
        return deep_poly_lbs, ibp_stab, dp_stab, intermediate_bounds

    def get_robust_stat_from_bounds(self, lb, ub, y, return_all=False, intermediate_bounds: Optional[Dict]=None, compute_bounds:bool=False):
        self._sync_net()
        assert not (compute_bounds and intermediate_bounds is not None), "providing bounds and computing them is not compatible"
        if self.args.subbatch_size is None or intermediate_bounds is not None:
            deep_poly_lbs, ibp_stability, dp_stability, intermediate_bounds = self.get_DP_bounds(lb, ub, y, intermediate_bounds, compute_bounds,
                                                                                                 self.args.reuse_bound_mode, loss_fusion=self.args.loss_fusion)
        else:
            # divide the batch
            bs = self.args.subbatch_size
            lbs_list = []
            ibp_stability, dp_stability = [], []
            intermediate_bound_list = []
            for i in range(math.ceil(len(lb) / bs)):
                cur_lb, cur_ub, cur_y = lb[i*bs:(i+1)*bs], ub[i*bs:(i+1)*bs], y[i*bs:(i+1)*bs]
                dp_lbs, ibp_stability_i, dp_stability_i, intermediate_bounds_i = self.get_DP_bounds(cur_lb, cur_ub, cur_y, intermediate_bounds, compute_bounds, loss_fusion=self.args.loss_fusion)
                lbs_list.append(dp_lbs)
                ibp_stability.append(ibp_stability_i)
                dp_stability.append(dp_stability_i)
                if intermediate_bounds_i is not None:
                    intermediate_bound_list.append(intermediate_bounds_i)
            deep_poly_lbs = torch.cat(lbs_list)
            ibp_stability = sum(ibp_stability)/len(ibp_stability)
            dp_stability = sum(dp_stability)/len(dp_stability)
            if compute_bounds:
                assert len(intermediate_bound_list) > 0
                intermediate_bounds = {"layer_ids": intermediate_bound_list[0]["layer_ids"]}
                for layer_id in intermediate_bound_list[0].keys():
                    if layer_id == "layer_ids": continue
                    intermediate_bounds[layer_id] = (torch.concat([x[layer_id][0] for x in intermediate_bound_list], 0),
                                                     torch.concat([x[layer_id][1] for x in intermediate_bound_list], 0) if intermediate_bound_list[0][layer_id][1] is not None else None)
            else:
                intermediate_bounds = None

        if self.args.loss_fusion and deep_poly_lbs.dim() == 1:
            loss = deep_poly_lbs.mean()
            cert_accu = torch.tensor(-1)
            is_cert_accu = None
        else:
            deep_poly_lbs_padded = torch.cat((torch.zeros(size=(deep_poly_lbs.size(0), 1), dtype=deep_poly_lbs.dtype, device=deep_poly_lbs.device), deep_poly_lbs), dim=1)
            pseudo_bound = -deep_poly_lbs_padded
            pseudo_labels = torch.zeros(size=(deep_poly_lbs.size(0),), dtype=torch.int64, device=deep_poly_lbs.device)

            loss = self.loss_fn(pseudo_bound, pseudo_labels)
            cert_accu, is_cert_accu = self._Get_Accuracy(pseudo_bound, pseudo_labels)

        if return_all:
            return loss, cert_accu, is_cert_accu, ibp_stability, dp_stability, intermediate_bounds
        elif compute_bounds:
            return intermediate_bounds
        else:
            return loss, cert_accu, is_cert_accu

    def compute_model_stat(self, x, y, eps, **kwargs):
        self.current_eps = eps
        nat_loss, nat_accu, is_nat_accu = self.compute_nat_loss_and_set_BN(x, y)
        cert_loss, cert_accu, is_cert_accu, *rest = self.get_robust_stat_from_input_noise(eps, x, y, **kwargs)
        loss = self.combine_loss(nat_loss, cert_loss)
        basic = self.format_return(loss, nat_loss, nat_accu, is_nat_accu, cert_loss, cert_accu, is_cert_accu)
        return (*basic, *rest)

class DeepPolyModelWrapper(BasicModelWrapper):
    def __init__(self, net, loss_fn, input_dim, device, args, use_dp_box:bool=False, relu_type='original', box_pass=True, **kwargs):
        super().__init__(net, loss_fn, input_dim, device, args, **kwargs)
        self.current_eps = self.args.test_eps
        self.use_dp_box = use_dp_box
        self.name = 'dpbox' if self.use_dp_box else 'dp'
        self.relu_type = relu_type
        self.box_pass = False 
        self.box_pass = relu_type != 'zero'
        print('init.relu',self.relu_type,self.box_pass)

    def get_robust_stat_from_bounds(self, lb, ub, y,**kwargs):
        x_abs = HybridZonotope.construct_from_bounds(lb, ub, domain='box')
        domain = "deeppoly_box" if self.use_dp_box else "deeppoly"
        # print('self.relu',self.relu_type)
        pseudo_bound, pseudo_labels = propagate_abs(self.net, domain, x_abs, y, relu_type=self.relu_type, box_pass=self.box_pass)
        loss = self.loss_fn(pseudo_bound, pseudo_labels)
        cert_accu, is_cert_accu = self._Get_Accuracy(pseudo_bound, pseudo_labels)
        return loss, cert_accu, is_cert_accu

class TAPSModelWrapper(BoxModelWrapper):
    '''
    Implements TAPS training
    '''
    def __init__(self, net, loss_fn, input_dim, device, args, store_box_bounds:bool=False, block_sizes=None, **kwargs):
        super().__init__(net=net, loss_fn=loss_fn, input_dim=input_dim, device=device, args=args, store_box_bounds=store_box_bounds, **kwargs)
        self.current_eps = self.args.test_eps
        self.net_blocks_abs = self._split_net_to_blocks(block_sizes)
        self.volatile_BNs = [layer for layer in self.net_blocks_abs[-1] if isinstance(layer, abs_layers._BatchNorm)]
        self.soft_thre = args.soft_thre
        self.tol = 1e-5
        self.num_steps = args.train_steps
        self.disable_TAPS = False # when true, TAPS is equivalent to IBP
        self.TAPS_grad_scale = args.TAPS_grad_scale
        self.name = 'taps'

    def _split_net_to_blocks(self, block_sizes):
        assert block_sizes is not None and len(block_sizes) == 2, f"TAPS assume two blocks: the first uses IBP, the second uses PGD."
        assert block_sizes[0] > 0 or block_sizes[1] > 0
        if block_sizes[0] == -1: block_sizes[0] = len(self.net) - block_sizes[1]
        if block_sizes[1] == -1: block_sizes[1] = len(self.net) - block_sizes[0]
        assert len(self.net) == sum(block_sizes), f"Provided block splits have {sum(block_sizes)} layers, but the net has {len(self.net)} layers."

        start = 0
        blocks = []
        for size in block_sizes:
            end = start + size
            abs_block = abs_layers.Sequential(*self.net[start:end])
            abs_block.output_dim = abs_block[-1].output_dim
            blocks.append(abs_block)
            start = end
        return blocks

    def get_robust_stat_from_bounds(self, lb, ub, y, **kwargs):
        if self.disable_TAPS:
            return super().get_robust_stat_from_bounds(lb, ub, y)
        
        # propagate the bound block-wisely
        for block_id, block in enumerate(self.net_blocks_abs):
            if block_id + 1 < len(self.net_blocks_abs):
                lb, ub = self.get_IBP_bounds(block, lb, ub)
            else:
                # prepare PGD bounds, Box bounds for y_i - y_t
                PGD_bound = self.get_TAPS_bounds(block, lb, ub, self.num_steps, y)
                Box_bound, pseudo_labels = self.get_IBP_bounds(block, lb, ub, y)
                pseudo_bound = PGD_bound

        # TODO: perform smoothed gradient scaling from 0 - TAPS_grad_scale, i.e., IBP -> TAPS
        loss = GradExpander.apply(self.loss_fn(pseudo_bound, pseudo_labels), self.TAPS_grad_scale) * self.loss_fn(Box_bound, pseudo_labels)

        cert_accu, is_cert_accu = self._Get_Accuracy(pseudo_bound, pseudo_labels)
        return loss, cert_accu, is_cert_accu

    def get_TAPS_bounds(self, block, input_lb, input_ub, num_steps, y=None):
        C = construct_C(block.output_dim[-1], y) if y is not None else None

        with torch.no_grad():
            pts = self._get_pivotal_points(block, input_lb, input_ub, num_steps, C)

        # Establish gradient link between pivotal points and bound
        # via rectified linear link
        pts = torch.transpose(pts, 0, 1)
        pts = RectifiedLinearGradientLink.apply(input_lb.unsqueeze(0), input_ub.unsqueeze(0), pts, self.args.soft_thre, self.tol)
        pts = torch.transpose(pts, 0, 1)
        bounds = self._get_bound_estimation_from_pts(block, pts, C)

        return bounds


    def _get_bound_estimation_from_pts(self, block, pts, C=None):
        assert C is not None, "PGD estimation is supposed to be used for margins."
        # # main idea: convert the 9 adv inputs into one batch to compute the bound at the same time; involve many reshaping
        batch_C = C.unsqueeze(1).expand(-1, pts.shape[1], -1, -1).reshape(-1, *(C.shape[1:])) # may need shape adjustment
        batch_pts = pts.reshape(-1, *(pts.shape[2:]))
        out_pts = block(batch_pts, C=batch_C)
        out_pts = out_pts.reshape(*(pts.shape[:2]), *(out_pts.shape[1:]))
        out_pts = - out_pts # the out is the lower bound of yt - yi, transform it to the upper bound of yi - yt
        # the out_pts should be in shape (batch_size, n_class - 1, n_class - 1)
        ub = torch.diagonal(out_pts, dim1=1, dim2=2) # shape: (batch_size, n_class - 1)
        estimated_bounds = torch.cat([torch.zeros(size=(ub.shape[0],1), dtype=ub.dtype, device=ub.device), ub], dim=1) # shape: (batch_size, n_class)

        return estimated_bounds


    def _get_pivotal_points_one_batch(self, block, lb, ub, num_steps, C):
        '''
        Get adversarial examples in the latent space.
        '''
        num_pivotal = block.output_dim[-1] - 1 # only need to estimate n_class - 1 dim for the final output

        def init_pts(input_lb, input_ub):
            rand_init = input_lb.unsqueeze(1) + (input_ub-input_lb).unsqueeze(1)*torch.rand(input_lb.shape[0], num_pivotal, *input_lb.shape[1:], device=self.device)
            return rand_init
        
        def select_schedule(num_steps):
            if num_steps >= 20 and num_steps <= 50:
                lr_decay_milestones = [int(num_steps*0.7)]
            elif num_steps > 50 and num_steps <= 80:
                lr_decay_milestones = [int(num_steps*0.4), int(num_steps*0.7)]
            elif num_steps > 80:
                lr_decay_milestones = [int(num_steps*0.3), int(num_steps*0.6), int(num_steps*0.8)]
            else:
                lr_decay_milestones = []
            return lr_decay_milestones

        # TODO: move this to args factory
        lr_decay_milestones = select_schedule(num_steps)
        lr_decay_factor = 0.1
        init_lr = max(0.2, 2/num_steps)

        retain_graph = True if len(self.volatile_BNs) > 0 else False
        pts = init_pts(lb, ub)
        variety = (ub - lb).unsqueeze(1).detach()
        best_estimation = -1e5*torch.ones(pts.shape[:2], device=pts.device)
        best_pts = torch.zeros_like(pts)
        with torch.enable_grad():
            for re in range(self.args.restarts):
                lr = init_lr
                pts = init_pts(lb, ub)
                for it in range(num_steps+1):
                    pts.requires_grad = True
                    estimated_pseudo_bound = self._get_bound_estimation_from_pts(block, pts, C=C)
                    improve_idx = estimated_pseudo_bound[:, 1:] > best_estimation
                    best_estimation[improve_idx] = estimated_pseudo_bound[:, 1:][improve_idx].detach()
                    best_pts[improve_idx] = pts[improve_idx].detach()
                    # wants to maximize the estimated bound
                    if it != num_steps:
                        loss = - estimated_pseudo_bound.sum()
                        loss.backward(retain_graph=retain_graph)
                        new_pts = pts - pts.grad.sign() * lr * variety
                        pts = project_to_bounds(new_pts, lb.unsqueeze(1), ub.unsqueeze(1)).detach()
                        if (it+1) in lr_decay_milestones:
                            lr *= lr_decay_factor
        return best_pts

    def _get_pivotal_points(self, block, input_lb, input_ub, num_steps, C=None):
        '''
        This assumes the block net is fixed in this procedure. If a BatchNorm is involved, freeze its stat before calling this function.
        '''
        assert C is not None # Should only estimate for the final block
        lb, ub = input_lb.clone().detach(), input_ub.clone().detach()

        pt_list = []
        # split into batches
        bs = self.args.estimation_batch
        lb_batches = [lb[i*bs:(i+1)*bs] for i in range(math.ceil(len(lb) / bs))]
        ub_batches = [ub[i*bs:(i+1)*bs] for i in range(math.ceil(len(ub) / bs))]
        C_batches = [C[i*bs:(i+1)*bs] for i in range(math.ceil(len(C) / bs))]
        for lb_one_batch, ub_one_batch, C_one_batch in zip(lb_batches, ub_batches, C_batches):
            pt_list.append(self._get_pivotal_points_one_batch(block, lb_one_batch, ub_one_batch, num_steps, C_one_batch))
        pts = torch.cat(pt_list, dim=0)
        return pts


    def get_robust_stat_from_input_noise(self, eps, x, y, **kwargs):
        cert_loss, cert_accu, is_cert_accu = super().get_robust_stat_from_input_noise(eps, x, y)
        self.grad_cleaner.zero_grad() # clean the grad from PGD propagation
        return cert_loss, cert_accu, is_cert_accu

class AdvBoundGradientLink(torch.autograd.Function):
    '''
    Belongs to TAPS.

    The template class for gradient link between adversarial inputs and the input bounds
    '''
    @staticmethod
    def forward(ctx, lb, ub, x, c:float, tol:float):
        ctx.save_for_backward(lb, ub, x)
        ctx.c = c
        ctx.tol = tol
        return x
    
    @staticmethod
    def backward(ctx, grad_x):
        raise NotImplementedError

class RectifiedLinearGradientLink(AdvBoundGradientLink):
    '''
    Belongs to TAPS.

    Estabilish Rectified linear gradient link between the input bounds and the input point.
    Note that this is not a valid gradient w.r.t. the forward function
    Take ub as an example: 
        For dims that x[dim] \in [lb, ub-c*(ub-lb)], the gradient w.r.t. ub is 0. 
        For dims that x[dim] == ub, the gradient w.r.t. ub is 1.
        For dims that x[dim] \in [ub-c*(ub-lb), ub], the gradient is linearly interpolated between 0 and 1.
    
    x should be modified to shape (batch_size, *bound_dims) by reshaping.
    bounds should be of shape (1, *bound_dims)
    '''
    @staticmethod
    def backward(ctx, grad_x):
        lb, ub, x = ctx.saved_tensors
        c, tol = ctx.c, ctx.tol
        slackness = c * (ub - lb)
        # handle grad w.r.t. ub
        thre = (ub - slackness)
        rectified_grad_mask = (x >= thre)
        grad_ub = (rectified_grad_mask * grad_x * (x - thre).clamp(min=0.5*tol) / slackness.clamp(min=tol)).sum(dim=0, keepdim=True)
        # handle grad w.r.t. lb
        thre = (lb + slackness)
        rectified_grad_mask = (x <= thre)
        grad_lb = (rectified_grad_mask * grad_x * (thre - x).clamp(min=0.5*tol) / slackness.clamp(min=tol)).sum(dim=0, keepdim=True)
        # we don't need grad w.r.t. x and param
        return grad_lb, grad_ub, None, None, None


class GradExpander(torch.autograd.Function):
    '''
    Belongs to TAPS.

    Multiply the gradient by alpha
    '''
    @staticmethod
    def forward(ctx, x, alpha:float=1):
        ctx.alpha = alpha
        return x
    
    @staticmethod
    def backward(ctx, grad_x):
        return ctx.alpha * grad_x, None


class SmallBoxModelWrapper(BoxModelWrapper):
    '''
    Implements SABR training
    '''
    def __init__(self, net, loss_fn, input_dim, device, args, store_box_bounds:bool=False, eps_shrinkage:float=1, **kwargs):
        super().__init__(net=net, loss_fn=loss_fn, input_dim=input_dim, device=device, args=args, store_box_bounds=store_box_bounds, **kwargs)
        self.eps_shrinkage = eps_shrinkage
        # assert 0 < eps_shrinkage < 1, "lambda must be in (0, 1); If lambda = 1, then this is exactly IBP, please use Box wrapper instead for efficiency."
        logging.info(f"Using small box with eps shrinkage: {self.eps_shrinkage}")

        if args.relu_shrinkage is not None:
            for layer in self.net:
                if isinstance(layer, abs_layers.ReLU):
                    layer.relu_shrinkage = args.relu_shrinkage
            logging.info(f"Setting ReLU shrinkage to {args.relu_shrinkage}")
        self.name = f's{self.name}_{eps_shrinkage}'
    
    def get_robust_stat_from_input_noise(self, eps, x, y, **kwargs):
        with torch.no_grad():
            lb_box, ub_box = (x-eps).clamp(min=self.data_min), (x+eps).clamp(max=self.data_max)
            retain_graph = True if len(self.BNs) > 0 else False
            # TODO: move hard-coded num_steps to args?
            adex = adv_whitebox(self.net, x, y, lb_box, ub_box, self.device, num_steps=10, ODI_num_steps=0, lossFunc="pgd", retain_graph=retain_graph)
            eff_eps = (ub_box - lb_box) / 2 * self.eps_shrinkage
            x_new = torch.clamp(adex, lb_box+eff_eps, ub_box-eff_eps)
            lb_new, ub_new = (x_new - eff_eps), (x_new + eff_eps)
            self.grad_cleaner.zero_grad() # clean grad from PGD
        
        return self.get_robust_stat_from_bounds(lb_new, ub_new, y, **kwargs)

class STAPSModelWrapper(TAPSModelWrapper, SmallBoxModelWrapper):
    def __init__(self, net, loss_fn, input_dim, device, args, store_box_bounds: bool = False, block_sizes=None, eps_shrinkage=1, **kwargs):
        super().__init__(net=net, loss_fn=loss_fn, input_dim=input_dim, device=device, args=args, store_box_bounds=store_box_bounds, block_sizes=block_sizes, eps_shrinkage=eps_shrinkage, **kwargs)

class SmallHZonoModelWrapper(HZonoModelWrapper, SmallBoxModelWrapper):
    def __init__(self, net, loss_fn, input_dim, device, args, eps_shrinkage=1, **kwargs):
        # print('eps_shrinkage',eps_shrinkage)
        super().__init__(net, loss_fn, input_dim, device, args, eps_shrinkage=eps_shrinkage, **kwargs)

class SmallMNDPModelWrapper(MNBaBDeepPolyModelWrapper, SmallBoxModelWrapper):
    def __init__(self, net, loss_fn, input_dim, device, args, eps_shrinkage=1, **kwargs):
        # print('eps_shrinkage',eps_shrinkage)
        super().__init__(net, loss_fn, input_dim, device, args, eps_shrinkage=eps_shrinkage, **kwargs)

class SmallDPModelWrapper(DeepPolyModelWrapper, SmallBoxModelWrapper):
    def __init__(self, net, loss_fn, input_dim, device, args, eps_shrinkage=1, **kwargs):
        # print('eps_shrinkage',eps_shrinkage)
        super().__init__(net, loss_fn, input_dim, device, args, eps_shrinkage=eps_shrinkage, **kwargs)

class GradAccuModelWrapper(BasicModelWrapper):
    '''
    Implements gradient accumulation for all the defined model wrappers.

    It contains a concrete model wrapper inside, divides the batch into specified size, and then merge the gradients together.

    If a BN model is provided, BN stat is set based on the whole batch instead of the divided batches to ensure consistency of results. Therefore, it is possible that memory cost gradually steps up during the accumulation.
    '''
    def __init__(self, model_wrapper:BasicModelWrapper, args):
        super().__init__(model_wrapper.net, model_wrapper.loss_fn, model_wrapper.input_dim, model_wrapper.device, args, model_wrapper.summary_accu_stat, (model_wrapper.data_min, model_wrapper.data_max))
        self.accu_batch_size = args.grad_accu_batch
        self.model_wrapper = model_wrapper
        self.robust_weight = None
        self.named_grads = {} # used to keep the grad of each accumulation batch
        self.disable_accumulation = False
        for key, p in self.net.named_parameters():
            self.named_grads[key] = 0.0
        logging.info(f"Using gradient accumulation with batch size: {self.accu_batch_size}")

    def compute_model_stat(self, x, y, eps, **kwargs):
        self.model_wrapper.robust_weight = self.robust_weight
        if self.disable_accumulation:
            return self.model_wrapper.compute_model_stat(x, y, eps, **kwargs)
        # set BN stat based on the whole batch
        self.model_wrapper.freeze_BN = False
        nat_loss, nat_accu, is_nat_accu = self.model_wrapper.compute_nat_loss_and_set_BN(x, y, **kwargs)
        self.model_wrapper.freeze_BN = True
        # split into batches
        num_accu_batches = math.ceil(len(x) / self.accu_batch_size)
        is_robust_accu = []
        robust_loss = []
        retain_graph = True if len(self.model_wrapper.BNs) > 0 else False
        summary_accu_stat = self.model_wrapper.summary_accu_stat
        self.model_wrapper.summary_accu_stat = False
        for i in range(num_accu_batches):
            batch_x = x[i*self.accu_batch_size:(i+1)*self.accu_batch_size]
            batch_y = y[i*self.accu_batch_size:(i+1)*self.accu_batch_size]
            (loss, _, batch_robust_loss), _, (_, batch_is_robust_accu) = self.model_wrapper.compute_model_stat(batch_x, batch_y, eps, **kwargs)
            is_robust_accu.append(batch_is_robust_accu)
            robust_loss.append(batch_robust_loss.item())
            if self.net.training:
                self.grad_cleaner.zero_grad()
                loss.backward(retain_graph=retain_graph)
                for key, p in self.net.named_parameters():
                    if p.grad is not None:
                        self.named_grads[key] += p.grad * len(batch_x) / self.accu_batch_size # consider last batch not equal size
        self.model_wrapper.summary_accu_stat = summary_accu_stat
        is_robust_accu = torch.cat(is_robust_accu)
        if self.net.training:
            self.grad_cleaner.zero_grad()
            for key, p in self.net.named_parameters():
                if not isinstance(self.named_grads[key], float):
                    p.grad = self.named_grads[key] / num_accu_batches
                    self.named_grads[key] = 0.0
        robust_loss = torch.mean(torch.tensor(robust_loss)).to(x.device) # no grad here
        robust_accu = (is_robust_accu.sum() / len(is_robust_accu)).to(x.device)
        loss = self.model_wrapper.combine_loss(nat_loss, robust_loss).detach()
        loss.requires_grad=True # dummy, backward will do nothing
        return self.format_return(loss, nat_loss, nat_accu, is_nat_accu, robust_loss, robust_accu, is_robust_accu)


class MultiFacetModelWrapper(BasicModelWrapper):
    '''
    Use loss = sum(Li(eps * eps_ratios[i]) * weights[i]) to train model
    This will call get_robust_stat_from_input_noise in each individual wrapper.
    The returned statistics are all based on the last model wrapper (not to be used in this case).
    '''
    def __init__(self, model_wrappers:List[BasicModelWrapper], eps_ratios:List[float], weight_ratios:List[float], args):
        super().__init__(model_wrappers[0].net, model_wrappers[0].loss_fn, model_wrappers[0].input_dim, model_wrappers[0].device, args, model_wrappers[0].summary_accu_stat, (model_wrappers[0].data_min, model_wrappers[0].data_max))
        # checks for consistency
        assert len(model_wrappers) >= 2, "At least two facets should be provided."
        for i, wrapper in enumerate(model_wrappers):
            assert wrapper.net is self.net, f"Wrapper {i} does not wrap the same net."
            assert wrapper.input_dim == self.input_dim, f"Wrapper {i} does not wrap the same input_dim."
            assert wrapper.device == self.device, f"Wrapper {i} does not wrap the same device."
            assert wrapper.data_min == self.data_min, f"Wrapper {i} does not wrap the same data_min."
            assert wrapper.data_max == self.data_max, f"Wrapper {i} does not wrap the same data_max."
            assert not isinstance(wrapper, GradAccuModelWrapper), f"Functional wrappers, e.g. gradient accumulation should not be used in multi-facet loss; instead, wrap multi-facet loss with a Gradient accumulation wrapper."
        assert len(eps_ratios)==len(weight_ratios)==len(model_wrappers), f"Provided lengths unmatch: wrappers [{len(model_wrappers)}], eps_ratios [{len(eps_ratios)}, weight_ratios [{len(weight_ratios)}]]"
        # checks for the base wrapper (first in the list)
        assert eps_ratios[0] == 1.0, "The first element in eps_ratios must be 1, i.e., current_eps will be used for the first wrapper."
        self.wrappers = model_wrappers
        self.eps_ratios = eps_ratios
        self.weight_ratios = weight_ratios
        self.store_box_bounds = False

    def compute_model_stat(self, x, y, eps, **kwargs):
        for wrapper in self.wrappers:
            wrapper.store_box_bounds = self.store_box_bounds
        self.freeze_BN = False
        # set BN stat based on the whole batch
        nat_loss, nat_accu, is_nat_accu = self.compute_nat_loss_and_set_BN(x, y, **kwargs)
        self.freeze_BN = True
        # compute robust losses
        total_robust_loss = 0.0
        for i, wrapper in enumerate(self.wrappers):
            robust_loss, robust_accu, is_robust_accu = wrapper.get_robust_stat_from_input_noise(eps * self.eps_ratios[i], x, y)
            total_robust_loss = total_robust_loss + robust_loss * self.weight_ratios[i]
        loss = self.combine_loss(nat_loss, total_robust_loss)
        return self.format_return(loss, nat_loss, nat_accu, is_nat_accu, total_robust_loss, robust_accu, is_robust_accu)

# Function wrappers
class BasicFunctionWrapper(BasicModelWrapper):     
    def __init__(self, model_wrapper:BasicModelWrapper):
        super().__init__(model_wrapper.net, model_wrapper.loss_fn, model_wrapper.input_dim, model_wrapper.device, model_wrapper.args, (model_wrapper.data_min, model_wrapper.data_max))

class WeightSmoothFunctionWrapper(BasicFunctionWrapper):
    '''
    Implements weight smoothing. Reference: https://arxiv.org/abs/2311.00521

    First add a random small perturbation to the model weight according to the scaled standard deviation of the weight,
    then do a normal step w.r.t. the perturbed model. After the step, remove the perturbation.
    Formally, w_{t+1} = w_t - eta * grad(L(w_t + noise)), where noise ~ N(0, std_scale * std(w_t)).

    @param
        std_scale: float; the scale factor for the perturbation
        reset_noise: bool; whether to reset the noise after the step
    
    @property
        std_scale: float; the scale factor for the perturbation
        reset_noise: bool; whether to reset the noise after the step

    @remark
        Expected to converges to a flater minima.
    '''
    def __init__(self, model_wrapper:BasicModelWrapper, std_scale:float):
        super().__init__(model_wrapper)
        self.wrapper = model_wrapper
        self._std_scale = float(std_scale)
        self._rng_state = None
        self._rng_generator = torch.Generator(device=self.device)
        self._current_stds = []
        self._layer_with_weights = [layer for layer in self.net if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d)]
    std_scale = property(fget=lambda self: self._std_scale, fset=lambda self, value: set_value_typecast(self, "_std_scale", value, float, lambda x: x>0, "std_scale must be a positive float."))

    def _perturb_weights(self):
        '''
        Add a Gaussian noise to the weights of Conv2d and Linear layers. The noise is scaled by the standard deviation of the weights and self.std_scale. This will modify the weights in-place. A generator state is kept to allow generating the same noise for the same layer later to reduce memory overhead.
        '''
        with torch.no_grad():
            self._current_stds = [torch.std(layer.weight) for layer in self._layer_with_weights]
            self._rng_state = self._rng_generator.get_state()
            for layer,std in zip(self._layer_with_weights, self._current_stds):
                noise = self.std_scale * std * torch.empty_like(layer.weight).normal_(generator=self._rng_generator)
                layer.weight.data += noise

    def compute_model_stat(self, x, y, eps, **kwargs):
        '''
        Perturb the weights, then call the wrapped model to compute the model stat.
        '''
        self.wrapper.robust_weight = self.robust_weight
        self.wrapper.summary_accu_stat = self.summary_accu_stat
        self.wrapper.store_box_bounds = self.store_box_bounds
        if self.net.training:
            self._perturb_weights()
        return self.wrapper.compute_model_stat(x, y, eps, **kwargs)
    
    def param_postprocess(self):
        '''
        If reset_noise is True, remove the noise from the weights.
        '''
        self._rng_generator.set_state(self._rng_state)
        for layer,std in zip(self._layer_with_weights, self._current_stds):
            noise = self.std_scale * self.current_lr * std * torch.empty_like(layer.weight).normal_(generator=self._rng_generator)
            layer.weight.data -= noise
        self.wrapper.param_postprocess()

class SAMFunctionWrapper(BasicFunctionWrapper):
    '''
    Implements Sharpness-Aware Minimization (SAM) training

    Reference: https://arxiv.org/abs/2010.01412;

    @param
        rho: float; the scale factor for the sharpness penalty

    @property
        rho: float; the scale factor for the sharpness penalty

    @remark
        Expected to converges to a flater minima.
    '''
    def __init__(self, model_wrapper:BasicModelWrapper, rho:float):
        super().__init__(model_wrapper)
        self.wrapper = model_wrapper
        self._rho = float(rho)
        self.pert_dict = {}
    rho = property(fget=lambda self: self._rho, fset=lambda self, value: set_value_typecast(self, "_rho", value, float, lambda x: x>0, "rho must be a positive float."))

    def compute_model_stat(self, x, y, eps, **kwargs):
        self.wrapper.robust_weight = self.robust_weight
        self.wrapper.summary_accu_stat = self.summary_accu_stat
        self.wrapper.store_box_bounds = self.store_box_bounds
        if self.net.training:
            self.net.eval()
            loss = self.wrapper.compute_model_stat(x, y, eps, **kwargs)[0][0]
            loss.backward()
            self.net.train()
            for k, p in self.net.named_parameters():
                if p.grad is None:
                    continue
                pert = p.grad / (torch.norm(p.grad) + 1e-12) * self.rho
                self.pert_dict[k] = pert
                p.grad = None
        result =  self.wrapper.compute_model_stat(x, y, eps, **kwargs)
        return result
    
    def param_postprocess(self):
        for k, p in self.net.named_parameters():
            if k in self.pert_dict.keys():
                p.data -= self.pert_dict[k]
                self.pert_dict[k] = 0.0
        self.wrapper.param_postprocess()


if __name__ == "__main__":
    '''
    Test init functions
    '''
    from networks import get_network
    from loaders import get_loaders
    import argparse
    from AIDomains.abstract_layers import Sequential
    logging.basicConfig(level=logging.INFO)

    device = "cpu"
    net = get_network("cnn_3layer_bn", "mnist", device)
    loss_fn = nn.CrossEntropyLoss()
    input_dim = (1, 28, 28)
    net = Sequential.from_concrete_network(net, input_dim)
    parser = argparse.ArgumentParser()
    args = parser.parse_args()
    eps = 0.1
    args.train_eps = args.test_eps = 0.1

    bs = 16
    x = torch.randn(bs, *input_dim).to(device)
    y = torch.randint(10, (bs, )).flatten().to(device)
    print(x.shape, y.shape)

    # args.pgd_weight = 0.5
    # model_wrapper = PGDModelWrapper(net, loss_fn, input_dim, device, args)

    # model_wrapper = BoxModelWrapper(net, loss_fn, input_dim, device, args, True)

    # args.soft_thre = 0.5
    # args.train_steps = 10
    # args.robust_weight = 0.5
    # args.estimation_batch = 16
    # args.TAPS_grad_scale = 0.5
    # model_wrapper = TAPSModelWrapper(net, loss_fn, input_dim, device, args, True, [6, 3])

    # args.relu_shrinkage = 0.8
    # args.robust_weight = 0.5
    # model_wrapper = SmallBoxModelWrapper(net, loss_fn, input_dim, device, args, True, 0.4)

    args.relu_shrinkage = 0.8
    args.robust_weight = 0.5
    args.soft_thre = 0.5
    args.train_steps = 10
    args.estimation_batch = 16
    args.TAPS_grad_scale = 0.5
    model_wrapper = STAPSModelWrapper(net, loss_fn, input_dim, device, args, True, [6,3], 0.4)
    model_wrapper.disable_TAPS = True

    args.grad_accu_batch = 4
    gc_model_wrapper = GradAccuModelWrapper(model_wrapper, args)

    (loss, nat_loss, cert_loss), (nat_accu, cert_accu) = model_wrapper.compute_model_stat(x, y, eps)