import torch
import torch.nn as nn
import torch.nn.functional as F

class Model_CNN(nn.Module):
    def __init__(self, embedding_matrix):
        super(Model_CNN, self).__init__()
        self.embed = nn.Embedding.from_pretrained(embedding_matrix, freeze=True)
        self.conv1 = nn.Conv1d(100, 128, 5, stride=1)     # torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0)
        self.pool1 = nn.MaxPool1d(5, stride=1)    # torch.nn.MaxPool1d(kernel_size, stride=None, padding=0, 
        self.conv2 = nn.Conv1d(128, 128, 5, stride=1)
        self.pool2 = nn.MaxPool1d(5, stride=1)
        self.conv3 = nn.Conv1d(128, 128, 5, stride=1)

        self.fc1 = nn.Linear(128, 128) # torch.nn.Linear(in_features, out_features, bias=True), input = (N=Batch, Hin), output = (N=Batch, Hout)
        self.dropout = nn.Dropout(0.7)
        self.fc2 = nn.Linear(128, 20)

    def forward(self, x):
        x = self.embed(x) # torch.Size([128, 1000, 100]) = Batch, MAX_SEQUENCE_LENGTH, EMBEDDING_DIM
        x = x.transpose(1,2) # torch.Size([128, 100, 1000]) because for nn.Conv1d input = (N, Cin, L) = (Batch, num of Channels, Length of sequence)
        x = F.relu(self.conv1(x)) # torch.Size([128, 128, 996]) because for nn.Conv1d output = (N, Cout, Lout)
        x = self.pool1(x) # torch.Size([128, 128, 992])

        x = F.relu(self.conv2(x)) # torch.Size([128, 128, 988])
        x = self.pool2(x) # torch.Size([128, 128, 984])

        x = F.relu(self.conv3(x))   # torch.Size([128, 128, 980])
        # GLobal pooling, reduce along the last dim (i.e. sequence length dim) of torch.Size([128, 128, 980])
        x, _ = torch.max(x, dim=2) # torch.Size([128, 128])

        x = self.fc1(x) # torch.Size([128, 128])
        x = self.dropout(x) # torch.Size([128, 128])
        x = self.fc2(x) # torch.Size([128, 20])
        return x