import math
import torch
import torch.nn.functional as F
import torchvision
from torch import nn, Tensor
from flcore.trainmodel.bilstm import *
from flcore.trainmodel.resnet import *
from flcore.trainmodel.alexnet import *
from flcore.trainmodel.mobilenet_v2 import *


# split an original model into a base and a head
class BaseHeadSplit(nn.Module):
    def __init__(self, args, cid):
        """
        Initialize the BaseHeadSplit model, splitting the base and head for heterogeneous federated learning.

        Args:
            args: Arguments containing model configurations.
            cid: Client ID.
        """
        super().__init__()
        self.base = eval(args.models[cid % len(args.models)])
        head = None # you may need more code for pre-existing heterogeneous heads
        if hasattr(self.base, 'heads'):
            head = self.base.heads
            self.base.heads = nn.AdaptiveAvgPool1d(args.feature_dim)
        elif hasattr(self.base, 'head'):
            head = self.base.head
            self.base.head = nn.AdaptiveAvgPool1d(args.feature_dim)
        elif hasattr(self.base, 'fc'):
            head = self.base.fc
            self.base.fc = nn.AdaptiveAvgPool1d(args.feature_dim)
        elif hasattr(self.base, 'classifier'):
            head = self.base.classifier
            self.base.classifier = nn.AdaptiveAvgPool1d(args.feature_dim)
        else:
            raise('The base model does not have a classification head.')
        
        if hasattr(args, 'heads'):
            self.head = eval(args.heads[cid % len(args.heads)])
        else:
            self.head = nn.Linear(args.feature_dim, args.num_classes,bias=False)

        #self.head = nn.Linear(args.feature_dim, args.num_classes, bias=False)
        nn.init.orthogonal_(self.head.weight)
        self.scaling = torch.nn.Parameter(torch.tensor([1.0]))

    def forward(self, x):
        """
        Forward pass through base and head, with feature normalization and scaling.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tuple[Tensor, Tensor]: Feature embedding and logits.
        """
        out = self.base(x)
        norm = torch.norm(out, p=2, dim=1, keepdim=True).clamp(min=1e-12)
        feature_embedding = torch.div(out, norm)

        prototype_norm = torch.norm(self.head.weight, p=2, dim=1, keepdim=True).clamp(min=1e-12)
        normalized_prototype = torch.div(self.head.weight, prototype_norm)
        logits = self.head(feature_embedding)
        logits = torch.matmul(feature_embedding, normalized_prototype.T)
        logits = self.scaling * logits

        return feature_embedding,logits

class Head(nn.Module):
    def __init__(self, num_classes=10, hidden_dims=[512]):
        """
        Initialize the Head module for classification.

        Args:
            num_classes (int): Number of output classes.
            hidden_dims (list): List of hidden layer dimensions.
        """
        super().__init__()
        hidden_dims.append(num_classes)

        layers = []
        for idx in range(1, len(hidden_dims)):
            linear_layer = nn.Linear(hidden_dims[idx-1], hidden_dims[idx])
            if idx < len(hidden_dims) - 1:
                layers.append(linear_layer)
                layers.append(nn.ReLU(inplace=True))

        # 添加最后一层分类器
        layers.append(nn.Linear(hidden_dims[-1], num_classes))
        self.fc = nn.Sequential(*layers)
        
        # 暴露最后一层权重
        self.weight = layers[-1].weight

        self.fc = nn.Sequential(*layers)

    def forward(self, rep):
        """
        Forward pass through the head layers.

        Args:
            rep (Tensor): Input representation.

        Returns:
            Tensor: Output logits.
        """
        out = self.fc(rep)
        return out

###########################################################

