import sys
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np


def export(fn):
    mod=sys.modules[fn.__module__]
    if hasattr(mod, '__all__'):
        mod.__all__.append(fn.__name__)
    else:
        mod.__all__=[fn.__name__]
    return fn

def parameter_count(module):
    return sum(int(parameter.numel()) for parameter in module.parameters())

def _weights_init(m):
    classname=m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)


####### Export #########################################################################################
@export
def MLP(input_dim,  h_vec, output_dim, activation, **kwargs):
    net=MLP(input_dim=input_dim, h_vec=h_vec, output_dim=output_dim, activation=activation)

@export
def MLP_DLF(input_dim, h_vec, mu_h_vec, phi_h_vec, latent_dim, output_dim, num_ens, activation, **kwargs):
    net=MLP_DLF(input_dim=input_dim, h_vec=h_vec, mu_h_vec=mu_h_vec, phi_h_vec=phi_h_vec, latent_dim=latent_dim, output_dim=output_dim, activation=activation, num_ens=num_ens)
    return net

@export
def WRN_16_1(num_classes=10, activation = 'ReLU', **kwargs):
    net=WideResNet(num_classes, depth=16, widen_factor=1, droprate=kwargs['droprate'], activation=activation)
    return net

@export
def WRN_20_1(num_classes=10, activation = 'ReLU', **kwargs):
    net=WideResNet(num_classes, depth=20, widen_factor=1, droprate=kwargs['droprate'], activation=activation)
    return net

@export
def WRN_24_1(num_classes=10, activation = 'ReLU', **kwargs):
    net=WideResNet(num_classes, depth=24, widen_factor=1, droprate=kwargs['droprate'], activation=activation)
    return net

@export
def WRN_28_1(num_classes=10, activation = 'ReLU', **kwargs):
    net=WideResNet(num_classes, depth=28, widen_factor=1, droprate=kwargs['droprate'], activation=activation)
    return net

@export
def WRN_16_1_DLF(num_classes=10, latent_dim = 8, activation = 'ReLU', **kwargs):
    if 'num_ens' in kwargs.keys():
        num_ens=kwargs['num_ens']
        assert num_ens > 1
    if 'mu_h_vec' in kwargs.keys():
        mu_h_vec=kwargs['mu_h_vec']
    if 'phi_h_vec' in kwargs.keys():
        phi_h_vec=kwargs['phi_h_vec']
    net=WideResNet_DLF(num_classes, depth=28, widen_factor=1, mu_h_vec=mu_h_vec, phi_h_vec=phi_h_vec, droprate=kwargs['droprate'], activation=activation, num_ens=num_ens, latent_dim = latent_dim)
    return net

@export
def WRN_20_1_DLF(num_classes=10, latent_dim = 8, activation = 'ReLU', **kwargs):
    if 'num_ens' in kwargs.keys():
        num_ens=kwargs['num_ens']
        assert num_ens > 1
    if 'mu_h_vec' in kwargs.keys():
        mu_h_vec=kwargs['mu_h_vec']
    if 'phi_h_vec' in kwargs.keys():
        phi_h_vec=kwargs['phi_h_vec']
    net=WideResNet_DLF(num_classes, depth=20, widen_factor=1, mu_h_vec=mu_h_vec, phi_h_vec=phi_h_vec, droprate=kwargs['droprate'], activation=activation, num_ens=num_ens, latent_dim = latent_dim)
    return net

@export
def WRN_24_1_DLF(num_classes=10, latent_dim = 8, activation = 'ReLU', **kwargs):
    if 'num_ens' in kwargs.keys():
        num_ens=kwargs['num_ens']
        assert num_ens > 1
    if 'mu_h_vec' in kwargs.keys():
        mu_h_vec=kwargs['mu_h_vec']
    if 'phi_h_vec' in kwargs.keys():
        phi_h_vec=kwargs['phi_h_vec']
    net=WideResNet_DLF(num_classes, depth=24, widen_factor=1, mu_h_vec=mu_h_vec, phi_h_vec=phi_h_vec, droprate=kwargs['droprate'], activation=activation, num_ens=num_ens, latent_dim = latent_dim)
    return net

