# Adapted from: https://github.com/locuslab/edge-of-stability
# Original paper: Cohen, J. M., Kaur, S., Li, Y., Kolter, J. Z., & Talwalkar, A. (2021).
# "Gradient Descent on Neural Networks Typically Occurs at the Edge of Stability", ICLR 2021.
# If you use this code, please cite the original work.
#
# Our modifications: lines 13, 16-17, 38-41, 64-142, 212-243, 270-271, 294-301

from typing import List

import torch
import torch.nn as nn

from resnet_cifar import resnet32, resnet20
from vgg import vgg11_nodropout, vgg11_nodropout_bn
from data import num_classes, num_input_channels, image_size, num_pixels
from transformer_seq2seq import Seq2Seq
from vision_transformer import ViT

_CONV_OPTIONS = {"kernel_size": 3, "padding": 1, "stride": 1}

def get_activation(activation: str):
    if activation == 'relu':
        return torch.nn.ReLU()
    elif activation == 'hardtanh':
        return torch.nn.Hardtanh()
    elif activation == 'leaky_relu':
        return torch.nn.LeakyReLU()
    elif activation == 'selu':
        return torch.nn.SELU()
    elif activation == 'elu':
        return torch.nn.ELU()
    elif activation == "tanh":
        return torch.nn.Tanh()
    elif activation == "softplus":
        return torch.nn.Softplus()
    elif activation == "sigmoid":
        return torch.nn.Sigmoid()
    elif activation == "relu-kernel":
        return torch.nn.ReLU()
    elif activation == "relu-mup":
        return torch.nn.ReLU()
    else:
        raise NotImplementedError("unknown activation function: {}".format(activation))

def get_pooling(pooling: str):
    if pooling == 'max':
        return torch.nn.MaxPool2d((2, 2))
    elif pooling == 'average':
        return torch.nn.AvgPool2d((2, 2))


def fully_connected_net(dataset_name: str, widths: List[int], activation: str, bias: bool = True) -> nn.Module:
    modules = [nn.Flatten()]
    for l in range(len(widths)):
        prev_width = widths[l - 1] if l > 0 else num_pixels(dataset_name)
        modules.extend([
            nn.Linear(prev_width, widths[l], bias=bias),
            get_activation(activation),
        ])
    modules.append(nn.Linear(widths[-1], num_classes(dataset_name), bias=bias))
    return nn.Sequential(*modules)


def fully_connected_net_scaled(dataset_name: str, widths: List[int], activation: str, bias: bool = True, scaling: float = 1.0) -> nn.Module:
    modules = [nn.Flatten()]
    for l in range(len(widths)):
        prev_width = widths[l - 1] if l > 0 else num_pixels(dataset_name)
        modules.extend([
            nn.Linear(prev_width, widths[l], bias=bias),
            get_activation(activation),
        ])
    modules.append(nn.Linear(widths[-1], num_classes(dataset_name), bias=bias))

    model = nn.Sequential(*modules)

    def init_weights(m):
        if isinstance(m, nn.Linear):
            with torch.no_grad():
                m.weight.mul_(scaling)  # scale weights in-place
                if m.bias is not None:
                    m.bias.mul_(scaling)  # scale biases too

    model.apply(init_weights)
    return model

class ScaledActivation(nn.Module):
    def __init__(self, activation_fn: nn.Module, scale: float):
        super().__init__()
        self.activation = activation_fn
        self.scale = scale

    def forward(self, x):
        return self.scale * self.activation(x)
    
class OutputScaler(nn.Module):
    def __init__(self, scale: float):
        super().__init__()
        self.scale = scale

    def forward(self, x):
        return self.scale * x
    
def fully_connected_net(dataset_name: str, widths: List[int], activation: str, bias: bool = True) -> nn.Module:
    
    def relu_kernel_linear(in_dim, out_dim, bias):
        """Linear layer with unit variance output for ReLU kernel."""
        layer = nn.Linear(in_dim, out_dim, bias=bias)
        std = 1.0
        nn.init.normal_(layer.weight, mean=0.0, std=std)
        if bias:
            nn.init.zeros_(layer.bias)
        return layer
    
    modules = [nn.Flatten()]
    for l in range(len(widths)):
        prev_width = widths[l - 1] if l > 0 else num_pixels(dataset_name)
        if(activation in ("relu-kernel","relu-mup")):
            modules.extend([
                relu_kernel_linear(prev_width, widths[l], bias),
                    ScaledActivation(get_activation(activation), scale=1.0/prev_width**(1/2))
            ])
        else:
            modules.extend([
                nn.Linear(prev_width, widths[l], bias=bias),
                get_activation(activation),
            ])
    if activation == "relu-kernel":
        final_layer= relu_kernel_linear(widths[-1], num_classes(dataset_name), bias)
        modules.extend([
            final_layer
        ])
    elif activation == "relu-mup":
        final_layer = relu_kernel_linear(widths[-1], num_classes(dataset_name), bias)
        final_scale = 1.0 / widths[-1] ** 0.5
        modules.extend([
            final_layer,
            OutputScaler(final_scale)
        ])
    else:
        modules.append(nn.Linear(widths[-1], num_classes(dataset_name), bias=bias))
    
    return nn.Sequential(*modules)


