import torch.nn as nn
from module import Sequential, UnifiedConv2d, UnifiedLinear
from typing import Optional, List, Tuple, Union


class MLP(nn.Module):
    def __init__(self, default_mode="logit"):
        super().__init__()
        self.module_w_para = nn.Sequential(Sequential(UnifiedLinear(256, 128, default_mode=default_mode)),     
                                           Sequential(UnifiedLinear(128, 64, default_mode=default_mode)),    
                                           Sequential(UnifiedLinear(64, 32, default_mode=default_mode)),     
                                           Sequential(UnifiedLinear(32, 10, default_mode=default_mode))              
                                           )
        self.module_wo_para = nn.Sequential(
                                            nn.Sequential(nn.GELU()),
                                            nn.Sequential(nn.GELU()),
                                            nn.Sequential(nn.GELU()),
                                            )

    def forward(self, x, add_noise=None):
        x = x.reshape(-1,256)
        if add_noise is None:
            add_noise = len(self.module_w_para) * [False]
        x = self.module_wo_para[0](self.module_w_para[0](x, add_noise[0]))
        x = self.module_wo_para[1](self.module_w_para[1](x, add_noise[1]))
        x = self.module_wo_para[2](self.module_w_para[2](x, add_noise[2]))
        x = self.module_w_para[3](x, add_noise[3])
        return x

    def backward(self, loss, grad_sample=False, clip_threshold=None, batch_size=None, loss0=None):
        for seq in self.module_w_para:
            seq.backward(loss, grad_sample, clip_threshold, batch_size, loss0)

    def fetch_gradient(self):
        return [seq.fetch_gradient() for seq in self.module_w_para]

    def set_sigma(self, new_sigma):
        for sigma, seq in zip(new_sigma, self.module_w_para):
            seq.set_sigma(sigma)

    def dp_controller(self, sigma_0: Union[float, list], repeat_time_K: Union[int, list], total_clip_threshold_C: Union[float, int], loss):
        if not isinstance(sigma_0, list):
            sigma_0 = [sigma_0]*len(self.module_w_para)
        if not isinstance(repeat_time_K, list):
            repeat_time_K = [repeat_time_K]*len(self.module_w_para)
        modes, sigma_bounds, sigmas = [], [], []
        for i, seq in enumerate(self.module_w_para):
            result = seq.dp_controller(sigma_0[i], repeat_time_K[i], total_clip_threshold_C, loss)
            if result is not None:
                mode, sigma_bound, sigma = result
                modes.append(mode)
                sigma_bounds.append(sigma_bound)
                sigmas.append(sigma)
        return modes, sigma_bounds, sigmas

    def turn_off_antivariable(self):
        for seq in self.module_w_para:
            seq.turn_off_antivariable()


class ResNetMini(nn.Module):
    def __init__(self, default_mode="logit"):
        super().__init__()
        self.module_w_para = nn.Sequential(Sequential(UnifiedConv2d(3, 8, (3, 3), (1, 1), 1, default_mode=default_mode)),     
                                           Sequential(UnifiedConv2d(8, 16, (3, 3), (1, 1), 1, default_mode=default_mode)),    
                                           Sequential(UnifiedConv2d(16, 32, (3, 3), (1, 1), 1, default_mode=default_mode)),  
                                           Sequential(UnifiedConv2d(32, 32, (3, 3), (1, 1), 1, default_mode=default_mode)),   
                                           Sequential(UnifiedLinear(32 * 4 * 4, 10, default_mode='logit'))              
                                           )
        self.module_wo_para = nn.Sequential(
                                            nn.Sequential(nn.GELU()),
                                            nn.Sequential(nn.GELU(), nn.AvgPool2d(2)),
                                            nn.Sequential(nn.GELU()),
                                            nn.Sequential(nn.GELU()),
                                            nn.Sequential(nn.AdaptiveAvgPool2d(4), nn.Flatten())
                                            )

    def forward(self, x, add_noise=None):
        if add_noise is None:
            add_noise = len(self.module_w_para) * [False]
        x = self.module_wo_para[0](self.module_w_para[0](x, add_noise[0]))
        x = self.module_wo_para[1](self.module_w_para[1](x, add_noise[1]))
        x = self.module_wo_para[2](self.module_w_para[2](x, add_noise[2]))
        x = self.module_wo_para[3](self.module_w_para[3](x, add_noise[3])) + x
        x = self.module_w_para[4](self.module_wo_para[4](x), add_noise[4])
        return x

    def backward(self, loss, grad_sample=False, clip_threshold=None, batch_size=None, loss0=None):
        for seq in self.module_w_para:
            seq.backward(loss, grad_sample, clip_threshold, batch_size=batch_size, loss0=loss0)

    def fetch_gradient(self):
        return [seq.fetch_gradient() for seq in self.module_w_para]

    def set_sigma(self, new_sigma):
        for sigma, seq in zip(new_sigma, self.module_w_para):
            seq.set_sigma(sigma)

    def dp_controller(self, sigma_0: Union[float, list], repeat_time_K: Union[int, list], total_clip_threshold_C: Union[float, int], loss):
        if not isinstance(sigma_0, list):
            sigma_0 = [sigma_0]*len(self.module_w_para)
        if not isinstance(repeat_time_K, list):
            repeat_time_K = [repeat_time_K]*len(self.module_w_para)
        modes, sigma_bounds, sigmas = [], [], []
        for i, seq in enumerate(self.module_w_para):
            result = seq.dp_controller(sigma_0[i], repeat_time_K[i], total_clip_threshold_C, loss)
            if result is not None:
                mode, sigma_bound, sigma = result
                modes.append(mode)
                sigma_bounds.append(sigma_bound)
                sigmas.append(sigma)
        return modes, sigma_bounds, sigmas
    
    def turn_off_antivariable(self):
        for seq in self.module_w_para:
            seq.turn_off_antivariable()