@export
def WRN_28_1_DLF(num_classes=10, latent_dim = 8, activation = 'ReLU', **kwargs):
    if 'num_ens' in kwargs.keys():
        num_ens=kwargs['num_ens']
        assert num_ens > 1
    if 'mu_h_vec' in kwargs.keys():
        mu_h_vec=kwargs['mu_h_vec']
    if 'phi_h_vec' in kwargs.keys():
        phi_h_vec=kwargs['phi_h_vec']
    net=WideResNet_DLF(num_classes, depth=28, widen_factor=1, mu_h_vec=mu_h_vec, phi_h_vec=phi_h_vec, droprate=kwargs['droprate'], activation=activation, num_ens=num_ens, latent_dim = latent_dim)
    return net


####### MLP ############################################################################################
class MLP(nn.Module):
    def __init__(self, input_dim : int, h_vec : list, output_dim=1, activation='Softplus'):
        super(MLP, self).__init__()

        self.input_dim=input_dim
        self.output_dim=output_dim
        self.L=len(h_vec)
        self.p_vec=np.hstack([input_dim, h_vec])
        
        if activation == 'ReLU':
            self.activation=nn.ReLU()
        elif activation == 'LeakyReLU':
            self.activation=nn.LeakyReLU(negative_slope=0.1)
        elif activation == 'SiLU':
            self.activation=nn.SiLU()
        elif activation == 'hardtanh':
            self.activation=nn.Hardtanh(min_val=-20.,max_val=20)
        elif activation == 'Softplus':
            self.activation=nn.Softplus()
            
        self.common_layers=self._make_layer()  
        if self.output_dim > 0:
            self.heads=nn.Linear(h_vec[-1], self.output_dim)
            
    def _make_layer(self):
        layers=[]
        for l in range(self.L):
            layer=[]
            layer.append(nn.Linear(self.p_vec[l], self.p_vec[l+1]))
            layer.append(nn.BatchNorm1d(self.p_vec[l+1], eps=1e-4, affine=True))
            layer.append(self.activation)
            layers.append(nn.Sequential(*layer))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x=x.view(-1, self.input_dim)
        x=self.common_layers(x)
        if self.output_dim > 0:
            x=self.heads(x)
        return x
    
    def freeze(self, name):
        for n, para in self.named_parameters():
            if name in n:
                para.requires_grad=False

    def melt(self, name):
        for n, para in self.named_parameters():
            if name in n:
                para.requires_grad=True
    
    def freeze_all(self):
        for n, para in self.named_parameters():
            para.requires_grad=False

    def melt_all(self):
        for n, para in self.named_parameters():
            para.requires_grad=True
                
           
class MLP_DLF(nn.Module):
    def __init__(self, input_dim : int, h_vec : list, latent_dim : int, mu_h_vec = [], phi_h_vec = [], output_dim=1, activation='ReLU', num_ens=4):
        super(MLP_DLF, self).__init__()

        self.input_dim=input_dim
        self.output_dim=output_dim
        self.L=len(h_vec)
        self.p_vec=np.hstack([input_dim, h_vec])
        self.latent_dim=latent_dim
        
        if activation == 'ReLU':
            self.activation=nn.ReLU()
        elif activation == 'LeakyReLU':
            self.activation=nn.LeakyReLU(negative_slope=0.1)
        elif activation == 'SiLU':
            self.activation=nn.SiLU()
        elif activation == 'Hardtanh':
            self.activation=nn.Hardtanh(min_val=-20.,max_val=20)
        elif activation == 'Softplus':
            self.activation=nn.Softplus()
            
        self.common_layers=self._make_layer()
        self.fc=Linear_DLF_reg(in_features=h_vec[-1], mu_h_vec=mu_h_vec, phi_h_vec=phi_h_vec, activation=activation, out_features=self.output_dim, num_ens=num_ens, latent_dim=latent_dim)
        
        for m in self.modules():
            if isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.zero_()
        
    def _make_layer(self):
        layers=[]
        for l in range(self.L):
            layer=[]
            layer.append(nn.Linear(self.p_vec[l], self.p_vec[l+1]))
            layer.append(nn.BatchNorm1d(self.p_vec[l+1], eps=1e-4, affine=True))
            layer.append(self.activation)
            layers.append(nn.Sequential(*layer))
        return nn.Sequential(*layers)
    
    def forward(self, x, z=None):
        x=x.view(-1, self.input_dim)
        x=self.common_layers(x)
        out, mu_f, bias_f, phi, sigma_f=self.fc(x, z)
        return out, mu_f, bias_f, phi, sigma_f
    
    def freeze(self, name):
        for n, para in self.named_parameters():
            if name in n:
                para.requires_grad=False

    def melt(self, name):
        for n, para in self.named_parameters():
            if name in n:
                para.requires_grad=True
    
    def freeze_all(self):
        for n, para in self.named_parameters():
            para.requires_grad=False

    def melt_all(self):
        for n, para in self.named_parameters():
            para.requires_grad=True
    

