import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
 
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv1d(1, 20, kernel_size=7, stride=1, padding=0),
            nn.BatchNorm1d(20),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=7, stride=3, return_indices=True))
        self.layer2 = nn.Sequential(
            nn.Conv1d(20, 40, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm1d(40),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=5, stride=2, return_indices=True))
        self.layer3 = nn.Sequential(
            nn.Conv1d(40, 40, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm1d(40),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=1, return_indices=True))
        self.fc1 = nn.Linear(8640, 60)
        self.fc2 = nn.Linear(60, 3)

    def mask_forward(self, mask, layer):
        weights = layer[0].weight
        bias = layer[0].bias
        layer[0].weight = nn.Parameter(torch.ones_like(weights)/torch.numel(weights[0]))
        layer[0].bias = nn.Parameter(torch.zeros_like(bias))
        mask = layer[0](mask)
        mask = torch.where(mask > 0, 1.0, 0)
        score_mask = torch.where(mask > 0, 1.0, bias.view(1, -1, 1))
        # score_mask = torch.where(mask > 0, 1.0, 0)
        # mask += bias
        layer[0].weight = weights
        layer[0].bias = bias
        # mask = layer[-1](mask)
        return mask, score_mask

    def forward(self, x,
                explanation_mode = False,
                masking_value = None,
                explanation_mask = None):

        if explanation_mode:
            assert explanation_mask is not None or masking_value is not None, "Explanation_mask or masking_value must be provided in explanation mode"

            if masking_value is not None:
                explanation_mask = torch.where(x == masking_value, 0, 1.0)

            x = torch.where(explanation_mask == 0, 0, x)

            #Applying the Conv Operator and MaxPooling to the Mutant Mask
            explanation_mask, scores = self.mask_forward(explanation_mask, self.layer1)
            output = self.layer1[0](x)
            output = self.layer1[1](output)
            output = torch.where(explanation_mask == 0, scores, output)
            output  = self.layer1[2](output)
            output, indices = self.layer1[3](output)
            explanation_mask = explanation_mask.flatten()[indices.flatten()].reshape(output.shape)


            explanation_mask, scores = self.mask_forward(explanation_mask, self.layer2)
            output = self.layer2[0](output)
            output = self.layer2[1](output)
            output = torch.where(explanation_mask == 0, scores, output)
            output= self.layer2[2](output)
            output, indices = self.layer2[3](output)
            explanation_mask = explanation_mask.flatten()[indices.flatten()].reshape(output.shape)


            explanation_mask, scores = self.mask_forward(explanation_mask, self.layer3)
            output = self.layer3[0](output)
            output = self.layer3[1](output)
            output = torch.where(explanation_mask == 0, scores, output)
            output = self.layer3[2](output)
            output, indices = self.layer3[3](output)
            explanation_mask = explanation_mask.flatten()[indices.flatten()].reshape(output.shape)


            output = output.reshape(output.size(0), -1)
            output = F.relu(self.fc1(output))
            output = F.dropout(output, 0.2)
            output = self.fc2(output)
        else:
            output = self.layer1(x)
            output = self.layer2(output)
            output = self.layer3(output)
            output = output.reshape(output.size(0), -1)
            output = F.relu(self.fc1(output))
            output = F.dropout(output, 0.2)
            output = self.fc2(output)
 
        return output