from spaghettini import quick_register
import math

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import init
from torch.nn.utils import spectral_norm


@quick_register
class SmallTwoConvTwoFC(nn.Module):
    def __init__(self, in_channels=1, num_fc_hidden=500, out_features=160, activation=F.relu, drop_prob=0.0):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(5 * 5 * 50, num_fc_hidden)
        self.bn1 = nn.BatchNorm1d(num_fc_hidden)
        self.dropout1 = nn.Dropout(p=drop_prob)
        self.fc2 = nn.Linear(num_fc_hidden, out_features)
        self.activation = activation

    def forward(self, x):
        assert x.ndimension() == 4
        batch_size, num_channels, height, width = x.shape
        z = self.activation(self.conv1(x))
        z = F.max_pool2d(z, 2, 2)
        z = self.activation(self.conv2(z))
        z = F.max_pool2d(z, 2, 2)
        z = z.view(batch_size, 5 * 5 * 50)
        z = self.activation(self.fc1(z))
        z = self.dropout1(z)
        z = self.fc2(z)

        return z


@quick_register
class TwoLayerFC(nn.Module):
    def __init__(self, num_inputs=256, num_hidden=1000, num_outputs=784, activation=F.relu,
                 final_activation=lambda x: x):
        super().__init__()
        self.num_inputs = num_inputs
        self.num_hidden = num_hidden
        self.num_outputs = num_outputs
        self.activation = activation
        self.final_activation = final_activation

        self.fc1 = nn.Linear(self.num_inputs, self.num_hidden)
        self.fc2 = nn.Linear(self.num_hidden, self.num_outputs)

    def forward(self, x):
        if x.ndimension() > 2:
            x = x.view((x.shape[0], -1))
        z = self.fc1(x)
        z = self.activation(z)
        z = self.fc2(z)
        z = self.final_activation(z)

        return z, dict()


@quick_register
class SmallOneFCThree1DConv(nn.Module):
    def __init__(self, fc_in_features, fc_out_features, out_chn, in_chn=16, hidden_chn1=128, hidden_chn2=256,
                 activation=F.relu, final_activation=lambda x: x):
        super().__init__()
        # Make sure that the output of the fully connected network can be reshaped into the input to the 1D convnet.
        assert fc_out_features % in_chn == 0
        self.fc_in_features = fc_in_features
        self.fc_out_features = fc_out_features
        self.out_chn = out_chn
        self.in_chn = in_chn
        self.hidden_chn1 = hidden_chn1
        self.hidden_chn2 = hidden_chn2
        self.act = activation
        self.final_act = final_activation

        self.fc = nn.Linear(self.fc_in_features, self.fc_out_features)
        self.bn_fc = nn.BatchNorm1d(num_features=self.fc_out_features)
        self.time_steps = fc_out_features // in_chn

        self.conv1 = nn.Conv1d(in_channels=self.in_chn, out_channels=self.hidden_chn1, kernel_size=9, stride=1,
                               padding=4)
        self.conv2 = nn.Conv1d(in_channels=self.hidden_chn1, out_channels=self.hidden_chn2, kernel_size=3, stride=1,
                               padding=1)
        self.conv3 = nn.Conv1d(in_channels=self.hidden_chn2, out_channels=self.out_chn, kernel_size=1, stride=1,
                               padding=0)

    def forward(self, x):
        bs = x.shape[0]
        # ____ Pass through a FC. ____
        z = self.bn_fc(self.act(self.fc(x)))

        # ____ Reshape into the input of the 1D convnet. ____
        z = z.view((bs, self.in_chn, self.time_steps))

        # ____ Pass through three layers of 1D convolution. ____
        z = self.act(self.conv1(z))
        z = self.act(self.conv2(z))
        z = self.final_act(self.conv3(z))

        return z


@quick_register
class OrthogonalInitLinear(nn.Linear):
    def reset_parameters(self) -> None:
        init.orthogonal_(self.weight)
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)


@quick_register
class DropoutLinear(nn.Module):
    def __init__(self, linear_module, input_dropout_rate=0., use_spectral_norm=False):
        super().__init__()
        assert isinstance(linear_module, nn.Linear)
        self.linear_module = linear_module if not use_spectral_norm else spectral_norm(linear_module)
        self.input_dropout_rate = input_dropout_rate
        self.dropout = nn.Dropout(p=input_dropout_rate)

    def forward(self, xs):
        xs = self.dropout(xs)
        return self.linear_module(xs)


@quick_register
class IdentityModule(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, xs):
        return xs


if __name__ == "__main__":
    """
    Run from root. 
    python -m src.dl.models.building_blocks
    """
    test_num = 0

    if test_num == 0:
        linear_module = nn.Linear(in_features=10, out_features=5)
        dropout_linear = DropoutLinear(linear_module=linear_module, input_dropout_rate=0.5)
        xs = torch.randn(size=(3, 10))
        ys = dropout_linear(xs)
