import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class UnifiedLinear(nn.Linear):
    def __init__(self, in_features, out_features, init_std=1e-0, default_mode="logit", bias=True, device=None, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.log_noise_std = nn.Parameter(torch.full((out_features,), np.log(init_std), device=device))
        self.input_buf = None
        self.epsilon_buf = None
        self.epsilon_buf_w = None
        self.epsilon_buf_b = None
        self.added_noise = False
        # switch the pertubation mode
        self.default_mode = default_mode
        self.perturb_mode = default_mode
        self.use_antivariable = True
        self.add_noise_to_jacobian = False

    def turn_off_antivariable(self):
        self.use_antivariable = False

    def forward(self, input:Tensor, add_noise=False):
        if self.perturb_mode == "logit":
            logit_output = super().forward(input)
            if add_noise:
                bs, out_features = logit_output.shape
                epsilon = torch.zeros_like(logit_output, device=self.log_noise_std.device)
                if self.use_antivariable:
                    epsilon[:bs//2] += torch.randn((bs//2, out_features), device=self.log_noise_std.device)
                    epsilon[bs//2:] -= epsilon[:bs//2]
                else:
                    epsilon += torch.randn((bs, out_features), device=self.log_noise_std.device)
                noise = epsilon * torch.exp(self.log_noise_std[None,:])
                self.input_buf = input
                self.epsilon_buf = epsilon
                self.added_noise = True
                return logit_output + noise
            else:
                self.input_buf = input
                return logit_output
        else:
            if add_noise:
                bs = input.shape[0]
                w = self.weight.unsqueeze(0).repeat(bs,1,1)
                b = self.bias.unsqueeze(0).repeat(bs,1)
                epsilon_w = torch.zeros_like(w, device=self.log_noise_std.device)
                epsilon_b = torch.zeros_like(b, device=self.log_noise_std.device)
                if self.use_antivariable:
                    epsilon_w[:bs // 2] += torch.randn((bs // 2, self.out_features, self.in_features), device=self.log_noise_std.device)
                    epsilon_b[:bs // 2] += torch.randn((bs // 2, self.out_features), device=self.log_noise_std.device)
                    epsilon_w[bs // 2:] -= epsilon_w[:bs // 2]
                    epsilon_b[bs // 2:] -= epsilon_b[:bs // 2]
                else:
                    epsilon_w += torch.randn((bs, self.out_features, self.in_features), device=self.log_noise_std.device)
                    epsilon_b += torch.randn((bs, self.out_features), device=self.log_noise_std.device)
                self.epsilon_buf_w = epsilon_w
                self.epsilon_buf_b = epsilon_b
                w += epsilon_w * torch.exp(self.log_noise_std[None, :, None])
                b += epsilon_b * torch.exp(self.log_noise_std[None, :])
                logit_output = torch.bmm(w,input[:,:,None]).squeeze(-1) + b
                self.added_noise = True
            else:
                logit_output = super().forward(input)
            return logit_output

    def backward(self, loss, grad_sample=False, clip_threshold=None, batch_size=None, loss0=None):
        if self.perturb_mode == "logit":
            whole_bs = self.input_buf.shape[0]
            input_buf = self.input_buf.reshape(whole_bs,-1,self.in_features)
            epsilon_buf = self.epsilon_buf.reshape(whole_bs,-1,self.out_features)
            loss = loss[:,None,None]
            noise_std = torch.exp(self.log_noise_std)
            if not grad_sample:
                self.weight.grad = torch.einsum('nki,nkj->ji', input_buf, loss * epsilon_buf) / (noise_std[:,None] * whole_bs)
                self.bias.grad = torch.sum(loss * epsilon_buf, dim=(0,1)) / (noise_std * whole_bs)
            else:
                self.weight.grad_sample = torch.einsum('nki,nkj->nji', input_buf, loss * epsilon_buf) / noise_std[None,:,None]
                self.bias.grad_sample = torch.sum(loss * epsilon_buf, dim=1) / noise_std[None,:]

                self.weight.grad_sample = torch.mean(self.weight.grad_sample.reshape(-1, batch_size, self.out_features, self.in_features), dim=0)
                self.bias.grad_sample = torch.mean(self.bias.grad_sample.reshape(-1, batch_size, self.out_features), dim=0)
                if clip_threshold is None:
                    clip_coeff = torch.ones(whole_bs, device=self.weight.grad_sample.device)
                else:
                    weight_grad_sample_norms = torch.linalg.vector_norm(self.weight.grad_sample, ord=2, dim=(1,2))
                    bias_grad_sample_norms = torch.linalg.vector_norm(self.bias.grad_sample, ord=2, dim=1)
                    total_grad_temp_norms = (weight_grad_sample_norms ** 2 + bias_grad_sample_norms ** 2).sqrt()
                    clip_coeff = torch.clamp(clip_threshold / total_grad_temp_norms, max=1.0)
                self.weight.grad = torch.mean(self.weight.grad_sample * clip_coeff[:,None,None], dim=0)
                self.bias.grad = torch.mean(self.bias.grad_sample * clip_coeff[:,None], dim=0)
                
                if self.add_noise_to_estimated_gradient:
                    random_noise = ((torch.randn_like(self.weight.grad)*self.covariance_needed**0.5) @ self.egvec.T)/batch_size
                    self.weight.grad += random_noise
                    self.add_noise_to_estimated_gradient = False
                    self.egval, self.egvec = None, None
                    self.covariance_needed = None

            self.log_noise_std.grad = None
            self.input_buf = None
            self.epsilon_buf = None
            self.added_noise = False
        else:
            whole_bs = loss.shape[0]
            noise_std = torch.exp(self.log_noise_std)
            if not grad_sample:
                tmp = loss[:, None, None] * self.epsilon_buf_w / noise_std[None, :, None]
                self.weight.grad = torch.mean(tmp, 0)
                tmp = loss[:, None] * self.epsilon_buf_b / noise_std[None, :]
                self.bias.grad = torch.mean(tmp, 0)
            else:
                self.weight.grad_sample = loss[:, None, None] * self.epsilon_buf_w / noise_std[None,:,None]
                self.bias.grad_sample = loss[:, None] * self.epsilon_buf_b / noise_std[None,:]
                self.weight.grad_sample = torch.mean(self.weight.grad_sample.reshape(-1, batch_size, self.out_features, self.in_features),dim=0)
                self.bias.grad_sample = torch.mean(self.bias.grad_sample.reshape(-1, batch_size, self.out_features), dim=0)
                if clip_threshold is None:
                    clip_coeff = torch.ones(self.weight.grad_sample.shape[0], device=self.weight.grad_sample.device)
                else:
                    weight_grad_sample_norms = torch.linalg.vector_norm(self.weight.grad_sample, ord=2, dim=(1,2))
                    bias_grad_sample_norms = torch.linalg.vector_norm(self.bias.grad_sample, ord=2, dim=1)
                    total_grad_temp_norms = (weight_grad_sample_norms ** 2 + bias_grad_sample_norms ** 2).sqrt()
                    clip_coeff = torch.clamp(clip_threshold / total_grad_temp_norms, max=1.0)
                self.weight.grad = torch.mean(self.weight.grad_sample * clip_coeff[:,None,None], dim=0)
                self.bias.grad = torch.mean(self.bias.grad_sample * clip_coeff[:,None], dim=0)

            self.log_noise_std.grad = None
            self.epsilon_buf_w = None
            self.epsilon_buf_b = None
            self.added_noise = False
        self.switch_mode(self.default_mode) # backward之后回调

    def loss_times_jacobian_sum_eigenvalue_lower_bound(self, input, loss):
        assert len(input.shape) >= 2
        weight_subjacobian = input.reshape(input.shape[0],-1, input.shape[-1])
        weight_subjacobian_matmul = torch.matmul(weight_subjacobian.transpose(1,2), weight_subjacobian)
        loss_times_weight_subjacobian_sum = (loss[:,None,None]**2 * weight_subjacobian_matmul).sum(dim=0)
        egval, egvec = torch.linalg.eigh(loss_times_weight_subjacobian_sum)
        weight_eigenvalue_lower_bound = torch.min(torch.linalg.eigvalsh(loss_times_weight_subjacobian_sum).real).item()
        bias_eigenvalue_lower_bound = (loss**2).sum().item() * weight_subjacobian.shape[1]**2
        return max(min(weight_eigenvalue_lower_bound, bias_eigenvalue_lower_bound), 0), egval, egvec

    def dp_controller(self, sigma_0, repeat_time_K, clip_threshold_C, loss):
        current_sigma = torch.exp(self.log_noise_std)
        if self.perturb_mode == "logit":
            assert self.input_buf is not None
            min_eigenvalue, egval, egvec = self.loss_times_jacobian_sum_eigenvalue_lower_bound(self.input_buf, loss)
            sigma = (min_eigenvalue/(repeat_time_K*clip_threshold_C**2*sigma_0**2))**0.5
            if sigma <= current_sigma.min():
                self.add_noise_to_estimated_gradient = True
                self.egval, self.egvec = egval, egvec
                self.covariance_needed = clip_threshold_C**2*sigma_0**2 - (self.egval*(self.egval>0.1)/(repeat_time_K*current_sigma.min()**2))
                self.covariance_needed = self.covariance_needed * (self.covariance_needed>0)
                sigma = current_sigma.min()
        if self.perturb_mode == "weight":
            min_eigenvalue = len(loss)
            sigma = (min_eigenvalue/(repeat_time_K*clip_threshold_C**2*sigma_0**2))**0.5
        new_sigma = torch.min(sigma, current_sigma) if torch.is_tensor(sigma) else torch.min(torch.tensor(sigma).to(current_sigma.dtype), current_sigma)
        self.set_sigma(new_sigma)
        return int(self.perturb_mode == "logit"),\
              float(sigma.item() if isinstance(sigma, torch.Tensor) else sigma),\
              float(new_sigma[0].item() if isinstance(new_sigma[0], torch.Tensor) else new_sigma[0]),

    def fetch_gradient(self):
        return self.weight.grad.detach().cpu()

    def set_sigma(self, new_sigma):
        if isinstance(new_sigma, float) or isinstance(new_sigma, int):
            assert new_sigma > 0
            self.log_noise_std.data = torch.log(torch.full_like(self.log_noise_std.data, new_sigma))
        elif isinstance(new_sigma, Tensor):
            assert torch.min(new_sigma) > 0
            assert new_sigma.shape == self.log_noise_std.data.shape
            self.log_noise_std.data = torch.log(new_sigma)

    def switch_mode(self, mode="weight"):
        self.input_buf = None
        self.epsilon_buf = None
        self.epsilon_buf_w = None
        self.epsilon_buf_b = None
        self.added_noise = False
        self.perturb_mode = mode


class UnifiedConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, init_std=1e-0, default_mode="logit", bias=True, device=None, dtype=None):
        super().__init__(in_channels, out_channels, kernel_size, stride, padding, bias=bias, device=device, dtype=dtype)
        self.unfold = nn.Unfold(kernel_size, padding, stride)
        self.log_noise_std = nn.Parameter(torch.full((out_channels,), np.log(init_std), device=device))
        self.input_buf = None
        self.epsilon_buf = None
        self.epsilon_buf_w = None
        self.epsilon_buf_b = None
        self.added_noise = False
        # switch the pertubation mode
        self.default_mode = default_mode
        self.perturb_mode = default_mode
        self.use_antivariable = True
        self.add_noise_to_estimated_gradient = False

    def turn_off_antivariable(self):
        self.use_antivariable = False

    def forward(self, input, add_noise=False):
        if self.perturb_mode == "logit":
            logit_output = super().forward(input)
            if add_noise:
                N, out_channels, H_, W_ = logit_output.shape
                epsilon = torch.zeros_like(logit_output, device=self.log_noise_std.device)
                if self.use_antivariable:
                    epsilon[:N//2] += torch.randn((N//2, out_channels, H_, W_), device=self.log_noise_std.device)
                    epsilon[N//2:] -= epsilon[:N//2]
                else:
                    epsilon += torch.randn((N, out_channels, H_, W_), device=self.log_noise_std.device)
                noise = epsilon * torch.exp(self.log_noise_std[None,:,None,None])
                self.input_buf = input
                self.epsilon_buf = epsilon
                self.added_noise = True
                return logit_output + noise
            else:
                self.input_buf = input
                return logit_output
        else:
            if add_noise:
                C_out, C_in, H, W = self.weight.data.shape
                N, _, H_in, W_in = input.shape
                noise_std = torch.exp(self.log_noise_std).repeat(N)
                w = self.weight.repeat(N,1,1,1)
                b = self.bias.repeat(N)
                epsilon_w = torch.zeros_like(w, device=self.log_noise_std.device)
                epsilon_b = torch.zeros_like(b, device=self.log_noise_std.device)
                if self.use_antivariable:
                    epsilon_w[:N * C_out // 2] += torch.randn((N * C_out // 2, C_in, H, W), device=self.log_noise_std.device)
                    epsilon_b[:N * C_out // 2] += torch.randn((N * C_out // 2,), device=self.log_noise_std.device)
                    epsilon_w[N * C_out // 2:] -= epsilon_w[:N * C_out // 2]
                    epsilon_b[N * C_out // 2:] -= epsilon_b[:N * C_out // 2]
                else:
                    epsilon_w += torch.randn((N * C_out, C_in, H, W), device=self.log_noise_std.device)
                    epsilon_b += torch.randn((N * C_out,), device=self.log_noise_std.device)
                self.epsilon_buf_w = epsilon_w
                self.epsilon_buf_b = epsilon_b
                w += epsilon_w * noise_std[:, None, None, None]
                b += epsilon_b * noise_std
                logit_output = F.conv2d(input.reshape(1,N*C_in,H_in, W_in), w, b, stride=self.stride, padding=self.padding, groups=N)
                _, _, H_out, W_out = logit_output.shape
                logit_output = logit_output.reshape(N,C_out,H_out,W_out)
                self.added_noise = True
            else:
                logit_output = super().forward(input)
            return logit_output

    def backward(self, loss, grad_sample=False, clip_threshold=None, batch_size=None, loss0=None):
        if self.perturb_mode == "logit":
            KH,KW = self.kernel_size
            N, C_in, _, _ = self.input_buf.shape
            _, C_out, H_out, W_out = self.epsilon_buf.shape
            noise_std = torch.exp(self.log_noise_std)

            output_grad = self.epsilon_buf * loss[:,None,None,None]
            
            if not grad_sample:
                grad = F.conv2d(self.input_buf.transpose(0,1), output_grad.transpose(0,1), torch.zeros(size=(C_out,), device=self.log_noise_std.device),
                    dilation=self.stride, padding=self.padding)
                if grad.shape[2]>self.weight.shape[2]:
                    grad = grad[:,:,:self.weight.shape[2],:]
                if grad.shape[3]>self.weight.shape[3]:
                    grad = grad[:,:,:,:self.weight.shape[3]]
                self.weight.grad = grad.transpose(0,1) / (N * noise_std[:,None,None,None])
                self.bias.grad = torch.sum(output_grad,(0,2,3)) / (N * noise_std)
            else:
                weight_grad_sample_temp = F.conv2d(self.input_buf.transpose(0,1), output_grad.reshape(-1,1,H_out,W_out), torch.zeros(size=(N*C_out,), 
                                device=self.log_noise_std.device), groups=N, dilation=self.stride, padding=self.padding)
                if weight_grad_sample_temp.shape[2]>self.weight.shape[2]:
                    weight_grad_sample_temp = weight_grad_sample_temp[:,:,:self.weight.shape[2],:]
                if weight_grad_sample_temp.shape[3]>self.weight.shape[3]:
                    weight_grad_sample_temp = weight_grad_sample_temp[:,:,:,:self.weight.shape[3]]
                self.weight.grad_sample = weight_grad_sample_temp.reshape(C_in, N, C_out,KH,KW).permute(1,2,0,3,4)/ noise_std[None,:,None,None,None] 
                self.bias.grad_sample = torch.sum(output_grad,(2,3))/noise_std[None,:] 

                self.weight.grad_sample = torch.mean(self.weight.grad_sample.reshape(-1, batch_size, C_out, C_in, KH, KW), dim=0)
                self.bias.grad_sample = torch.mean(self.bias.grad_sample.reshape(-1, batch_size, C_out), dim=0)

                if clip_threshold is None:
                    clip_coeff = torch.ones(N, device=self.weight.grad_sample.device)
                else:
                    weight_grad_sample_norms = torch.linalg.vector_norm(self.weight.grad_sample, ord=2, dim=(1,2,3,4))
                    bias_grad_sample_norms = torch.linalg.vector_norm(self.bias.grad_sample, ord=2, dim=1)
                    total_grad_temp_norms = (weight_grad_sample_norms ** 2 + bias_grad_sample_norms ** 2).sqrt()
                    clip_coeff = torch.clamp(clip_threshold / total_grad_temp_norms, max=1.0)
                self.weight.grad = torch.mean(self.weight.grad_sample * clip_coeff[:,None,None,None,None], dim=0)
                self.bias.grad = torch.mean(self.bias.grad_sample * clip_coeff[:,None], dim=0)

                if self.add_noise_to_estimated_gradient:
                    random_noise = ((torch.randn_like(self.weight.grad).reshape(C_out,-1)*self.covariance_needed**0.5)@self.egvec.T)/batch_size
                    self.weight.grad += random_noise.reshape(self.weight.grad.shape)
                    self.add_noise_to_estimated_gradient = False
                    self.egval, self.egvec = None, None
                    self.covariance_needed = None

            self.log_noise_std.grad = None
            self.input_buf = None
            self.epsilon_buf = None
            self.added_noise = False
        else:
            whole_bs = loss.shape[0]
            noise_std = torch.exp(self.log_noise_std)
            if not grad_sample:
                tmp_w = torch.stack(torch.split(self.epsilon_buf_w, split_size_or_sections=self.out_channels, dim=0))
                tmp = loss[:,None,None,None,None] * tmp_w
                self.weight.grad = torch.sum(tmp, dim=0) / (whole_bs * noise_std[:,None,None,None])
                tmp_b = torch.stack(torch.split(self.epsilon_buf_b, split_size_or_sections=self.out_channels, dim=0))
                tmp = loss[:, None] * tmp_b
                self.bias.grad = torch.sum(tmp, 0) / (whole_bs * noise_std)
            else:
                tmp_w = torch.stack(torch.split(self.epsilon_buf_w, split_size_or_sections=self.out_channels, dim=0))
                self.weight.grad_sample = loss[:,None,None,None,None] * tmp_w / noise_std[None,:,None,None,None]
                tmp_b = torch.stack(torch.split(self.epsilon_buf_b, split_size_or_sections=self.out_channels, dim=0))
                self.bias.grad_sample = loss[:, None] * tmp_b / noise_std[None,:]

                self.weight.grad_sample = torch.mean(self.weight.grad_sample.reshape(-1, batch_size, self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1]), dim=0)
                self.bias.grad_sample = torch.mean(self.bias.grad_sample.reshape(-1, batch_size, self.out_channels), dim=0)

                if clip_threshold is None:
                    clip_coeff = torch.ones(N)
                else:
                    weight_grad_sample_norms = torch.linalg.vector_norm(self.weight.grad_sample, ord=2, dim=(1,2,3,4))
                    bias_grad_sample_norms = torch.linalg.vector_norm(self.bias.grad_sample, ord=2, dim=1)
                    total_grad_temp_norms = (weight_grad_sample_norms ** 2 + bias_grad_sample_norms ** 2).sqrt()
                    clip_coeff = torch.clamp(clip_threshold / total_grad_temp_norms, max=1.0)
                self.weight.grad = torch.mean(self.weight.grad_sample * clip_coeff[:,None,None,None,None], dim=0)
                self.bias.grad = torch.mean(self.bias.grad_sample * clip_coeff[:,None], dim=0)
                
            self.log_noise_std.grad = None
            self.epsilon_buf_w = None
            self.epsilon_buf_b = None
            self.added_noise = False
        self.switch_mode(self.default_mode)

    def loss_times_jacobian_sum_eigenvalue_lower_bound(self, input, loss):
        assert len(input.shape) == 4
        input_unfold  = self.unfold(input)
        weight_subjacobian = input_unfold.transpose(1,2)
        weight_subjacobian_matmul = torch.matmul(weight_subjacobian.transpose(1,2), weight_subjacobian)
        loss_times_weight_subjacobian_sum = (loss[:,None,None]**2 * weight_subjacobian_matmul).sum(dim=0)
        egval, egvec = torch.linalg.eigh(loss_times_weight_subjacobian_sum)
        weight_eigenvalue_lower_bound = torch.min(torch.linalg.eigvalsh(loss_times_weight_subjacobian_sum).real).item()
        bias_eigenvalue_lower_bound = (loss**2).sum().item() * input_unfold.shape[-1]**2
        return max(min(weight_eigenvalue_lower_bound, bias_eigenvalue_lower_bound), 0), egval, egvec

    def dp_controller(self, sigma_0, repeat_time_K, clip_threshold_C, loss):
        current_sigma = torch.exp(self.log_noise_std)
        if self.perturb_mode == "logit":
            assert self.input_buf is not None
            min_eigenvalue, egval, egvec = self.loss_times_jacobian_sum_eigenvalue_lower_bound(self.input_buf, loss)
            sigma = (min_eigenvalue/(repeat_time_K*clip_threshold_C**2*sigma_0**2))**0.5
            if sigma <= current_sigma.min():
                self.add_noise_to_estimated_gradient = True
                self.egval, self.egvec = egval, egvec
                self.covariance_needed = clip_threshold_C**2*sigma_0**2 - (self.egval*(self.egval>0.1)/(repeat_time_K*current_sigma.min()**2))
                self.covariance_needed = self.covariance_needed * (self.covariance_needed>0)
                sigma = current_sigma.min()
        if self.perturb_mode == "weight":
            min_eigenvalue = (loss**2).sum(dim=0)
            sigma = (min_eigenvalue/(repeat_time_K*clip_threshold_C**2*sigma_0**2))**0.5
        
        new_sigma = torch.min(sigma, current_sigma) if torch.is_tensor(sigma) else torch.min(torch.tensor(sigma).to(current_sigma.dtype), current_sigma)
        self.set_sigma(new_sigma)
        return int(self.perturb_mode == "logit"),\
              float(sigma.item() if isinstance(sigma, torch.Tensor) else sigma),\
              float(new_sigma[0].item() if isinstance(new_sigma[0], torch.Tensor) else new_sigma[0]),

    def fetch_gradient(self):
        return self.weight.grad.detach().cpu()

    def set_sigma(self, new_sigma):
        if isinstance(new_sigma, float) or isinstance(new_sigma, int):
            assert new_sigma > 0
            self.log_noise_std.data = torch.log(torch.full_like(self.log_noise_std.data, new_sigma))
        elif isinstance(new_sigma, Tensor):
            assert torch.min(new_sigma) > 0
            assert new_sigma.shape == self.log_noise_std.data.shape
            self.log_noise_std.data = torch.log(new_sigma)

    def switch_mode(self, mode="weight"):
        self.input_buf = None
        self.epsilon_buf = None
        self.epsilon_buf_w = None
        self.epsilon_buf_b = None
        self.added_noise = False
        self.perturb_mode = mode


class Sequential(nn.Sequential):
    def __init__(self, *args):
        super().__init__(*args)

    def forward(self, input, add_noise=False, epsilon_buf=None):
        length = len(self)
        for i, module in enumerate(self):
            if (add_noise is True) and (epsilon_buf is not None) and (i == length - 1):
                try:
                    input = module(input, add_noise, epsilon_buf)
                    continue
                except TypeError:
                    pass
            try:
                input = module(input, add_noise)
            except TypeError:
                input = module(input)
        return input

    def backward(self, loss, grad_sample=False, clip_threshold=None, batch_size=None, loss0=None):
        for module in self:
            try:
                if module.added_noise:
                    module.backward(loss, grad_sample, clip_threshold, batch_size, loss0)
            except AttributeError:
                continue

    def fetch_gradient(self):
        gradient_list = []
        for module in self:
            try:
                gradient_list.append(module.fetch_gradient())
            except AttributeError:
                continue
        if len(gradient_list)==1:
            return gradient_list[0]
        else:
            return gradient_list

    def set_sigma(self, new_sigma):
        for module in self:
            try:
                module.set_sigma(new_sigma)
            except AttributeError:
                continue
    
    def dp_controller(self, sigma_0, repeat_time_K, clip_threshold_C, loss):
        for module in self:
            try:
                return module.dp_controller(sigma_0, repeat_time_K, clip_threshold_C, loss)
            except AttributeError:
                continue

    def turn_off_antivariable(self):
        for module in self:
            try:
                module.turn_off_antivariable()
            except AttributeError:
                continue