from typing import List, final

import torch
import torch.nn as nn
from torch import Block, Tensor, zeros_like
from torch.utils.data import Dataset

import math

from archs import get_activation, get_pooling, num_input_channels, num_classes, image_size, get_final_bn
from utilities import parameters_to_vector, iterate_dataset, lanczos, DEFAULT_PHYS_BS
import bp_preresnet, fsi_preresnet

from si_module import ModuleSI


def bn_cnn_make_layers(cfg, dataset_name: str, activation: str, pooling: str, bn_affine: bool, final_bn) -> nn.Module:
    layers = []
    in_channels = num_input_channels(dataset_name)
    size = image_size(dataset_name)
    for v in cfg:
        if v == 'M':
            layers.extend([
                get_pooling(pooling)
            ])
            size //= 2
        else:
            layers.extend([
                nn.Conv2d(in_channels, v, kernel_size=3, padding=1, stride=1, bias=False),
                nn.BatchNorm2d(v, eps=0.0, affine=bn_affine),
                get_activation(activation),
            ])
            in_channels = v
    layers.extend([
        nn.Flatten(),
        nn.Linear(in_channels * size * size, num_classes(dataset_name), bias=False)
    ])
    if final_bn:
        layers.extend([
            get_final_bn(final_bn, num_classes=num_classes(dataset_name))
        ])

    return layers


def nsi_cnn_make_layers(cfg, dataset_name: str, activation: str, pooling: str) -> nn.Module:
    layers = []
    in_channels = num_input_channels(dataset_name)
    size = image_size(dataset_name)
    for v in cfg:
        if v == 'M':
            layers.extend([
                get_pooling(pooling)
            ])
            size //= 2
        else:
            layers.extend([
                nn.Conv2d(in_channels, v, kernel_size=3, padding=1, stride=1),
                get_activation(activation),
            ])
            in_channels = v
    layers.extend([
        nn.Flatten(),
        nn.Linear(in_channels * size * size, num_classes(dataset_name))
    ])
    return layers


class SICNN(ModuleSI):

    def __init__(self, cfg: list, dataset_name: str, activation: str, pooling: str, bn_affine: bool, final_bn, linear_init: str) -> nn.Module:
        super().__init__()
        
        layers = bn_cnn_make_layers(cfg, dataset_name, activation, pooling, bn_affine, final_bn)
        self.net = nn.Sequential(*layers)
        
        self._si_group = set()

        while layers:
            m = layers.pop()
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                break
        
        for m in layers:
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                self._si_group.add(id(m.weight))
            elif isinstance(m, nn.BatchNorm2d):
                if activation in ['relu', 'leaky_relu']:
                    if m.weight is not None:
                        self._si_group.add(id(m.weight))
                    if m.bias is not None:
                        self._si_group.add(id(m.bias))
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                assert m.bias is None
            elif isinstance(m, nn.Linear):
                if linear_init == 'one_over_sqrt_in':
                    m.weight.data.normal_(0, math.sqrt(1. / m.in_features))
                else:
                    assert linear_init == '0.01'
                    m.weight.data.normal_(0, 0.01)
    

    def forward(self, x):
        return self.net(x)
    
    
    def in_si_group(self, param):
        return id(param) in self._si_group


class NSICNN(ModuleSI):

    def __init__(self, cfg: list, dataset_name: str, activation: str, pooling: str, linear_init: str) -> nn.Module:
        super().__init__()
        
        layers = nsi_cnn_make_layers(cfg, dataset_name, activation, pooling)
        self.net = nn.Sequential(*layers)
        
        self._si_group = set()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                if linear_init == 'one_over_sqrt_in':
                    m.weight.data.normal_(0, math.sqrt(1. / m.in_features))
                else:
                    assert linear_init == '0.01'
                    m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
    
    
    def forward(self, x):
        return self.net(x)
    
    
    def in_si_group(self, param):
        return id(param) in self._si_group


cfgs = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
          512, 512, 512, 512, 'M'],
    'C4': [64, 'M', 128, 'M', 256, 'M'],
}


def sicnn_vgg11(dataset: str, activation: str, pooling: str, bn_affine: bool, final_bn, linear_init: str = '0.01'):
    """VGG 11-layer model (configuration "A") with batch normalization"""
    return SICNN(cfgs['A'], dataset, activation, pooling, bn_affine, final_bn, linear_init)

def nsicnn_vgg11(dataset: str, activation: str, pooling: str, linear_init: str = '0.01'):
    """VGG 11-layer model (configuration "A") without batch normalization"""
    return NSICNN(cfgs['A'], dataset, activation, pooling, linear_init)

def sicnn_convnet4(dataset: str, activation: str, pooling: str, bn_affine: bool, final_bn, linear_init: str = '0.01'):
    return SICNN(cfgs['C4'], dataset, activation, pooling, bn_affine, final_bn, linear_init)


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)
            identity = torch.cat((identity, torch.zeros_like(identity)), 1)

        out += identity
        out = self.relu(out)

        return out