class Linear_DLF_reg(nn.Module):
    def __init__(self, in_features, mu_h_vec, phi_h_vec, activation='ReLU', out_features=1, num_ens=4, latent_dim=9):
        super(Linear_DLF_reg, self).__init__()

        self.in_features=in_features
        self.out_features=out_features
        self.num_ens=num_ens
        
        if activation == 'ReLU':
            self.activation=nn.ReLU()
        elif activation == 'LeakyReLU':
            self.activation=nn.LeakyReLU(negative_slope=0.1)
        elif activation == 'SiLU':
            self.activation=nn.SiLU()
        elif activation == 'Hardtanh':
            self.activation=nn.Hardtanh(min_val=-20.,max_val=20)
        elif activation == 'Softplus':
            self.activation=nn.Softplus()
        if mu_h_vec == []: 
            self.mu_basis=self._make_layer(np.hstack([in_features, out_features]))
        else: 
            self.mu_basis=self._make_layer(np.hstack([in_features, mu_h_vec, out_features]))
        if phi_h_vec == []:
            self.phi_basis=self._make_layer(np.hstack([in_features, latent_dim]))
        else:
            self.phi_basis=self._make_layer(np.hstack([in_features, phi_h_vec, latent_dim]))
        self.z=nn.Parameter(torch.Tensor(num_ens, latent_dim, self.out_features))
        self.sigma_f=nn.Parameter(torch.tensor([0.])) # corresponding epsilon
        self.init_parameters()
        
    def _make_layer(self, h_vec):
        layers=[]
        for l in range(len(h_vec)-1):
            layer=[]
            layer.append(nn.Linear(h_vec[l], h_vec[l+1]))
            if l+2 !=len(h_vec):
                layer.append(nn.BatchNorm1d(h_vec[l+1], eps=1e-4, affine=True))
                layer.append(self.activation)
            layers.append(nn.Sequential(*layer))
        return nn.Sequential(*layers)
        
    def init_parameters(self):
        nn.init.normal_(self.z, mean=0, std=1)
        
    def forward(self, x, z=None):
        if z == None:
            z=self.z

        mu_f=self.mu_basis(x)
        phi_f=self.phi_basis(x)
        bias_f=(phi_f @ z).permute(1,0,2)
        output_f=mu_f.unsqueeze(1) + bias_f

        return output_f, mu_f, bias_f, phi_f, self.sigma_f.exp()

####### WideReseNet ####################################################################################
class BasicBlock_WRN(nn.Module):
    def __init__(self, in_planes, out_planes, stride, droprate=0.0, activate_before_residual=False, activation='ReLU'):
        super(BasicBlock_WRN, self).__init__()
        self.bn1=nn.BatchNorm2d(in_planes, momentum=0.1)
        self.activation=activation
        if self.activation == 'ReLU':
            self.relu1=nn.ReLU(inplace=True)
            self.relu2=nn.ReLU(inplace=True)
        elif self.activation == 'LeakyReLU':
            self.relu1=nn.LeakyReLU(negative_slope=0.1, inplace=True)
            self.relu2=nn.LeakyReLU(negative_slope=0.1, inplace=True)
        elif self.activation == 'SiLU':
            self.relu1=nn.SiLU(inplace=True)
            self.relu2=nn.SiLU(inplace=True)
        self.conv1=nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2=nn.BatchNorm2d(out_planes, momentum=0.1)
        self.conv2=nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.droprate=droprate
        self.Dropout=nn.Dropout(p=self.droprate)
        self.equalInOut=(in_planes == out_planes)
        self.convShortcut=(not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) or None
        self.activate_before_residual=activate_before_residual
        
    def forward(self, x):
        if not self.equalInOut and self.activate_before_residual == True:
            x=self.relu1(self.bn1(x))
        else:
            out=self.relu1(self.bn1(x))
        out=self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
        if self.droprate > 0:
            out=self.Dropout(out)
        out=self.conv2(out)
        return torch.add(x if self.equalInOut else self.convShortcut(x), out)

