import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
 
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=5,
            stride=stride, padding=2, bias=False)
        self.bn1 = nn.BatchNorm1d(num_features=out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=5,
            stride=1, padding=2, bias=False)
        self.bn2 = nn.BatchNorm1d(num_features=out_channels)
 
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1,
                    stride=stride, bias=False),
                nn.BatchNorm1d(out_channels))
 
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
 
class InterpretableResNet(nn.Module):
    def __init__(self, layers=6, hiden_size = 100, block_size = 2, input_dim=1356,
        in_channels=64, n_classes=2):
        super(InterpretableResNet, self).__init__()
        self.hidden_sizes = [hiden_size] * layers
        self.num_blocks = [block_size] * layers

        assert len(self.num_blocks) == len(self.hidden_sizes)

        self.input_dim = input_dim
        self.in_channels = in_channels
        self.n_classes = n_classes
 
        self.conv1 = nn.Conv1d(1, self.in_channels, kernel_size=5, stride=1,
            padding=2, bias=False)
        self.bn1 = nn.BatchNorm1d(self.in_channels)
 
        layers = []
        strides = [1] + [2] * (len(self.hidden_sizes) - 1)

        for idx, hidden_size in enumerate(self.hidden_sizes):
            layers.append(self._make_layer(hidden_size, self.num_blocks[idx],
                stride=strides[idx]))
        self.encoder = nn.Sequential(*layers)
 
        self.z_dim = self._get_encoding_size()
        self.linear = nn.Linear(self.z_dim, self.n_classes)
 
    def encode(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.encoder(x)
        z = x.view(x.size(0), -1)
        return z

    def masked_forward(self, input, layer, mask):
        if mask is None:
            mask = torch.ones_like(input)
        output = input

        for bottleneck in layer:
            shortcut = output
            shortcut_mask = mask

            weights = bottleneck.conv1.weight
            bottleneck.conv1.weight = nn.Parameter(torch.ones_like(weights)/torch.numel(weights[0]))
            mask = bottleneck.conv1(mask)
            bottleneck.conv1.weight = weights
            mask = torch.where(mask != 0, 1.0, 0.0)

            output = bottleneck.conv1(output)
            output = bottleneck.bn1(output)
            output = torch.where(mask !=0, output, torch.zeros_like(output))
            output = F.relu(output)
 
            weights = bottleneck.conv2.weight
            bottleneck.conv2.weight = nn.Parameter(torch.ones_like(weights)/torch.numel(weights[0]))
            mask = bottleneck.conv2(mask)
            bottleneck.conv2.weight = weights
            mask = torch.where(mask != 0, 1.0, 0.0)
            
            output = bottleneck.conv2(output)
            output = bottleneck.bn2(output)
            output = torch.where(mask != 0, output, torch.zeros_like(output))
            
            if len(bottleneck.shortcut) > 0:
                weights = bottleneck.shortcut[0].weight
                bottleneck.shortcut[0].weight = nn.Parameter(torch.ones_like(weights)/torch.numel(weights[0]))
                shortcut_mask = bottleneck.shortcut[0](shortcut_mask)
                bottleneck.shortcut[0].weight = weights
                shortcut_mask = torch.where(shortcut_mask != 0, 1.0, 0.0)

                shortcut = bottleneck.shortcut(shortcut)
                shortcut = torch.where(shortcut_mask != 0.0, shortcut, torch.zeros_like(shortcut))
            
            output += shortcut
            output = F.relu(output)
        
        return output, mask

            
    def forward(self, x,
                explanation_mode = False,
                masking_value = None,
                explanation_mask = None):
        
        if explanation_mode:
            assert explanation_mask is not None or masking_value is not None, "Explanation_mask or masking_value must be provided in explanation mode"

            if masking_value is not None:
                explanation_mask = torch.where(x == masking_value, 0, 1.0)

            if len(explanation_mask.shape) == 2 and len(x.shape) == 3:
                explanation_mask = explanation_mask.unsqueeze(1)
                assert explanation_mask.shape == x.shape, f"Explanation mask and input must have the same shape, Got {explanation_mask.shape} and {x.shape}"

            #Since bias term is actually zero, this is truly faithful to the model
            x = torch.where(explanation_mask == 0, 0, x)

            #Applying the Conv Operator and MaxPooling to the Mutant Mask
            weights = self.conv1.weight
            self.conv1.weight = nn.Parameter(torch.ones_like(weights)/torch.numel(weights[0]))
            explanation_mask = self.conv1(explanation_mask)
            explanation_mask = torch.where(explanation_mask != 0, 1.0, 0.0)
            self.conv1.weight = weights

            output = self.conv1(x)
            output = self.bn1(output)
            output = torch.where(explanation_mask == 0, 0, output)
            output = F.relu(output)

            output, explanation_mask = self.masked_forward(input=output, layer=self.encoder[0], mask=explanation_mask)
            output, explanation_mask = self.masked_forward(input=output, layer=self.encoder[1], mask=explanation_mask)
            output, explanation_mask = self.masked_forward(input=output, layer=self.encoder[2], mask=explanation_mask)
            output, explanation_mask = self.masked_forward(input=output, layer=self.encoder[3], mask=explanation_mask)
            output, explanation_mask = self.masked_forward(input=output, layer=self.encoder[4], mask=explanation_mask)
            output, explanation_mask = self.masked_forward(input=output, layer=self.encoder[5], mask=explanation_mask)

            output = output.view(output.size(0), -1)
            output = self.linear(output)

            return output
        else:
            output = self.encode(x)
            output =  self.linear(output)
        
        return output
 
    def _make_layer(self, out_channels, num_blocks, stride=1):
        strides = [stride] + [1] * (num_blocks - 1)
        blocks = []
        for s in strides:
            blocks.append(ResidualBlock(self.in_channels, out_channels,
                stride=s))
            self.in_channels = out_channels
        return nn.Sequential(*blocks)
 
    def _get_encoding_size(self):
        temp = torch.rand(2, 1, self.input_dim)
        z = self.encode(temp)
        z_dim = z.data.size(1)
        return z_dim