def convnet(dataset_name: str, widths: List[int], activation: str, pooling: str, bias: bool) -> nn.Module:
    modules = []
    size = image_size(dataset_name)
    for l in range(len(widths)):
        prev_width = widths[l - 1] if l > 0 else num_input_channels(dataset_name)
        modules.extend([
            nn.Conv2d(prev_width, widths[l], bias=bias, **_CONV_OPTIONS),
            get_activation(activation),
            get_pooling(pooling),
        ])
        size //= 2
    modules.append(nn.Flatten())
    modules.append(nn.Linear(widths[-1]*size*size, num_classes(dataset_name)))
    return nn.Sequential(*modules)


def convnet_bn(dataset_name: str, widths: List[int], activation: str, pooling: str, bias: bool) -> nn.Module:
    modules = []
    size = image_size(dataset_name)
    for l in range(len(widths)):
        prev_width = widths[l - 1] if l > 0 else num_input_channels(dataset_name)
        modules.extend([
            nn.Conv2d(prev_width, widths[l], bias=bias, **_CONV_OPTIONS),
            get_activation(activation),
            nn.BatchNorm2d(widths[l]),
            get_pooling(pooling),
        ])
        size //= 2
    modules.append(nn.Flatten())
    modules.append(nn.Linear(widths[-1]*size*size, num_classes(dataset_name)))
    return nn.Sequential(*modules)

def make_deeplinear(L: int, d: int, seed=8):
    torch.manual_seed(seed)
    layers = []
    for l in range(L):
        layer = nn.Linear(d, d, bias=False)
        nn.init.xavier_normal_(layer.weight)
        layers.append(layer)
    network = nn.Sequential(*layers)
    return network.cuda()

def make_one_layer_network(h=10, seed=0, activation='tanh', sigma_w=1.9):
    torch.manual_seed(seed)
    network = nn.Sequential(
        nn.Linear(1, h, bias=True),
        get_activation(activation),
        nn.Linear(h, 1, bias=False),
    )
    nn.init.xavier_normal_(network[0].weight, gain=sigma_w)
    nn.init.zeros_(network[0].bias)
    nn.init.xavier_normal_(network[2].weight)
    return network


