

import copy
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score
from sklearn.metrics import mean_squared_error

import torch.nn.functional as F
from torch import nn


class AEWrapper(nn.Module):
    """
    Autoencoder wrapper class
    """

    def __init__(self, options):
        """

        Args:
            options (dict): Configuration dictionary.
        """
        super(AEWrapper, self).__init__()
        self.options = copy.deepcopy(options)
        self.encoder = ShallowEncoder(self.options) if options["shallow_architecture"] else Encoder(self.options)
        self.decoder = ShallowDecoder(self.options) if options["shallow_architecture"] else Decoder(self.options)
        
        # Get the last dimension of encoder. This will also be used as the dimension of projection
        output_dim = self.options["dims"][-1]
        # Two-Layer Projection Network
        # First linear layer, which will be followed with non-linear activation function in the forward()
        self.linear_layer1 = nn.Linear(output_dim, output_dim)
        # Last linear layer for final projection
        self.linear_layer2 = nn.Linear(output_dim, output_dim)

    def forward(self, x):
        # Forward pass on Encoder
        latent = self.encoder(x)
        # Forward pass on Projection
        # Apply linear layer followed by non-linear activation to decouple final output, z, from representation layer h.
        z = F.leaky_relu(self.linear_layer1(latent))
        # Apply final linear layer
        z = self.linear_layer2(z)
        # Do L2 normalization
        z = F.normalize(z, p=self.options["p_norm"], dim=1) if self.options["normalize"] else z
        # Forward pass on decoder
        x_recon = self.decoder(latent)
        # Return 
        return z, latent, x_recon


class Mask(nn.Module):
    """
    Mask wrapper class
    """

    def __init__(self, options):
        """

        Args:
            options (dict): Configuration dictionary.
        """
        super(Mask, self).__init__()
        self.options = copy.deepcopy(options)
        # Get the last dimension of encoder. This will also be used as the dimension of projection
        input_dim = self.options["dims"][0]
        hdim = input_dim #1024
        
#         # Two-Layer Projection Network
#         # First linear layer, which will be followed with non-linear activation function in the forward()
#         self.linear_layer1 = nn.Linear(input_dim, hdim)
#         self.linear_layer2 = nn.Linear(hdim, input_dim)
#         self.linear_layer3 = nn.Linear(hdim, input_dim)
#         self.linear_layer4 = nn.Linear(hdim, input_dim)

        if self.options["linear"]:
            self.linear_layer1 = nn.Linear(input_dim, hdim)
        elif self.options["relu"]:
            self.linear_layer1 = nn.Linear(input_dim, hdim)
            self.linear_layer2 = nn.Linear(hdim, input_dim)
        elif self.options["relux2"]:
            self.linear_layer1 = nn.Linear(input_dim, hdim)
            self.linear_layer2 = nn.Linear(hdim, input_dim)
            self.linear_layer3 = nn.Linear(hdim, input_dim)           
        elif self.options["relux3"]:
            self.linear_layer1 = nn.Linear(input_dim, hdim)
            self.linear_layer2 = nn.Linear(hdim, input_dim)
            self.linear_layer3 = nn.Linear(hdim, input_dim)
            self.linear_layer4 = nn.Linear(hdim, input_dim)
        elif self.options["relux5"]:
            self.linear_layer1 = nn.Linear(input_dim, hdim)
            self.linear_layer2 = nn.Linear(hdim, input_dim)
            self.linear_layer3 = nn.Linear(hdim, input_dim)
            self.linear_layer4 = nn.Linear(hdim, input_dim)        
            self.linear_layer5 = nn.Linear(hdim, input_dim)
            self.linear_layer6 = nn.Linear(hdim, input_dim)        

#         self.sm = nn.Softmax(dim=1)

        
    def forward(self, z):
        
        if self.options["linear"]:
            mask = self.linear_layer1(z)
        elif self.options["relu"]:
            z = F.leaky_relu(self.linear_layer1(z))
            mask = self.linear_layer2(z)
        elif self.options["relux2"]:
            z = F.leaky_relu(self.linear_layer1(z))
            z = F.leaky_relu(self.linear_layer2(z))
            mask = self.linear_layer3(z)             
        elif self.options["relux3"]:
            z = F.leaky_relu(self.linear_layer1(z))
            z = F.leaky_relu(self.linear_layer2(z))
            z = F.leaky_relu(self.linear_layer3(z))
            mask = self.linear_layer4(z)             
        elif self.options["relux5"]:
            z = F.leaky_relu(self.linear_layer1(z))
            z = F.leaky_relu(self.linear_layer2(z))
            z = F.leaky_relu(self.linear_layer3(z))
            z = F.leaky_relu(self.linear_layer4(z))
            z = F.leaky_relu(self.linear_layer5(z))

            mask = self.linear_layer6(z) 
             
        return mask.sigmoid()