class CNN(nn.Module):
    def __init__(self, in_features=1, num_classes=10, height=28, 
                 num_cov=2, feature_dim=512, hidden_dims=[]):
        """
        Initialize a simple CNN for image classification.

        Args:
            in_features (int): Number of input channels.
            num_classes (int): Number of output classes.
            height (int): Height of input images.
            num_cov (int): Number of convolutional layers.
            feature_dim (int): Feature dimension for fully connected layers.
            hidden_dims (list): List of hidden layer dimensions.
        """
        super().__init__()
        convs = [nn.Sequential(
                    nn.Conv2d(in_features,
                        32,
                        kernel_size=5,
                        padding=0,
                        stride=1,
                        bias=True),
                    nn.ReLU(inplace=True), 
                    nn.MaxPool2d(kernel_size=(2, 2))
                )]
        height = int(height - 5 + 1)
        height = int((height - 2) / 2 + 1)
        i=-1
        for i in range(num_cov-1):
            convs.append(nn.Sequential(
                            nn.Conv2d(2**(i+5),
                                2**(i+6),
                                kernel_size=5,
                                padding=0,
                                stride=1,
                                bias=True),
                            nn.ReLU(inplace=True), 
                            nn.MaxPool2d(kernel_size=(2, 2))
                        ))
            height = int(height - 5 + 1)
            height = int((height - 2) / 2 + 1)
        self.conv = nn.Sequential(*convs)
        
        hidden_dims.append(feature_dim)

        layers = [nn.Flatten()]
        for idx in range(len(hidden_dims)):
            if len(layers) == 1:
                layers.append(nn.Linear(height ** 2 * 2**(i+6), hidden_dims[idx]))
                layers.append(nn.ReLU(inplace=True))
            else:
                layers.append(nn.Linear(hidden_dims[idx-1], hidden_dims[idx]))
                layers.append(nn.ReLU(inplace=True))

        self.fc1 = nn.Sequential(*layers)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        """
        Forward pass through CNN layers.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Output logits.
        """
        out = self.conv(x)
        out = self.fc1(out)
        out = self.fc(out)
        return out

# https://github.com/jindongwang/Deep-learning-activity-recognition/blob/master/pytorch/network.py
class HARCNN(nn.Module):
    def __init__(self, in_channels=9, dim_hidden=64*26, num_classes=6, conv_kernel_size=(1, 9), pool_kernel_size=(1, 2)):
        """
        Initialize the HAR CNN model for activity recognition.

        Args:
            in_channels (int): Number of input channels.
            dim_hidden (int): Hidden dimension after convolution.
            num_classes (int): Number of output classes.
            conv_kernel_size (tuple): Kernel size for convolution.
            pool_kernel_size (tuple): Kernel size for pooling.
        """
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=conv_kernel_size),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=pool_kernel_size, stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=conv_kernel_size),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=pool_kernel_size, stride=2)
        )
        self.fc = nn.Sequential(
            nn.Linear(dim_hidden, 1024),
            nn.ReLU(), 
            nn.Linear(1024, 512),
            nn.ReLU(), 
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        """
        Forward pass through HAR CNN layers.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Output logits.
        """
        out = self.conv1(x)
        out = self.conv2(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out


# https://github.com/FengHZ/KD3A/blob/master/model/digit5.py
class Digit5CNN(nn.Module):
    def __init__(self):
        """
        Initialize the Digit5CNN model for digit classification.
        """
        super(Digit5CNN, self).__init__()
        self.encoder = nn.Sequential()
        self.encoder.add_module("conv1", nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2))
        self.encoder.add_module("bn1", nn.BatchNorm2d(64))
        self.encoder.add_module("relu1", nn.ReLU())
        self.encoder.add_module("maxpool1", nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False))
        self.encoder.add_module("conv2", nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2))
        self.encoder.add_module("bn2", nn.BatchNorm2d(64))
        self.encoder.add_module("relu2", nn.ReLU())
        self.encoder.add_module("maxpool2", nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False))
        self.encoder.add_module("conv3", nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2))
        self.encoder.add_module("bn3", nn.BatchNorm2d(128))
        self.encoder.add_module("relu3", nn.ReLU())

        self.linear = nn.Sequential()
        self.linear.add_module("fc1", nn.Linear(8192, 3072))
        self.linear.add_module("bn4", nn.BatchNorm1d(3072))
        self.linear.add_module("relu4", nn.ReLU())
        self.linear.add_module("dropout", nn.Dropout())
        self.linear.add_module("fc2", nn.Linear(3072, 2048))
        self.linear.add_module("bn5", nn.BatchNorm1d(2048))
        self.linear.add_module("relu5", nn.ReLU())

        self.fc = nn.Linear(2048, 10)

    def forward(self, x):
        """
        Forward pass through Digit5CNN layers.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Output logits.
        """
        batch_size = x.size(0)
        feature = self.encoder(x)
        feature = feature.view(batch_size, -1)
        feature = self.linear(feature)
        out = self.fc(feature)
        return out
        