class NetworkBlock_WRN(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, droprate=0.0, activate_before_residual=False, activation='ReLU'):
        super(NetworkBlock_WRN, self).__init__()
        self.activation=activation
        self.layer=self._make_layer(block, in_planes, out_planes, nb_layers, stride, droprate, activate_before_residual)
        
    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, droprate, activate_before_residual):
        layers=[]
        for i in range(int(nb_layers)):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, droprate, activate_before_residual, self.activation))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layer(x)
    
class WideResNet(nn.Module):
    def __init__(self, num_classes, depth=28, widen_factor=2, droprate=0.0, activation='ReLU'):
        super(WideResNet, self).__init__()
        nChannels=[16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
        assert((depth - 4) % 6 == 0)
        n=(depth - 4) / 6
        block=BasicBlock_WRN
        
        # 1st conv before any network block
        self.conv1=nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False)
        # 1st block
        self.block1=NetworkBlock_WRN(n, nChannels[0], nChannels[1], block, 1, droprate, activation=activation)
        # 2nd block
        self.block2=NetworkBlock_WRN(n, nChannels[1], nChannels[2], block, 2, droprate, activation=activation)
        # 3rd block
        self.block3=NetworkBlock_WRN(n, nChannels[2], nChannels[3], block, 2, droprate, activation=activation)
        # global average pooling and classifier
        self.bn1=nn.BatchNorm2d(nChannels[3], momentum=0.1)
        
        self.activation=activation
        if self.activation == 'ReLU':
            self.relu=nn.ReLU(inplace=True)
        elif self.activation == 'LeakyReLU':
            self.relu=nn.LeakyReLU(negative_slope=0.1, inplace=True)
        elif self.activation == 'SiLU':
            self.relu=nn.SiLU(inplace=True)
        self.fc=nn.Linear(nChannels[3], num_classes)
        self.nChannels=nChannels[3]

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n=m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.zero_()

    def forward(self, x):
        out=self.conv1(x)
        out=self.block1(out)
        out=self.block2(out)
        out=self.block3(out)
        out=self.relu(self.bn1(out))
        out=F.avg_pool2d(out, 8)
        out=out.view(-1, self.nChannels)
        out=self.fc(out)
        return out    

    def freeze(self):
        for para in self.parameters():
            para.requires_grad=False

    def melt(self):
        for para in self.parameters():
            para.requires_grad=True

class LowTriangularMatrix3(nn.Module):
    def __init__(self, n_dim, latent_dim):
        super(LowTriangularMatrix3, self).__init__()
        self.n_dim=n_dim
        self.latent_dim=latent_dim
        self.diag_idx=torch.eye(self.n_dim, self.latent_dim, dtype=torch.bool)
        self.lower_idx=~torch.triu(torch.ones(self.n_dim, self.latent_dim, dtype=torch.bool))
        
        self.diag_elements=nn.Parameter(torch.randn(latent_dim))  # Diagonal elements
        self.lower_triangular_elements=nn.Parameter(torch.randn(self.lower_idx.sum()))  # Lower triangular elements
        self.init_parameters()

    def init_parameters(self):
        nn.init.normal_(self.diag_elements, 1,0.01)
        nn.init.normal_(self.lower_triangular_elements, 0, 0.01)

    def forward(self):
        matrix=torch.zeros(self.n_dim, self.latent_dim, dtype=torch.float32).cuda()
        matrix[self.diag_idx]=self.diag_elements
        matrix[self.lower_idx]=self.lower_triangular_elements
        return matrix    
    
