#########################################################################
##   This file is part of the auto_LiRPA library, a core part of the   ##
##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##
##   by the α,β-CROWN Team                                             ##
##                                                                     ##
##   Copyright (C) 2020-2024 The α,β-CROWN Team                        ##
##   Primary contacts: Huan Zhang <huan@huan-zhang.com>                ##
##                     Zhouxing Shi <zshi@cs.ucla.edu>                 ##
##                     Kaidi Xu <kx46@drexel.edu>                      ##
##                                                                     ##
##    See CONTRIBUTORS for all author contacts and affiliations.       ##
##                                                                     ##
##     This program is licensed under the BSD 3-Clause License,        ##
##        contained in the LICENCE file in this directory.             ##
##                                                                     ##
#########################################################################
from .base import *

class BoundDropout(Bound):
    def __init__(self, attr=None, inputs=None, output_index=0, options=None):
        super().__init__(attr, inputs, output_index, options)
        if 'ratio' in attr:
            self.ratio = attr['ratio']
            self.dynamic = False
        else:
            self.ratio = None
            self.dynamic = True
        self.clear()

    def clear(self):
        self.mask = None

    def forward(self, *inputs):
        x = inputs[0]
        if not self.training:
            return x
        if self.dynamic:
            # Inputs: data, ratio (optional), training_mode (optional)
            # We assume ratio must exist in the inputs.
            # We ignore training_mode, but will use self.training which can be
            # changed after BoundedModule is built.
            assert (inputs[1].dtype == torch.float32 or
                    inputs[1].dtype == torch.float64)
            self.ratio = inputs[1]
        if self.ratio >= 1:
            raise ValueError('Ratio in dropout should be less than 1')
        self.mask = torch.rand(x.shape, device=self.ratio.device) > self.ratio
        return x * self.mask / (1 - self.ratio)

    def _check_forward(self):
        """ If in the training mode, a forward pass should have been called."""
        if self.training and self.mask is None:
            raise RuntimeError('For a model with dropout in the training mode, '\
                'a clean forward pass must be called before bound computation')

    def bound_backward(self, last_lA, last_uA, *args, **kwargs):
        empty_A = [(None, None)] * (len(args) -1)
        if not self.training:
            return [(last_lA, last_uA), *empty_A], 0, 0
        self._check_forward()
        def _bound_oneside(last_A):
            if last_A is None:
                return None
            return last_A * self.mask / (1 - self.ratio)
        lA = _bound_oneside(last_lA)
        uA = _bound_oneside(last_uA)
        return [(lA, uA), *empty_A], 0, 0

    def bound_forward(self, dim_in, x, *args):
        if not self.training:
            return x
        self._check_forward()
        lw = x.lw * self.mask.unsqueeze(1) / (1 - self.ratio)
        lb = x.lb * self.mask / (1 - self.ratio)
        uw = x.uw * self.mask.unsqueeze(1) / (1 - self.ratio)
        ub = x.ub * self.mask / (1 - self.ratio)
        return LinearBound(lw, lb, uw, ub)

    def interval_propagate(self, *v):
        if not self.training:
            return v[0]
        self._check_forward()
        h_L, h_U = v[0]
        lower = h_L * self.mask / (1 - self.ratio)
        upper = h_U * self.mask / (1 - self.ratio)
        return lower, upper