# https://github.com/FengHZ/KD3A/blob/master/model/amazon.py
class AmazonMLP(nn.Module):
    def __init__(self, feature_dim=[500]):
        """
        Initialize the AmazonMLP model for Amazon review classification.

        Args:
            feature_dim (list): List of hidden layer dimensions.
        """
        super(AmazonMLP, self).__init__()
        self.in_features = 5000
        self.out_features = 100
        layers = []
        for idx in range(len(feature_dim)):
            if len(layers) == 0:
                layers.append(nn.Linear(self.in_features, feature_dim[idx]))
                layers.append(nn.ReLU())
            else:
                layers.append(nn.Linear(feature_dim[idx-1], feature_dim[idx]))
                layers.append(nn.ReLU())

        layers.append(nn.Linear(feature_dim[idx], self.out_features))
        layers.append(nn.ReLU())

        self.encoder = nn.Sequential(*layers)
        self.fc = nn.Linear(self.out_features, 2)

    def forward(self, x):
        """
        Forward pass through AmazonMLP layers.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Output logits.
        """
        out = self.encoder(x)
        out = self.fc(out)
        return out
        

# # https://github.com/katsura-jp/fedavg.pytorch/blob/master/src/models/cnn.py
# class FedAvgCNN(nn.Module):
#     def __init__(self, in_features=1, num_classes=10, dim=1024):
#         super().__init__()
#         self.conv1 = nn.Conv2d(in_features,
#                                32,
#                                kernel_size=5,
#                                padding=0,
#                                stride=1,
#                                bias=True)
#         self.conv2 = nn.Conv2d(32,
#                                64,
#                                kernel_size=5,
#                                padding=0,
#                                stride=1,
#                                bias=True)
#         self.fc1 = nn.Linear(dim, 512)
#         self.fc = nn.Linear(512, num_classes)

#         self.act = nn.ReLU(inplace=True)
#         self.maxpool = nn.MaxPool2d(kernel_size=(2, 2))

#     def forward(self, x):
#         x = self.act(self.conv1(x))
#         x = self.maxpool(x)
#         x = self.act(self.conv2(x))
#         x = self.maxpool(x)
#         x = torch.flatten(x, 1)
#         x = self.act(self.fc1(x))
#         x = self.fc(x)
#         return x

class FedAvgCNN(nn.Module):
    def __init__(self, in_features=1, num_classes=10, dim=1024):
        """
        Initialize the FedAvgCNN model for federated averaging.

        Args:
            in_features (int): Number of input channels.
            num_classes (int): Number of output classes.
            dim (int): Flattened dimension after convolution.
        """
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_features,
                        32,
                        kernel_size=5,
                        padding=0,
                        stride=1,
                        bias=True),
            nn.ReLU(inplace=True), 
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32,
                        64,
                        kernel_size=5,
                        padding=0,
                        stride=1,
                        bias=True),
            nn.ReLU(inplace=True), 
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        self.fc1 = nn.Sequential(
            nn.Linear(dim, 512), 
            nn.ReLU(inplace=True)
        )
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        """
        Forward pass through FedAvgCNN layers.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Output logits.
        """
        out = self.conv1(x)
        out = self.conv2(out)
        out = torch.flatten(out, 1)
        out = self.fc1(out)
        out = self.fc(out)
        return out