class Classifier(nn.Module):
    """
    Mask wrapper class
    """

    def __init__(self, options):
        """

        Args:
            options (dict): Configuration dictionary.
        """
        super(Classifier, self).__init__()
        self.options = copy.deepcopy(options)
        # Get the last dimension of encoder. This will also be used as the dimension of projection
        input_dim = self.options["dims"][-1]
        hdim = self.options["c_hdim"]
        # Two-Layer Projection Network
        # First linear layer, which will be followed with non-linear activation function in the forward()
        self.linear_layer1 = nn.Linear(input_dim, hdim)
        self.linear_layer2 = nn.Linear(hdim, hdim)
#         self.linear_layer3 = nn.Linear(hdim, hdim)

        #self.bn1 = nn.BatchNorm1d(hdim)
        #self.bn2 = nn.BatchNorm1d(hdim)

        # Last linear layer for final projection
        self.linear_layer4 = nn.Linear(hdim, self.options["n_classes"])


    def test(self, preds, labels):
        """ Gives ROC AUC and AP"""
        
        labels, preds = labels.detach().cpu().numpy(), preds.detach().cpu().numpy()
        
        labels = np.argmax(labels, axis=1)
        preds = np.argmax(preds, axis=1)

        return accuracy_score(labels, preds) #roc_auc_score(y, pred), average_precision_score(y, pred)        
        

    def forward(self, x):
        # Forward pass
        z = F.leaky_relu(self.linear_layer1(x))
        z = F.dropout(z, p=0.2, training=self.training)
        z = F.leaky_relu(self.linear_layer2(z))
        z = F.dropout(z, p=0.2, training=self.training)

#         z = F.leaky_relu(self.linear_layer3(z))
        
        # Apply final linear layer
        preds = self.linear_layer4(z)
        
        # Return 
        return preds.sigmoid()

    

# class Classifier(nn.Module):
#     """
#     Mask wrapper class
#     """

#     def __init__(self, options):
#         """

#         Args:
#             options (dict): Configuration dictionary.
#         """
#         super(Classifier, self).__init__()
#         self.options = copy.deepcopy(options)
#         # Get the last dimension of encoder. This will also be used as the dimension of projection
#         input_dim = self.options["dims"][-1]
#         hdim = input_dim
#         # Two-Layer Projection Network
#         # First linear layer, which will be followed with non-linear activation function in the forward()
#         self.linear_layer1 = nn.Linear(input_dim, hdim)
#         self.linear_layer2 = nn.Linear(hdim, hdim)
#         #self.bn1 = nn.BatchNorm1d(hdim)
#         #self.bn2 = nn.BatchNorm1d(hdim)

#         # Last linear layer for final projection
#         self.linear_layer3 = nn.Linear(hdim, self.options["n_classes"])


#     def test(self, preds, labels):
#         """ Gives ROC AUC and AP"""
        
#         labels, preds = labels.detach().cpu().numpy(), preds.detach().cpu().numpy()
        
#         labels = np.argmax(labels, axis=1)
#         preds = np.argmax(preds, axis=1)

#         return accuracy_score(labels, preds) #roc_auc_score(y, pred), average_precision_score(y, pred)        
        

#     def forward(self, x):
#         # Forward pass
#         #z = F.leaky_relu(self.bn1(self.linear_layer1(x)))
#         #z = z +F.leaky_relu(self.bn2(self.linear_layer2(z)))

# #         z = F.leaky_relu(self.bn1(self.linear_layer1(x)))
# #         z = F.dropout(z, p=0.5, training=self.training)
# #         z = F.leaky_relu(self.bn2(self.linear_layer2(z)))

#         z = F.leaky_relu(self.linear_layer1(x))
#         z = F.leaky_relu(self.linear_layer2(z))

#         #z = F.normalize(z, p=self.options["p_norm"], dim=1)
#         # Apply final linear layer
#         preds = self.linear_layer3(z)
#         # Return 
#         return preds.sigmoid()


class Regressor(nn.Module):
    """
    Mask wrapper class
    """

    def __init__(self, options):
        """

        Args:
            options (dict): Configuration dictionary.
        """
        super(Regressor, self).__init__()
        self.options = copy.deepcopy(options)
        # Get the last dimension of encoder. This will also be used as the dimension of projection
        input_dim = self.options["dims"][-1]
        hdim = input_dim
        # Two-Layer Projection Network
        # First linear layer, which will be followed with non-linear activation function in the forward()
        self.linear_layer1 = nn.Linear(input_dim, hdim)
        self.linear_layer2 = nn.Linear(hdim, hdim)

        # Last linear layer for final projection
        self.linear_layer3 = nn.Linear(hdim, 1)

    def test(self, preds,  targets):
        """ MSE as regression error"""
        
        targets, preds = labels.detach().cpu().numpy(), targets.detach().cpu().numpy()
        mse = mean_squared_error(labels, preds) 

        return  mse**0.5

    def forward(self, x):
        # Forward pass
        z = F.leaky_relu(self.linear_layer1(x))
        z = F.leaky_relu(self.linear_layer2(z))

        #z = F.normalize(z, p=self.options["p_norm"], dim=1)
        # Apply final linear layer
        preds = self.linear_layer3(z)
        # Return 
        return preds
    
    
