from torch import nn, optim
import torch
import torch.nn.functional as F
import numpy as np


# Define model
class CNN(nn.Module):
    def __init__(self, input_shape, num_classes):
        super(CNN, self).__init__()

        # Temporal convolution
        self.conv1 = nn.Conv2d(1, 16, kernel_size=(1, 125), padding=(0, 62), bias=False)
        self.bn1 = nn.BatchNorm2d(16)

        # Depthwise Convolution
        self.conv2_depthwise = nn.Conv2d(
            16, 32, kernel_size=(28, 1), groups=16, bias=False
        )
        self.bn2 = nn.BatchNorm2d(32)
        self.act2 = nn.ELU()
        self.avg_pool2 = nn.AvgPool2d(kernel_size=(1, 4), stride=(1, 4))
        self.drop2 = nn.Dropout(0.5)

        # Separable Convolution
        self.conv3_separable = nn.Conv2d(
            32, 32, kernel_size=(1, 16), padding=(0, 8), bias=False
        )
        self.bn3 = nn.BatchNorm2d(32)
        self.act3 = nn.ELU()
        self.avg_pool3 = nn.AvgPool2d(kernel_size=(1, 4), stride=(1, 4))
        self.drop3 = nn.Dropout(0.5)

        # Dynamically compute the flattened size
        dummy_input = torch.zeros(1, 1, input_shape[0], input_shape[1])
        x = self.drop3(
            self.avg_pool3(
                self.act3(
                    self.bn3(
                        self.conv3_separable(
                            self.drop2(
                                self.avg_pool2(
                                    self.act2(
                                        self.bn2(
                                            self.conv2_depthwise(
                                                self.bn1(self.conv1(dummy_input))
                                            )
                                        )
                                    )
                                )
                            )
                        )
                    )
                )
            )
        )
        self.flattened_size = x.view(-1).size(0)

        # Fully connected layer before splitting into arousal and valence
        self.fc1 = nn.Linear(self.flattened_size, 128)
        self.act_fc1 = nn.ELU()

        # Fully connected layers for arousal and valence
        self.fc_arousal = nn.Linear(128, num_classes)
        self.fc_valence = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.bn1(self.conv1(x))
        x = self.drop2(self.avg_pool2(self.act2(self.bn2(self.conv2_depthwise(x)))))
        x = self.drop3(self.avg_pool3(self.act3(self.bn3(self.conv3_separable(x)))))
        x = x.view(x.size(0), -1)
        x = self.act_fc1(self.fc1(x))

        arousal = self.fc_arousal(x)
        valence = self.fc_valence(x)

        return arousal, valence