# ====================================================================================================================

# https://github.com/katsura-jp/fedavg.pytorch/blob/master/src/models/mlp.py
class FedAvgMLP(nn.Module):
    def __init__(self, in_features=784, num_classes=10, hidden_dim=200):
        """
        Initialize the FedAvgMLP model for federated averaging.

        Args:
            in_features (int): Input feature dimension.
            num_classes (int): Number of output classes.
            hidden_dim (int): Hidden layer dimension.
        """
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        """
        Forward pass through FedAvgMLP layers.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Output logits.
        """
        if x.ndim == 4:
            x = x.view(x.size(0), -1)
        x = self.act(self.fc1(x))
        x = self.fc2(x)
        return x

class Mclr_Logistic(nn.Module):
    def __init__(self, input_dim=1*28*28, num_classes=10):
        """
        Initialize the multiclass logistic regression model.

        Args:
            input_dim (int): Input feature dimension.
            num_classes (int): Number of output classes.
        """
        super(Mclr_Logistic, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        """
        Forward pass through logistic regression.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Log-softmax output.
        """
        x = torch.flatten(x, 1)
        x = self.fc(x)
        output = F.log_softmax(x, dim=1)
        return output

# ====================================================================================================================

class DNN(nn.Module):
    def __init__(self, input_dim=1*28*28, mid_dim=100, num_classes=10):
        """
        Initialize the DNN model.

        Args:
            input_dim (int): Input feature dimension.
            mid_dim (int): Middle hidden layer dimension.
            num_classes (int): Number of output classes.
        """
        super(DNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, mid_dim)
        self.fc = nn.Linear(mid_dim, num_classes)

    def forward(self, x):
        """
        Forward pass through DNN layers.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Log-softmax output.
        """
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc(x)
        x = F.log_softmax(x, dim=1)
        return x

# ====================================================================================================================

# cfg = {
#     'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
#     'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
#     'VGGbatch_size': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
#     'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
# }

# class VGG(nn.Module):
#     def __init__(self, vgg_name):
#         super(VGG, self).__init__()
#         self.features = self._make_layers(cfg[vgg_name])
#         self.classifier = nn.Sequential(
#             nn.Linear(512, 512),
#             nn.ReLU(True),
#             nn.Linear(512, 512),
#             nn.ReLU(True),
#             nn.Linear(512, 10)
#         )

#     def forward(self, x):
#         out = self.features(x)
#         out = out.view(out.size(0), -1)
#         out = self.classifier(out)
#         output = F.log_softmax(out, dim=1)
#         return output

#     def _make_layers(self, cfg):
#         layers = []
#         in_channels = 3
#         for x in cfg:
#             if x == 'M':
#                 layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
#             else:
#                 layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
#                            nn.BatchNorm2d(x),
#                            nn.ReLU(inplace=True)]
#                 in_channels = x
#         layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
#         return nn.Sequential(*layers)

# ====================================================================================================================

def init_weights(m):
    """
    Initialize weights for layers using appropriate strategies.

    Args:
        m (nn.Module): Layer to initialize.
    """
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
        nn.init.kaiming_uniform_(m.weight)
        nn.init.zeros_(m.bias)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.zeros_(m.bias)
    elif classname.find('Linear') != -1:
        nn.init.xavier_normal_(m.weight)
        nn.init.zeros_(m.bias)

class LeNet(nn.Module):
    def __init__(self, feature_dim=50*4*4, bottleneck_dim=256, num_classes=10, iswn=None):
        """
        Initialize the LeNet model.

        Args:
            feature_dim (int): Feature dimension after convolution.
            bottleneck_dim (int): Bottleneck layer dimension.
            num_classes (int): Number of output classes.
            iswn (str, optional): Whether to use weight normalization.
        """
        super(LeNet, self).__init__()

        self.conv_params = nn.Sequential(
            nn.Conv2d(1, 20, kernel_size=5),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(20, 50, kernel_size=5),
            nn.Dropout2d(p=0.5),
            nn.MaxPool2d(2),
            nn.ReLU(),
        )
        self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True)
        self.dropout = nn.Dropout(p=0.5)
        self.bottleneck = nn.Linear(feature_dim, bottleneck_dim)
        self.bottleneck.apply(init_weights)
        self.fc = nn.Linear(bottleneck_dim, num_classes)
        if iswn == "wn":
            self.fc = nn.utils.weight_norm(self.fc, name="weight")
        self.fc.apply(init_weights)

    def forward(self, x):
        """
        Forward pass through LeNet layers.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Log-softmax output.
        """
        x = self.conv_params(x)
        x = x.view(x.size(0), -1)
        x = self.bottleneck(x)
        x = self.bn(x)
        x = self.dropout(x)
        x = self.fc(x)
        x = F.log_softmax(x, dim=1)
        return x