class Encoder(nn.Module):
    def __init__(self, options):
        """Encoder model

        Args:
            options (dict): Configuration dictionary.
        """
        super(Encoder, self).__init__()
        # Deepcopy options to avoid overwriting the original
        self.options = copy.deepcopy(options)
        # Compute the shrunk size of input dimension
        n_column_subset = int(self.options["dims"][0] / self.options["n_subsets"])
        # Ratio of overlapping features between subsets
        overlap = self.options["overlap"]
        # Number of overlapping features between subsets
        n_overlap = int(overlap * n_column_subset)
        # Overwrie the input dimension
        self.options["dims"][0] = n_column_subset + n_overlap
        # Forward pass on hidden layers
        self.hidden_layers = HiddenLayers(self.options)
        # Compute the latent i.e. bottleneck in Autoencoder
        self.latent = nn.Linear(self.options["dims"][-2], self.options["dims"][-1])

    def forward(self, h):
        # Forward pass on hidden layers
        h = self.hidden_layers(h)
        # Compute the mean i.e. bottleneck in Autoencoder
        latent = self.latent(h)
        return latent


class Decoder(nn.Module):
    def __init__(self, options):
        """Decoder model

        Args:
            options (dict): Configuration dictionary.
        """
        super(Decoder, self).__init__()
        # Deepcopy options to avoid overwriting the original
        self.options = copy.deepcopy(options)
        # If recontruct_subset is True, output dimension is same as input dimension of Encoder. Otherwise, 
        # output dimension is same as original feature dimension of tabular data
        if self.options["reconstruction"] and self.options["reconstruct_subset"]:
            # Compute the shrunk size of input dimension
            n_column_subset = int(self.options["dims"][0] / self.options["n_subsets"])
            # Overwrie the input dimension
            self.options["dims"][0] = n_column_subset
        # Revert the order of hidden units so that we can build a Decoder, which is the symmetric of Encoder
        self.options["dims"] = self.options["dims"][::-1]
        # Add hidden layers
        self.hidden_layers = HiddenLayers(self.options)
        # Compute logits and probabilities
        self.logits = nn.Linear(self.options["dims"][-2], self.options["dims"][-1])

    def forward(self, h):
        # Forward pass on hidden layers
        h = self.hidden_layers(h)
        # Compute logits
        logits = self.logits(h)
        return logits

    
class ShallowEncoder(nn.Module):
    def __init__(self, options):
        """Encoder model

        Args:
            options (dict): Configuration dictionary.
        """
        super(ShallowEncoder, self).__init__()
        # Deepcopy options to avoid overwriting the original
        self.options = copy.deepcopy(options)  
        # Compute the shrunk size of input dimension
        n_column_subset = int(self.options["dims"][0]/self.options["n_subsets"])
        # Ratio of overlapping features between subsets
        overlap = self.options["overlap"]
        # Number of overlapping features between subsets
        n_overlap = int(overlap*n_column_subset)
        # Overwrie the input dimension
        self.options["dims"][0] = n_column_subset + n_overlap
        # Forward pass on hidden layers
        self.hidden_layers = HiddenLayers(self.options)

    def forward(self, h):
        # Forward pass on hidden layers
        h = self.hidden_layers(h)
        return h
    
    
class ShallowDecoder(nn.Module):
    def __init__(self, options):
        """Decoder model

        Args:
            options (dict): Configuration dictionary.
        """
        super(ShallowDecoder, self).__init__()
        # Get configuration that contains architecture and hyper-parameters
        self.options = copy.deepcopy(options)
        # Input dimension of predictor == latent dimension
        input_dim, output_dim = self.options["dims"][-1],  self.options["dims"][0]
        # First linear layer with shape (bottleneck dimension, output channel size of last conv layer in CNNEncoder)
        self.first_layer = nn.Linear(input_dim, output_dim)

    def forward(self, z):
        logits = self.first_layer(z)
        return logits
    
    
class HiddenLayers(nn.Module):
    def __init__(self, options):
        """Class to add hidden layers to networks

        Args:
            options (dict): Configuration dictionary.
        """
        super(HiddenLayers, self).__init__()
        self.layers = nn.ModuleList()
        dims = options["dims"]

        for i in range(1, len(dims) - 1):
            self.layers.append(nn.Linear(dims[i - 1], dims[i]))
            if options["isBatchNorm"]:
                self.layers.append(nn.BatchNorm1d(dims[i]))

            self.layers.append(nn.LeakyReLU(inplace=False))
            if options["isDropout"]:
                self.layers.append(nn.Dropout(options["dropout_rate"]))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
