from spaghettini import quick_register

import numpy as np
import torch
from torch import nn
from torch.nn import Softplus, ReLU
from torch.nn.utils import spectral_norm


@quick_register
class FCNetFixedWidth(nn.Module):
    def __init__(self, num_inputs=256, num_hidden_dim=1000, num_outputs=10, num_hidden_layers=1,
                 activation_init=ReLU, use_layer_norm=True, add_final_activation=False, use_spectral_norm=False):
        super().__init__()
        self.num_inputs = num_inputs
        self.num_hidden_dim = num_hidden_dim
        self.num_outputs = num_outputs
        self.num_hidden_layers = num_hidden_layers
        self.activation_init = activation_init
        self.use_layer_norm = use_layer_norm
        self.add_final_activation = add_final_activation
        self.use_specnorm = use_spectral_norm

        # Construct the network.
        self.net = self._construct_network()

    def _construct_network(self):
        # Get layer widths.
        widths = [self.num_inputs] + self.num_hidden_layers*[self.num_hidden_dim] + [self.num_outputs]

        # Stack the ops.
        ops = list()
        for i in range(self.num_hidden_layers + 1):
            # Add linear module.
            if self.use_specnorm:
                ops.append(spectral_norm(nn.Linear(widths[i], widths[i+1])))
            else:
                ops.append(nn.Linear(widths[i], widths[i+1]))
            # Add normalization and activation.
            if not self.add_final_activation and i == self.num_hidden_layers:
                break
            if self.use_layer_norm:
                ops.append(nn.LayerNorm(widths[i+1]))
            ops.append(self.activation_init())

        return nn.Sequential(*ops)

    def forward(self, x):
        # Flatten.
        bs = x.shape[0]
        x = x.view((bs, -1))

        # Run the network.
        z = self.net(x)

        return z


@quick_register
class SmallInitLinear(nn.Linear):
    def reset_parameters(self):
        nn.init.normal_(self.weight, mean=0., std=0.0001)
        if self.bias is not None:
            nn.init.zeros_(self.bias)


if __name__ == "__main__":
    """
    Run from root. 
    python -m src.dl.models.fully_connected
    """
    test_num = 0

    if test_num == 0:
        # Check if spectral norm hook works.
        inf, outf = 10, 15
        dummy_inp = torch.ones((1, inf))

        # Instantiate Linear module with spectral normalization.
        spec_lin = spectral_norm(nn.Linear(in_features=inf, out_features=outf))
        print(spec_lin)

        # Run.
        num_forward = 10
        for i in range(num_forward):
            out = spec_lin(dummy_inp)
            weight = spec_lin.weight.detach().cpu().numpy()
            u, s, vh = np.linalg.svd(weight, full_matrices=True)
            print(s[0])