# ====================================================================================================================

# class CNNCifar(nn.Module):
#     def __init__(self, num_classes=10):
#         super(CNNCifar, self).__init__()
#         self.conv1 = nn.Conv2d(3, 6, 5)
#         self.pool = nn.MaxPool2d(2, 2)
#         self.conv2 = nn.Conv2d(6, batch_size, 5)
#         self.fc1 = nn.Linear(batch_size * 5 * 5, 120)
#         self.fc2 = nn.Linear(120, 100)
#         self.fc3 = nn.Linear(100, num_classes)

#         # self.weight_keys = [['fc1.weight', 'fc1.bias'],
#         #                     ['fc2.weight', 'fc2.bias'],
#         #                     ['fc3.weight', 'fc3.bias'],
#         #                     ['conv2.weight', 'conv2.bias'],
#         #                     ['conv1.weight', 'conv1.bias'],
#         #                     ]
                            
#     def forward(self, x):
#         x = self.pool(F.relu(self.conv1(x)))
#         x = self.pool(F.relu(self.conv2(x)))
#         x = x.view(-1, batch_size * 5 * 5)
#         x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
#         x = self.fc3(x)
#         x = F.log_softmax(x, dim=1)
#         return x

# ====================================================================================================================

class LSTMNet(nn.Module):
    def __init__(self, hidden_dim, num_layers=2, bidirectional=False, dropout=0.2, 
                padding_idx=0, vocab_size=98635, num_classes=10):
        """
        Initialize the LSTMNet model for sequence classification.

        Args:
            hidden_dim (int): Hidden layer dimension.
            num_layers (int): Number of LSTM layers.
            bidirectional (bool): Whether to use bidirectional LSTM.
            dropout (float): Dropout rate.
            padding_idx (int): Padding index for embeddings.
            vocab_size (int): Vocabulary size.
            num_classes (int): Number of output classes.
        """
        super().__init__()

        self.dropout = nn.Dropout(dropout)
        self.embedding = nn.Embedding(vocab_size, hidden_dim, padding_idx)
        self.lstm = nn.LSTM(input_size=hidden_dim, 
                            hidden_size=hidden_dim, 
                            num_layers=num_layers, 
                            bidirectional=bidirectional, 
                            dropout=dropout, 
                            batch_first=True)
        dims = hidden_dim*2 if bidirectional else hidden_dim
        self.fc = nn.Linear(dims, num_classes)

    def forward(self, x):
        """
        Forward pass through LSTMNet layers.

        Args:
            x (tuple): Tuple of (text tensor, text lengths).

        Returns:
            Tensor: Log-softmax output.
        """
        text, text_lengths = x
        
        embedded = self.embedding(text)
        
        #pack sequence
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths, batch_first=True, enforce_sorted=False)
        packed_output, (hidden, cell) = self.lstm(packed_embedded)

        #unpack sequence
        out, out_lengths = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)

        out = torch.relu_(out[:,-1,:])
        out = self.dropout(out)
        out = self.fc(out)
        out = F.log_softmax(out, dim=1)
            
        return out

# ====================================================================================================================

