from functools import partial
from typing import Any, Callable, Iterable, Sequence, Tuple

import pyt
from flax import linen as nn
from jax import numpy as np


class pytModule(nn.Module):
    """
    Wrapper around flax.linen.nn.Module which automatically wraps/unwraps params in pyt.Params
    """

    def init(self, *args, **kwargs):
        variables = super().init(*args, **kwargs)
        return variables.copy({"params": pyt.Params(variables["params"])})

    def apply(self, variables, *args, **kwargs):
        replace = {}
        if "params" in variables:
            replace["params"] = variables["params"].params
        return super().apply({**variables, **replace}, *args, **kwargs)


class BasicMLP(pytModule):
    """
    Defines an MLP with a sequence of relu dense layers with widths given by
    `layer_sizes`, followed by a log softmax dense layer for output classes.
    """

    layer_sizes: Sequence[int]
    num_outputs: int
    scale_inputs: float = 1
    first_layer_bias: bool = True
    extract_layer: int|None = None

    @nn.compact
    def __call__(self, inputs, train=False):
        layer_count = 0
        result = None
        x = inputs * self.scale_inputs
        x = x.reshape((x.shape[0], -1))  # flatten
        if self.layer_sizes:
            x = nn.Dense(features=self.layer_sizes[0], use_bias=self.first_layer_bias)(x)
            x = nn.relu(x)
            layer_count += 1
            if self.extract_layer == layer_count:
                result = x
        for size in self.layer_sizes[1:]:
            x = nn.Dense(features=size)(x)
            x = nn.relu(x)
            layer_count += 1
            if self.extract_layer == layer_count:
                result = x
        x = nn.Dense(features=self.num_outputs)(x)
        # x = nn.log_softmax(x)
        if result is not None:
            return result
        return x

    def describe(self):
        return f"Standard Relu MLP with sizes: {self.layer_sizes}"


class mupMLP(pytModule):
    """
    Defines a up-parameterised MLP with a sequence of relu dense layers with widths given by
    `layer_sizes`, followed by a log softmax dense layer for output classes.
    """

    layer_sizes: Sequence[int]
    num_outputs: int

    @nn.compact
    def __call__(self, inputs, train=False):
        x = inputs
        x = x.reshape((x.shape[0], -1))  # flatten
        # mup change: scale for first layer by sqrt(d)
        in_dim = x.shape[1]
        x = x * (in_dim**0.5)
        for size in self.layer_sizes:
            x = nn.Dense(features=size)(x)
            x = nn.relu(x)
        x = nn.Dense(features=self.num_outputs)(x)
        # mup change: scale for last layer by 1/sqrt(n)
        prev_size = self.layer_sizes[-1] if self.layer_sizes else in_dim
        x = x / (prev_size**0.5)
        return x

    def describe(self):
        return f"Standard Relu MLP with sizes: {self.layer_sizes}"


class BlockCNN(pytModule):
    """
    Defines a CNN architecture with blocks of counts[i] convolutional layers of
    features[i] features. After each block is a 2x2 maxpooling layer. After the
    final maxpooling layer, a dense network with the given sizes is used.
    """

    features: Sequence[int]
    counts: Sequence[int]
    dense_sizes: Sequence[int]
    num_outputs: int
    kernel: int = 3
    extract_layer: int|None = None

    @nn.compact
    def __call__(self, inputs, train=False):
        layer_count = 0
        result = None
        x = inputs
        for feature_size, count in zip(self.features, self.counts):
            for i in range(count):
                x = nn.Conv(
                    features=feature_size, kernel_size=(self.kernel, self.kernel)
                )(x)
                x = nn.relu(x)
                layer_count += 1
                if self.extract_layer == layer_count:
                    result = x
            x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        new_extract_layer = None if self.extract_layer is None else self.extract_layer - layer_count
        x = BasicMLP(layer_sizes=self.dense_sizes, num_outputs=self.num_outputs, extract_layer=new_extract_layer)(
            x, train=train
        )
        if result is not None:
            return result
        return x

    def describe(self):
        return f"CNN with maxpooling: features: {self.features}, counts of each: {self.counts}, sizes for dense layers: {self.dense_sizes}"


def GeneralLeNet(features, dense_sizes):
    return partial(BlockCNN, features, [1] * len(features), dense_sizes, kernel=5)


LeNet = GeneralLeNet(features=[6, 16], dense_sizes=[120, 84])
LeNetLarge = GeneralLeNet(features=[50, 50], dense_sizes=[120, 84])


VGG16 = partial(
    BlockCNN,
    features=[64, 128, 256, 512, 512],
    counts=[2, 2, 3, 3, 3],
    dense_sizes=[1024, 256],
)

VGG19 = partial(
    BlockCNN,
    features=[64, 128, 256, 512, 512],
    counts=[2, 2, 3, 4, 4],
    dense_sizes=[4096, 1000],
)


