import torch
import torch.nn as nn
import torch.nn.functional as F

import torch

import torch.nn as nn
import torch.nn.functional as F


def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
    from math import floor
    if type(h_w) is not tuple:
        h_w = (h_w, h_w)
    if type(kernel_size) is not tuple:
        kernel_size = (kernel_size, kernel_size)
    h = floor( ((h_w[0] + (2 * pad) - ( dilation * (kernel_size[0] - 1) ) - 1 )/ stride) + 1)
    w = floor( ((h_w[1] + (2 * pad) - ( dilation * (kernel_size[1] - 1) ) - 1 )/ stride) + 1)
    return h, w


# last fc layer should be named fc

class Net(nn.Module):
    def __init__(self, width=512, input_dim=(3, 28, 28)):
        super(Net, self).__init__()
        self.input_dim = input_dim[0] * input_dim[1] * input_dim[2]
        self.fc1 = nn.Linear(self.input_dim,width)
        self.fc2 = nn.Linear(width, width)
        self.fc = nn.Linear(width, 1)
        for fc in [self.fc1, self.fc2, self.fc]:
            nn.init.xavier_uniform_(fc.weight)
            nn.init.zeros_(fc.bias)
        

    def forward(self, x, with_feats=False):
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.fc(x) #.flatten()
        if with_feats:
            return logits, x
        return logits


class ConvNet(nn.Module):
    def __init__(self, width=512, input_dim=(3, 28, 28)):
        super(ConvNet, self).__init__()
        self.input_dim = input_dim
        self.conv1 = nn.Conv2d(3, 20, 5, 1)
        conv_output_shape_1 = conv_output_shape(input_dim[1], 5)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        conv_output_shape_2 = conv_output_shape(conv_output_shape_1[0]/2, 5)
        self.fc1 = nn.Linear(int(50*conv_output_shape_2[0]/2*conv_output_shape_2[1]/2), width)
        self.fc = nn.Linear(width, 1)

    def forward(self, x, with_feats=False):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(x.shape[0], -1)   
        x = F.relu(self.fc1(x))
        logits = self.fc(x) #.flatten()
        if with_feats:
            return logits, x
        return logits

# adapted from https://github.com/kuangliu/pytorch-cifar/blob/master/models/lenet.py
class LeNet(nn.Module):
    def __init__(self, in_channels=3, width=512):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*4*4, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc   = nn.Linear(84, width)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc(x)