class fastText(nn.Module):
    def __init__(self, hidden_dim, padding_idx=0, vocab_size=98635, num_classes=10):
        """
        Initialize the fastText model for text classification.

        Args:
            hidden_dim (int): Hidden layer dimension.
            padding_idx (int): Padding index for embeddings.
            vocab_size (int): Vocabulary size.
            num_classes (int): Number of output classes.
        """
        super(fastText, self).__init__()
        
        # Embedding Layer
        self.embedding = nn.Embedding(vocab_size, hidden_dim, padding_idx)
        
        # Hidden Layer
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        
        # Output Layer
        self.fc = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, x):
        """
        Forward pass through fastText layers.

        Args:
            x (tuple): Tuple of (text tensor, text lengths).

        Returns:
            Tensor: Log-softmax output.
        """
        text, text_lengths = x

        embedded_sent = self.embedding(text)
        h = self.fc1(embedded_sent.mean(1))
        z = self.fc(h)
        out = F.log_softmax(z, dim=1)

        return out

# ====================================================================================================================

class TextCNN(nn.Module):
    def __init__(self, hidden_dim, num_channels=100, kernel_size=[3,4,5], max_len=200, dropout=0.8, 
                padding_idx=0, vocab_size=98635, num_classes=10):
        """
        Initialize the TextCNN model for text classification.

        Args:
            hidden_dim (int): Hidden layer dimension.
            num_channels (int): Number of channels for convolution.
            kernel_size (list): List of kernel sizes.
            max_len (int): Maximum sequence length.
            dropout (float): Dropout rate.
            padding_idx (int): Padding index for embeddings.
            vocab_size (int): Vocabulary size.
            num_classes (int): Number of output classes.
        """
        super(TextCNN, self).__init__()
        
        # Embedding Layer
        self.embedding = nn.Embedding(vocab_size, hidden_dim, padding_idx)
        
        # This stackoverflow thread clarifies how conv1d works
        # https://stackoverflow.com/questions/46503816/keras-conv1d-layer-parameters-filters-and-kernel-size/46504997
        self.conv1 = nn.Sequential(
            nn.Conv1d(in_channels=hidden_dim, out_channels=num_channels, kernel_size=kernel_size[0]),
            nn.ReLU(),
            nn.MaxPool1d(max_len - kernel_size[0]+1)
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(in_channels=hidden_dim, out_channels=num_channels, kernel_size=kernel_size[1]),
            nn.ReLU(),
            nn.MaxPool1d(max_len - kernel_size[1]+1)
        )
        self.conv3 = nn.Sequential(
            nn.Conv1d(in_channels=hidden_dim, out_channels=num_channels, kernel_size=kernel_size[2]),
            nn.ReLU(),
            nn.MaxPool1d(max_len - kernel_size[2]+1)
        )
        
        self.dropout = nn.Dropout(dropout)
        
        # Fully-Connected Layer
        self.fc1 = nn.Linear(num_channels*len(kernel_size), hidden_dim)
        self.fc = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, x):
        """
        Forward pass through TextCNN layers.

        Args:
            x (tuple): Tuple of (text tensor, text lengths).

        Returns:
            Tensor: Log-softmax output.
        """
        text, text_lengths = x

        embedded_sent = self.embedding(text).permute(0,2,1)
        
        conv_out1 = self.conv1(embedded_sent).squeeze(2)
        conv_out2 = self.conv2(embedded_sent).squeeze(2)
        conv_out3 = self.conv3(embedded_sent).squeeze(2)
        
        all_out = torch.cat((conv_out1, conv_out2, conv_out3), 1)
        final_feature_map = self.dropout(all_out)
        feat = self.fc1(final_feature_map)
        out = self.fc(feat)
        out = F.log_softmax(out, dim=1)

        return out

# ====================================================================================================================


# class linear(Function):
#   @staticmethod
#   def forward(ctx, input):
#     return input
  
#   @staticmethod
#   def backward(ctx, grad_output):
#     return grad_output