import torch
import torch.nn as nn


class MLP(nn.Module):
    
    def __init__(self, input_dim, hidden_dims=[200, 500], output_dim=None, act='ReLU', bn=False, dropout=False):
        super(MLP, self).__init__()
        output_dim = input_dim if output_dim is None else output_dim
        
        if len(hidden_dims) > 0:
            network_modules = [nn.Linear(input_dim, hidden_dims[0])]
            if act != 'LeakyReLU':
                network_modules.append(getattr(nn, act)())
            else:
                network_modules.append(getattr(nn, act)(negative_slope=0.2, inplace=True))
            for i in range(len(hidden_dims) - 1):            
                network_modules.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
                if bn:
                    network_modules.append(nn.BatchNorm1d(hidden_dims[i + 1], affine=False))
                if dropout:
                    network_modules.append(nn.Dropout())
                if act != 'LeakyReLU':
                    network_modules.append(getattr(nn, act)())
                else:
                    network_modules.append(getattr(nn, act)(negative_slope=0.2, inplace=True))
            self.extractor = nn.Sequential(*network_modules[:-1])
            self.final_layer = nn.Linear(hidden_dims[-1], output_dim)
            network_modules.append(self.final_layer)
            self.network_modules = network_modules
            self.network = nn.Sequential(*network_modules)
        else:
            network_modules = [nn.Linear(input_dim, output_dim)]
            self.network = nn.Sequential(*network_modules)
        
    def forward(self, x, extract=False):
        return self.network(x) if not extract else self.extractor(x)
    
    def freeze(self):
        pass
            
    def melt(self):
        pass


class OTmap_MLP(MLP):
    
    def __init__(self, input_dim, hidden_dims=[200, 500], output_dim=None, act='ReLU', bn=False, dropout=False, last_act='identity', T=1.0):
        super(OTmap_MLP, self).__init__(input_dim, hidden_dims, output_dim, act, bn, dropout)
                
        self.network = MLP(input_dim, hidden_dims, output_dim, act, bn, dropout).network
        self.last_act = last_act
        self.T = T
        
    def forward(self, x):
        x = self.network(x)
        if self.last_act == 'identity':
            return x
        elif self.last_act == 'tanh':
            return (getattr(torch, self.last_act.lower())(x) + 1) / 2.0
        else:
            return getattr(torch, self.last_act.lower())(x / self.T)

    def freeze(self):
        for param in self.network.parameters():
            param.requires_grad = False
            
    def melt(self):
        for param in self.network.parameters():
            param.requires_grad = True


class Classifier_Linear(nn.Module):
    
    def __init__(self, input_dim, output_dim=10):
        super(Classifier_Linear, self).__init__()
        
        self.network = nn.Linear(input_dim, output_dim)
        
    def forward(self, x, extract=False):
        return torch.softmax(self.network(x), dim=1) if not extract else self.extractor(x)
    
    def freeze(self):
        for param in self.network.parameters():
            param.requires_grad = False
            
    def melt(self):
        for param in self.network.parameters():
            param.requires_grad = True


class Classifier_MLP(MLP):
    
    def __init__(self, input_dim, hidden_dims=[200], output_dim=10, act='ReLU'):
        super(Classifier_MLP, self).__init__(input_dim, hidden_dims, output_dim, act)
        
        self.network = MLP(input_dim, hidden_dims, output_dim, act).network
        
    def forward(self, x, extract=False):
        return torch.softmax(self.network(x), dim=1) if not extract else self.extractor(x)
    
    def freeze(self):
        for param in self.network.parameters():
            param.requires_grad = False
            
    def melt(self):
        for param in self.network.parameters():
            param.requires_grad = True