class ResNet(ModuleSI):

    def __init__(self, block, layers, num_classes=10):
        super(ResNet, self).__init__()
        self.num_layers = sum(layers)
        self.inplanes = 16
        self.conv1 = conv3x3(3, 16)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)

        self._nsi_group = set()
        self._nsi_group.add(id(self.fc.weight))
        self._nsi_group.add(id(self.fc.bias))
        for mid, m in enumerate(self.layer3):
            assert isinstance(m, BasicBlock)
            if mid == 0:
                assert isinstance(m.downsample, nn.Sequential)
                assert isinstance(m.downsample[1], nn.BatchNorm2d)
                self._nsi_group.add(id(m.downsample[1].weight))
                self._nsi_group.add(id(m.downsample[1].bias))
            self._nsi_group.add(id(m.bn2.weight))
            self._nsi_group.add(id(m.bn2.bias))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        for m in self.modules():
            if isinstance(m, BasicBlock):
                nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1:
            downsample = nn.Sequential(
                nn.AvgPool2d(1, stride=stride),
                nn.BatchNorm2d(self.inplanes),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(block(planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x
    
    def in_si_group(self, param):
        return id(param) not in self._nsi_group


def siresnet20(dataset: str):
    model = ResNet(BasicBlock, [3, 3, 3], num_classes=num_classes(dataset))
    return model


def siresnet56(dataset: str):
    model = ResNet(BasicBlock, [9, 9, 9], num_classes=num_classes(dataset))
    return model


def si_load_architecture(arch_id: str, dataset_name: str) -> ModuleSI:
    if arch_id == 'sicnn_vgg11_swish':
        return sicnn_vgg11(dataset_name, activation='swish', pooling='average', bn_affine=True, final_bn=False)
    elif arch_id == 'sicnn_vgg11_relu':
        return sicnn_vgg11(dataset_name, activation='relu', pooling='average', bn_affine=True, final_bn=False)
    elif arch_id == 'sicnn_vgg11_swish_noba':
        return sicnn_vgg11(dataset_name, activation='swish', pooling='average', bn_affine=False, final_bn=False)
    elif arch_id == 'sicnn_vgg11_swish_noba_loosi':
        return sicnn_vgg11(dataset_name, activation='swish', pooling='average', bn_affine=False, final_bn=False, linear_init='one_over_sqrt_in')
    elif arch_id == 'sicnn_vgg11_swish_noba_wfbn':
        return sicnn_vgg11(dataset_name, activation='swish', pooling='average', bn_affine=False, final_bn=True)
    elif arch_id == 'sicnn_vgg11_swish_noba_wfbn_loosi':
        return sicnn_vgg11(dataset_name, activation='swish', pooling='average', bn_affine=False, final_bn=True, linear_init='one_over_sqrt_in')
    elif arch_id == 'sicnn_vgg11_swish_noba_ffbn':
        return sicnn_vgg11(dataset_name, activation='swish', pooling='average', bn_affine=False, final_bn='fixed')
    elif arch_id == 'sicnn_vgg11_swish_noba_ffbn1':
        return sicnn_vgg11(dataset_name, activation='swish', pooling='average', bn_affine=False, final_bn='fixed1')
    elif arch_id == 'sicnn_vgg11_swish_noba_ffbn2':
        return sicnn_vgg11(dataset_name, activation='swish', pooling='average', bn_affine=False, final_bn='fixed2')
    elif arch_id == 'sicnn_vgg11_swish_noba_ffbn3':
        return sicnn_vgg11(dataset_name, activation='swish', pooling='average', bn_affine=False, final_bn='fixed3')
    elif arch_id == 'sicnn_vgg11_swish_noba_ffbn4':
        return sicnn_vgg11(dataset_name, activation='swish', pooling='average', bn_affine=False, final_bn='fixed4')
    elif arch_id == 'sicnn_vgg11_swish_noba_ffbn5':
        return sicnn_vgg11(dataset_name, activation='swish', pooling='average', bn_affine=False, final_bn='fixed5')
    elif arch_id == 'sicnn_vgg11_swish_noba_ffbn135': # final
        return sicnn_vgg11(dataset_name, activation='swish', pooling='average', bn_affine=False, final_bn='fixed135')
    elif arch_id == 'nsicnn_vgg11_swish':
        return nsicnn_vgg11(dataset_name, activation='swish', pooling='average')
    elif arch_id == 'nsicnn_vgg11_swish_loosi': 
        return nsicnn_vgg11(dataset_name, activation='swish', pooling='average', linear_init='one_over_sqrt_in')
    elif arch_id == 'sicnn_vgg11_relu_noba':
        return sicnn_vgg11(dataset_name, activation='relu', pooling='average', bn_affine=False, final_bn=False)
    elif arch_id == 'sicnn_vgg11_relu_noba_wfbn':
        return sicnn_vgg11(dataset_name, activation='relu', pooling='average', bn_affine=False, final_bn=True)
    elif arch_id == 'sicnn4_swish_noba':
        return sicnn_convnet4(dataset_name, activation='swish', pooling='average', bn_affine=False, final_bn=False)
    elif arch_id == 'sicnn4_swish_noba_wfbn':
        return sicnn_convnet4(dataset_name, activation='swish', pooling='average', bn_affine=False, final_bn=True)
    elif arch_id == 'sicnn4_swish_noba_ffbn135':
        return sicnn_convnet4(dataset_name, activation='swish', pooling='average', bn_affine=False, final_bn='fixed135')
    elif arch_id == 'siresnet20':
        return siresnet20(dataset_name)
    elif arch_id == 'siresnet56':
        return siresnet56(dataset_name)
    elif arch_id == 'bp_preresnet20': # final
        return bp_preresnet.preresnet20()
    elif arch_id == 'bp_preresnet56':
        return bp_preresnet.preresnet56()
    elif arch_id == 'fsi_preresnet20': # final
        return fsi_preresnet.preresnet20()
    elif arch_id == 'fsi_preresnet56':
        return fsi_preresnet.preresnet56()


def si_compute_global_hvp(network: ModuleSI, loss_fn: nn.Module, siwd, nsiwd,
                          dataset: Dataset, vector: Tensor, physical_batch_size: int = DEFAULT_PHYS_BS):
    """Compute a Hessian-vector product."""
    p = len(parameters_to_vector(network.trainable_parameters()))
    n = len(dataset)
    hvp = torch.zeros(p, device='cuda')
    vector = vector.cuda()
    for (X, y) in iterate_dataset(dataset, physical_batch_size):
        loss = loss_fn(network(X), y) / n
        grads = torch.autograd.grad(loss, inputs=network.trainable_parameters(), create_graph=True)
        dot = parameters_to_vector(grads).dot(vector)
        grads = [g.contiguous() for g in torch.autograd.grad(dot, network.trainable_parameters(), retain_graph=False)]
        hvp += parameters_to_vector(grads)
    
    hvp += parameters_to_vector([
        siwd * torch.ones_like(param) if network.in_si_group(param) else nsiwd * torch.ones_like(param)
        for param in network.trainable_parameters()
    ]) * vector
    
    return hvp


def si_compute_spherical_hvp(network: ModuleSI, loss_fn: nn.Module,
                          dataset: Dataset, vector: Tensor, physical_batch_size: int = DEFAULT_PHYS_BS):
    """Compute a Hessian-vector product."""
    p = len(parameters_to_vector(network.si_parameters()))
    n = len(dataset)
    hvp = torch.zeros(p, device='cuda')
    vector = vector.cuda()
    for (X, y) in iterate_dataset(dataset, physical_batch_size):
        loss = loss_fn(network(X), y) / n
        grads = torch.autograd.grad(loss, inputs=network.si_parameters(), create_graph=True)
        dot = parameters_to_vector(grads).dot(vector)
        grads = [g.contiguous() for g in torch.autograd.grad(dot, network.si_parameters(), retain_graph=False)]
        hvp += parameters_to_vector(grads)

    return hvp * network.si_norm2()


def si_get_global_hessian_eigenvalues(network: ModuleSI, loss_fn: nn.Module, dataset: Dataset,
                                      neigs=6, physical_batch_size=1000, siwd=0.0, nsiwd=0.0):
    """ Compute the leading Hessian eigenvalues. """
    hvp_delta = lambda delta: si_compute_global_hvp(network, loss_fn, siwd, nsiwd, dataset,
                                                    delta, physical_batch_size=physical_batch_size).detach().cpu()
    nparams = len(parameters_to_vector((network.trainable_parameters())))
    return lanczos(hvp_delta, nparams, neigs=neigs)


def si_get_spherical_hessian_eigen(network: ModuleSI, loss_fn: nn.Module, dataset: Dataset,
                                         neigs=6, physical_batch_size=1000):
    """ Compute the leading Hessian eigenvalues. """
    hvp_delta = lambda delta: si_compute_spherical_hvp(network, loss_fn, dataset,
                                                       delta, physical_batch_size=physical_batch_size).detach().cpu()
    nparams = len(parameters_to_vector((network.si_parameters())))
    return lanczos(hvp_delta, nparams, neigs=neigs, return_eigenvectors=True)


def si_get_spherical_hess_grad_product(network: ModuleSI, loss_fn: nn.Module, dataset: Dataset,
                                            physical_batch_size=1000):
    n = len(dataset)
    hvp = [torch.zeros_like(p) for p in network.si_parameters()]
    pgrads = [p.grad.clone() for p in network.si_parameters()]
    for (X, y) in iterate_dataset(dataset, physical_batch_size):
        loss = loss_fn(network(X), y) / n
        grads = torch.autograd.grad(loss, inputs=network.si_parameters(), create_graph=True)
        dot = sum((grads[i].view(-1).dot(pgrads[i].view(-1)) for i in range(len(grads))))
        for i, g in enumerate(torch.autograd.grad(dot, network.si_parameters(), retain_graph=False)):
            hvp[i] += g
    for i, p in enumerate(network.si_parameters()):
        assert torch.linalg.norm(pgrads[i] - p.grad) < 1e-7
    
    for i in range(len(hvp)):
        hvp[i] *= network.si_norm2()
    return hvp