class Alexnet(pytModule):
    """
    Defines a CNN architecture very similar to alexnet.
    """

    num_outputs: int

    @nn.compact
    def __call__(self, inputs, train=False):
        x = inputs
        x = nn.Conv(features=96, kernel_size=(5, 5), strides=(2, 2), padding="SAME")(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="VALID")
        x = nn.Conv(features=256, kernel_size=(5, 5), strides=(1, 1), padding="SAME")(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="VALID")
        x = nn.Conv(features=384, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x)
        x = nn.relu(x)
        x = nn.Conv(features=384, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x)
        x = nn.relu(x)
        x = nn.Conv(features=256, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="VALID")
        x = BasicMLP(layer_sizes=[4096, 4096], num_outputs=self.num_outputs)(
            x, train=train
        )
        return x

    def describe(self):
        return f"Basic Alexnet-like CNN"


# ResNet based off of flax exampls

ModuleDef = Any


class ResNetBlock(nn.Module):
    """ResNet block."""

    filters: int
    conv: ModuleDef
    norm: ModuleDef
    act: Callable
    strides: Tuple[int, int] = (1, 1)

    @nn.compact
    def __call__(
        self,
        x,
    ):
        residual = x
        y = self.conv(self.filters, (3, 3), self.strides)(x)
        y = self.norm()(y)
        y = self.act(y)
        y = self.conv(self.filters, (3, 3))(y)
        # y = self.norm(scale_init=nn.initializers.zeros)(y)
        y = self.norm()(y)

        if residual.shape != y.shape:
            residual = self.conv(self.filters, (1, 1), self.strides, name="conv_proj")(
                residual
            )
            residual = self.norm(name="norm_proj")(residual)

        return self.act(residual + y)


class BottleneckResNetBlock(nn.Module):
    """Bottleneck ResNet block."""

    filters: int
    conv: ModuleDef
    norm: ModuleDef
    act: Callable
    strides: Tuple[int, int] = (1, 1)

    @nn.compact
    def __call__(self, x):
        residual = x
        y = self.conv(self.filters, (1, 1))(x)
        y = self.norm()(y)
        y = self.act(y)
        y = self.conv(self.filters, (3, 3), self.strides)(y)
        y = self.norm()(y)
        y = self.act(y)
        y = self.conv(self.filters * 4, (1, 1))(y)
        y = self.norm(scale_init=nn.initializers.zeros)(y)

        if residual.shape != y.shape:
            residual = self.conv(
                self.filters * 4, (1, 1), self.strides, name="conv_proj"
            )(residual)
            residual = self.norm(name="norm_proj")(residual)

        return self.act(residual + y)


class ResNet(pytModule):
    """ResNetV1."""

    stage_sizes: Sequence[int]
    block_cls: ModuleDef
    num_outputs: int
    num_filters: int = 64
    dtype: Any = np.float32
    act: Callable = nn.relu
    norm: str = "batch"

    @nn.compact
    def __call__(self, x, train: bool = False):
        conv = partial(nn.Conv, use_bias=False, dtype=self.dtype)
        if self.norm == "layer":
            norm = partial(
                nn.LayerNorm,
                epsilon=1e-5,
                dtype=self.dtype,
            )
        elif self.norm == "batch":
            norm = partial(
                nn.BatchNorm,
                use_running_average=not train,
                momentum=0.9,
                epsilon=1e-5,
                dtype=self.dtype,
            )
        elif self.norm == "none":
            norm = lambda **kwargs: lambda x: x

        # NORMAL VERSION
        # x = conv(
        #     self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name="conv_init"
        # )(x)
        # x = norm(name="bn_init")(x)
        # x = nn.relu(x)
        # x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
        # LOW RES VERSION
        x = conv(self.num_filters, (3, 3), (1, 1), padding="SAME", name="conv_init")(x)
        x = norm(name="bn_init")(x)
        x = nn.relu(x)
        for i, block_size in enumerate(self.stage_sizes):
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                x = self.block_cls(
                    self.num_filters * 2**i,
                    strides=strides,
                    conv=conv,
                    norm=norm,
                    act=self.act,
                )(x)
        x = np.mean(x, axis=(1, 2))
        x = nn.Dense(self.num_outputs, dtype=self.dtype)(x)
        # REVERT
        # x = nn.log_softmax(x)
        x = np.asarray(x, self.dtype)
        return x

    def describe(self):
        return f"ResNet with stage sizes: {self.stage_sizes} and block {self.block_cls.__name__}"


ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock)
ResNet18NoNorm = partial(
    ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock, norm="none"
)
ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock)
ResNet34NoNorm = partial(
    ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock, norm="none"
)
ResNet50 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=BottleneckResNetBlock)
ResNet101 = partial(ResNet, stage_sizes=[3, 4, 23, 3], block_cls=BottleneckResNetBlock)
ResNet152 = partial(ResNet, stage_sizes=[3, 8, 36, 3], block_cls=BottleneckResNetBlock)
ResNet200 = partial(ResNet, stage_sizes=[3, 24, 36, 3], block_cls=BottleneckResNetBlock)
ResNet50NoNorm = partial(
    ResNet,
    stage_sizes=[3, 4, 6, 3],
    block_cls=BottleneckResNetBlock,
    norm="none",
)
ResNet50LayerNorm = partial(
    ResNet,
    stage_sizes=[3, 4, 6, 3],
    block_cls=BottleneckResNetBlock,
    norm="layer",
)
