import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
 
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=5,
            stride=stride, padding=2, bias=False)
        self.bn1 = nn.BatchNorm1d(num_features=out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=5,
            stride=1, padding=2, bias=False)
        self.bn2 = nn.BatchNorm1d(num_features=out_channels)
 
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1,
                    stride=stride, bias=False),
                nn.BatchNorm1d(out_channels))
 
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
 
class ResNet(nn.Module):
    def __init__(self, layers=6, hiden_size = 100, block_size = 2, input_dim=1356,
        in_channels=64, n_classes=2):
        super(ResNet, self).__init__()
        self.hidden_sizes = [hiden_size] * layers
        self.num_blocks = [block_size] * layers

        assert len(self.num_blocks) == len(self.hidden_sizes)

        self.input_dim = input_dim
        self.in_channels = in_channels
        self.n_classes = n_classes
 
        self.conv1 = nn.Conv1d(1, self.in_channels, kernel_size=5, stride=1,
            padding=2, bias=False)
        self.bn1 = nn.BatchNorm1d(self.in_channels)
 
        layers = []
        strides = [1] + [2] * (len(self.hidden_sizes) - 1)

        for idx, hidden_size in enumerate(self.hidden_sizes):
            layers.append(self._make_layer(hidden_size, self.num_blocks[idx],
                stride=strides[idx]))
        self.encoder = nn.Sequential(*layers)
 
        self.z_dim = self._get_encoding_size()
        self.linear = nn.Linear(self.z_dim, self.n_classes)
 
    def encode(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.encoder(x)
        z = x.view(x.size(0), -1)
        return z

            
    def forward(self, x):
        output = self.encode(x)
        output =  self.linear(output)
        return output
 
    def _make_layer(self, out_channels, num_blocks, stride=1):
        strides = [stride] + [1] * (num_blocks - 1)
        blocks = []
        for s in strides:
            blocks.append(ResidualBlock(self.in_channels, out_channels,
                stride=s))
            self.in_channels = out_channels
        return nn.Sequential(*blocks)
 
    def _get_encoding_size(self):
        temp = torch.rand(2, 1, self.input_dim)
        z = self.encode(temp)
        z_dim = z.data.size(1)
        return z_dim