import torch
import torch.nn as nn


class Flatten(torch.nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)
    
class MaskedReLU(torch.nn.Module):
    def __init__(self):
        super(MaskedReLU, self).__init__()

    def forward(self, x, mask = None):
        return nn.ReLU()(x) * mask if mask is not None else nn.ReLU()(x)

class ArrhythmiaClassifierV2(nn.Module):
    def __init__(self, num_classes=17):
        super(ArrhythmiaClassifierV2, self).__init__()
               # Our input signal is (1, 3600), so we'll treat it as (batch_size, 1, 3600)

        # --- Convolutional Block 1 with Early MaxPooling ---
        # Layer 1
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=32, kernel_size=16, padding='same')
        self.relu1 = MaskedReLU()
        
        # Layer 2: MaxPooling (used once, at the beginning)
        self.pool1 = nn.MaxPool1d(kernel_size=4, stride=4)
        # Output of pool1: (batch_size, 32, 3600 / 4) = (batch_size, 32, 900)

        # --- Feature Extraction Block with Strided Convolutions ---
        # Layer 3
        self.conv2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=12, padding='same')
        self.bn2 = nn.BatchNorm1d(64)
        self.relu2 = MaskedReLU()
        # Output of conv2: (batch_size, 64, 900)

        # Layer 4: Convolution with stride to reduce sequence length
        self.conv3 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=9, stride=2)
        self.bn3 = nn.BatchNorm1d(128)
        self.relu3 = MaskedReLU()

        # Output of conv3: (batch_size, 128, 900 / 2) = (batch_size, 128, 450)
        # PyTorch's padding='same' with stride S aims for output length ceil(input_length / S)

        # Layer 5: Convolution with stride to further reduce sequence length
        self.conv4 = nn.Conv1d(in_channels=128, out_channels=128, kernel_size=7, stride=2)
        self.bn4 = nn.BatchNorm1d(128)
        self.relu4 = MaskedReLU()
        # Output of conv4: (batch_size, 128, 450 / 2) = (batch_size, 128, 225)

        self.dropout1 = nn.Dropout(p=0.25)

        self.adaptive_pool = nn.AdaptiveMaxPool1d(output_size=1024)

        # --- Classification/Regression Head ---
        self.flatten = nn.Flatten()

        # Calculate the flatorchened size: 128 channels * 225 length = 28,800
        self.fc1_in_features = 1024
        self.fc1 = nn.Linear(in_features=self.fc1_in_features, out_features=128) # Reduced dense layer
        self.relu_fc1 = MaskedReLU()
        self.dropout2 = nn.Dropout(0.5) # Added Dropout for regularization
        self.fc2 = nn.Linear(in_features=128, out_features=num_classes)
        # For multi-class classification, Softmax is often applied outside (e.g., with nn.CrossEntropyLoss)

    def forward(self, x, mask=None):

        # if mask is None:
        #     mask = torch.ones_like(x)
        
        # visualize_channels(mask, "Input Mask")

        # Convolutional Block 1
        x = self.conv1(x)    # (batch_size, 32, 3600)

        #Conv over the mask
        with torch.no_grad():
            weight = self.conv1.weight
            bias = self.conv1.bias
            self.conv1.weight = nn.Parameter(torch.ones_like(weight)/torch.numel(weight[0]))
            self.conv1.bias = nn.Parameter(torch.zeros_like(bias))

            mask = self.conv1(mask)
            # print("Mask min: ", torch.min(mask), "Mask max: ", torch.max(mask))
            mask = torch.where(mask > 0, 1.0, 0.0)

            self.conv1.weight = weight
            self.conv1.bias = bias
        

        x = self.relu1(x, mask)
        x = self.pool1(x)    # (batch_size, 32, 900)
        mask = self.pool1(mask)
        assert torch.max(mask) <= 1.0 and torch.min(mask) >= 0.0, "Mask must be between 0 and 1"

        x = x * mask

        # Feature Extraction Block
        x = self.conv2(x)    # (batch_size, 64, 900)

        with torch.no_grad():
            weight = self.conv2.weight
            bias = self.conv2.bias

            self.conv2.weight = nn.Parameter(torch.ones_like(weight)/torch.numel(weight[0]))
            self.conv2.bias = nn.Parameter(torch.zeros_like(bias))

            mask = self.conv2(mask)
            # print("Mask min: ", torch.min(mask), "Mask max: ", torch.max(mask))
        
            mask = torch.where(mask > 0, 1.0, 0.0)

            self.conv2.weight = weight
            self.conv2.bias = bias

        x = self.relu2(x, mask)

        x = self.conv3(x)    # (batch_size, 128, 450)

        with torch.no_grad():
            weight = self.conv3.weight
            bias = self.conv3.bias

            self.conv3.weight = nn.Parameter(torch.ones_like(weight)/torch.numel(weight[0]))
            self.conv3.bias = nn.Parameter(torch.zeros_like(bias))

            mask = self.conv3(mask)
            # print("Mask min: ", torch.min(mask), "Mask max: ", torch.max(mask))

            mask = torch.where(mask > 0, 1.0, 0.0)

            self.conv3.weight = weight
            self.conv3.bias = bias

        x = self.relu3(x, mask)

        x = self.conv4(x)    # (batch_size, 128, 225)

        with torch.no_grad():
            weight = self.conv4.weight
            bias = self.conv4.bias

            self.conv4.weight = nn.Parameter(torch.ones_like(weight)/torch.numel(weight[0]))
            self.conv4.bias = nn.Parameter(torch.zeros_like(bias))

            mask = self.conv4(mask)
            # print("Mask min: ", torch.min(mask), "Mask max: ", torch.max(mask))


            mask = torch.where(mask > 0, 1.0, 0.0)

            self.conv4.weight = weight
            self.conv4.bias = bias

        x = self.relu4(x, mask)
        
        x = self.dropout1(x)

        # Classification/Regression Head
        x = self.flatten(x)  # (batch_size, 28800)
        mask = self.flatten(mask)

        x = self.adaptive_pool(x)  # (batch_size, 1024)
        mask = self.adaptive_pool(mask)
        
        assert torch.unique(mask).tolist() == [0.0, 1.0] or torch.unique(mask).tolist() == [1.0] or torch.unique(mask).tolist() == [0.0], "Mask must contain only 0s and 1s"
        x = x * mask

        x = self.fc1(x)      # (batch_size, 128)
        x = self.relu_fc1(x)
        x = self.dropout2(x)  # Apply dropout
        output = self.fc2(x)      # (batch_size, num_output_classes)
        return output