class WideResNet_DLF(nn.Module):
    def __init__(self, num_classes, depth=28, widen_factor=2, mu_h_vec=[], phi_h_vec=[], droprate=0.0, activation='ReLU', num_ens=4, latent_dim=8):
        super(WideResNet_DLF, self).__init__()
        nChannels=[16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
        assert((depth - 4) % 6 == 0)
        n=(depth - 4) / 6
        block=BasicBlock_WRN
        
        # 1st conv before any network block
        self.num_ens=num_ens 
        self.conv1=nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False)
        # 1st block
        self.block1=NetworkBlock_WRN(n, nChannels[0], nChannels[1], block, 1, droprate, activation=activation)
        # 2nd block
        self.block2=NetworkBlock_WRN(n, nChannels[1], nChannels[2], block, 2, droprate, activation=activation)
        # 3rd block
        self.block3=NetworkBlock_WRN(n, nChannels[2], nChannels[3], block, 2, droprate, activation=activation)
        # global average pooling and classifier
        self.bn1=nn.BatchNorm2d(nChannels[3], momentum=0.1)
        
        if activation == 'ReLU':
            self.activation=nn.ReLU(inplace=True)
        elif activation == 'LeakyReLU':
            self.activation=nn.LeakyReLU(negative_slope=0.1, inplace=True)
        elif activation == 'SiLU':
            self.activation=nn.SiLU(inplace=True)
            
        self.fc=Linear_DLF_cls(nChannels[3], num_classes, mu_h_vec, phi_h_vec, num_ens=self.num_ens, latent_dim=latent_dim)
        self.nChannels=nChannels[3]
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n=m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.zero_()

    def forward(self, x, z=None):
        out=self.conv1(x)
        out=self.block1(out)
        out=self.block2(out)
        out=self.block3(out)
        out=self.activation(self.bn1(out))
        out=F.avg_pool2d(out, 8)
        out=out.view(-1, self.nChannels)
        out, mu_f, bias_f, cov_basis, sigma_f=self.fc(out, z)
        return out, mu_f, bias_f, cov_basis, sigma_f
    
    
    
class Linear_DLF_cls(nn.Module):
    def __init__(self, in_features, out_features, mu_h_vec, phi_h_vec, activation='ReLU', num_ens=4, latent_dim=8):
        super(Linear_DLF_cls, self).__init__()

        self.in_features=in_features
        self.out_features=out_features
        self.latent_dim=latent_dim
        self.num_ens=num_ens
        
        if activation == 'ReLU':
            self.activation=nn.ReLU(inplace=True)
        elif activation == 'LeakyReLU':
            self.activation=nn.LeakyReLU(negative_slope=0.1, inplace=True)
        elif activation == 'SiLU':
            self.activation=nn.SiLU(inplace=True)
        
        self.mu_basis=self._make_layer(np.hstack([in_features, mu_h_vec, out_features]))
        self.phi_basis=self._make_layer(np.hstack([in_features, phi_h_vec, latent_dim]))
        
        self.L_c=LowTriangularMatrix3(out_features, out_features)
        self.z=nn.Parameter(torch.Tensor(num_ens, self.latent_dim, self.out_features))
        self.sigma_f=nn.Parameter(torch.tensor([0.])) # corresponding epsilon
        self.init_parameters()
        
    def _make_layer(self, h_vec):
        layers=[]
        for l in range(len(h_vec)-1):
            layer=[]
            layer.append(nn.Linear(h_vec[l], h_vec[l+1]))
            if l+2 !=len(h_vec):
                layer.append(nn.BatchNorm1d(h_vec[l+1], eps=1e-4, affine=True))
                layer.append(self.activation)
            layers.append(nn.Sequential(*layer))
        return nn.Sequential(*layers)
        
    def init_parameters(self):
        nn.init.normal_(self.z, mean=0, std=1)

    def forward(self, x, z=None):
        
        if z == None:
            z=self.z

        B=x.shape[0]
        M=z.shape[0]
        C=self.out_features

        cov_basis=self.phi_basis(x)
        L_c=self.L_c()
        weights_bias_list=[]
        for m in range(M):
            weights_bias_list.append(cov_basis @ z[m,:,:] @ L_c.t())
        bias_f=torch.stack(weights_bias_list, dim=0)

        mu_f=self.mu_basis(x)
        
        output_f=mu_f.unsqueeze(0) + bias_f

        return output_f, mu_f, bias_f, cov_basis, self.sigma_f.exp()