def load_architecture(arch_id: str, dataset_name: str) -> nn.Module:
    #  ======   fully-connected networks =======
    if arch_id == 'fc-relu':
        return fully_connected_net(dataset_name, [200, 200], 'relu', bias=True)
    elif arch_id == 'fc-elu':
        return fully_connected_net(dataset_name, [200, 200], 'elu', bias=True)
    elif arch_id == 'fc-tanh':
        return fully_connected_net(dataset_name, [200, 200], 'tanh', bias=True)
    elif arch_id == 'fc-hardtanh':
        return fully_connected_net(dataset_name, [200, 200], 'hardtanh', bias=True)
    elif arch_id == 'fc-softplus':
        return fully_connected_net(dataset_name, [200, 200], 'softplus', bias=True)
    elif arch_id == 'fc-relu-wider':
        return fully_connected_net(dataset_name, [400, 400], 'relu', bias=True)
    elif arch_id == 'fc-relu-widest':
        return fully_connected_net(dataset_name, [600, 600], 'relu', bias=True)
    elif arch_id == 'fc-relu-megawide':
        return fully_connected_net(dataset_name, [2000, 2000], 'relu', bias=True)
    elif arch_id == 'fc-relu-deeper':
        return fully_connected_net(dataset_name, [200, 200, 200, 200], 'relu', bias=True)
    elif arch_id == 'fc-relu-deepest':
        return fully_connected_net(dataset_name, [200, 200, 200, 200, 200, 200], 'relu', bias=True)
    elif arch_id == 'fc-relu-bother':
        return fully_connected_net(dataset_name, [400, 400, 400, 400], 'relu', bias=True)
    elif arch_id == 'fc-relu-bothest':
        return fully_connected_net(dataset_name, [600, 600, 600, 600, 600, 600], 'relu', bias=True)
    elif arch_id == 'fc-relu-5x':
        return fully_connected_net_scaled(dataset_name, [200, 200], 'relu', bias=True, scaling=5.0)
    elif arch_id == 'fc-relu-10x':
        return fully_connected_net_scaled(dataset_name, [200, 200], 'relu', bias=True, scaling=10.0)
    elif arch_id == 'fc-relu-kernel':
        return fully_connected_net(dataset_name, [200, 200], 'relu-kernel', bias=True)
    elif arch_id == 'fc-relu-mup':
        return fully_connected_net(dataset_name, [200, 200], 'relu-mup', bias=True)
    elif arch_id == 'fc-relu-kernel-400':
        return fully_connected_net(dataset_name, [400, 400], 'relu-kernel', bias=True)
    elif arch_id == 'fc-relu-mup-400':
        return fully_connected_net(dataset_name, [400, 400], 'relu-mup', bias=True)
    elif arch_id == 'fc-relu-kernel-600':
        return fully_connected_net(dataset_name, [600, 600], 'relu-kernel', bias=True)
    elif arch_id == 'fc-relu-mup-600':
        return fully_connected_net(dataset_name, [600, 600], 'relu-mup', bias=True)
    elif arch_id == 'fc-relu-mup-100':
        return fully_connected_net(dataset_name, [100, 100], 'relu-mup', bias=True)

    #  ======   convolutional networks =======
    elif arch_id == 'cnn-relu':
        return convnet(dataset_name, [32, 32], activation='relu', pooling='max', bias=True)
    elif arch_id == 'cnn-elu':
        return convnet(dataset_name, [32, 32], activation='elu', pooling='max', bias=True)
    elif arch_id == 'cnn-tanh':
        return convnet(dataset_name, [32, 32], activation='tanh', pooling='max', bias=True)
    elif arch_id == 'cnn-avgpool-relu':
        return convnet(dataset_name, [32, 32], activation='relu', pooling='average', bias=True)
    elif arch_id == 'cnn-avgpool-elu':
        return convnet(dataset_name, [32, 32], activation='elu', pooling='average', bias=True)
    elif arch_id == 'cnn-avgpool-tanh':
        return convnet(dataset_name, [32, 32], activation='tanh', pooling='average', bias=True)

    #  ======   convolutional networks with BN =======
    elif arch_id == 'cnn-bn-relu':
        return convnet_bn(dataset_name, [32, 32], activation='relu', pooling='max', bias=True)
    elif arch_id == 'cnn-bn-elu':
        return convnet_bn(dataset_name, [32, 32], activation='elu', pooling='max', bias=True)
    elif arch_id == 'cnn-bn-tanh':
        return convnet_bn(dataset_name, [32, 32], activation='tanh', pooling='max', bias=True)

    #  ======   real networks on CIFAR-10  =======
    elif arch_id == 'resnet32':
        return resnet32()
    elif arch_id == 'resnet20':
        return resnet20()
    elif arch_id == 'vgg11':
        return vgg11_nodropout()
    elif arch_id == 'vgg11-bn':
        return vgg11_nodropout_bn()

    # ====== additional networks ========
    # elif arch_id == 'transformer':
        # return TransformerModelFixed()
    elif arch_id == 'deeplinear':
        return make_deeplinear(20, 50)
    elif arch_id == 'regression':
        return make_one_layer_network(h=100, activation='tanh')

    # ======= vary depth =======
    elif arch_id == 'fc-tanh-depth1':
        return fully_connected_net(dataset_name, [200], 'tanh', bias=True)
    elif arch_id == 'fc-tanh-depth2':
        return fully_connected_net(dataset_name, [200, 200], 'tanh', bias=True)
    elif arch_id == 'fc-tanh-depth3':
        return fully_connected_net(dataset_name, [200, 200, 200], 'tanh', bias=True)
    elif arch_id == 'fc-tanh-depth4':
        return fully_connected_net(dataset_name, [200, 200, 200, 200], 'tanh', bias=True)
    elif arch_id == 'transformer-reverse':
        return Seq2Seq(seq_len=10, vocab_size=10)
    elif arch_id == 'vit-mnist':
        return ViT(img_size=28, patch_size=7, in_chans=1, num_classes=num_classes(dataset_name),
               embed_dim=64, depth=4, num_heads=4)
    elif arch_id == 'vit-cifar':
        return ViT(img_size=32, patch_size=4, in_chans=3, num_classes=num_classes(dataset_name),
                embed_dim=128, depth=6, num_heads=4)