from neural_tangents import stax
from jax import vmap
import jax
import jax.numpy as np
import functools
import flax.linen as nn
import jax_models as jm
from flax.core.frozen_dict import unfreeze, freeze

def MyrtleNetwork(width, depth, W_std=np.sqrt(2.0), b_std=0.0, classes=10):
    layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]}
    activation_fn = stax.Relu()
    layers = []
    conv = functools.partial(
        stax.Conv, W_std=W_std, b_std=b_std, padding="SAME", parameterization="standard"
    )

    layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][0]
    layers += [stax.AvgPool((2, 2), strides=(2, 2))]
    layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][1]
    layers += [stax.AvgPool((2, 2), strides=(2, 2))]
    layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][2]

    layers += [stax.GlobalAvgPool()]

    layers += [stax.Flatten(), stax.Dense(classes, W_std, b_std)]

    return stax.serial(*layers)


def MyrtleNetworkNoPooling(width, depth, W_std=np.sqrt(2.0), b_std=0.0, classes=10):
    layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]}
    # activation_fn = stax.Relu()
    # activation_fn = stax.Erf()
    activation_fn = stax.Gelu()
    layers = []
    conv = functools.partial(
        stax.Conv, W_std=W_std, b_std=b_std, padding="SAME", parameterization="standard"
    )

    layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][0]
    layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][1]
    layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][2]

    layers += [stax.Flatten(), stax.Dense(classes, W_std, b_std)]

    return stax.serial(*layers)


def MyrtleNetwork3(
    width,
    depth,
    W_std=np.sqrt(2.0),
    b_std=0.0,
    classes=1,
    activation_fn=stax.Gelu(),
    parameterization="standard",
    widen_factor=2,
    pooling="stride",
):
    assert pooling in ("flat", "stride", "avg", "avg2")
    l1 = (depth - 1) // 3 + min((depth - 1) % 3, 1)
    l2 = (depth - 1) // 3 + max((depth - 1) % 3 - 1, 0)
    l3 = (depth - 1) // 3
    conv = functools.partial(
        stax.Conv,
        W_std=W_std,
        b_std=b_std,
        padding="SAME",
        parameterization=parameterization,
    )

    def downsample(width):
        if pooling == "flat":
            return [conv(width, (3, 3)), activation_fn]
        if pooling in ("stride", "avg2"):
            return [conv(width, (3, 3), (2, 2)), activation_fn]
        elif pooling == "avg":
            return [
                stax.AvgPool((2, 2), strides=(2, 2)),
                conv(width, (3, 3)),
                activation_fn,
            ]

    layers = []
    layers += [conv(width, (3, 3)), activation_fn] * l1
    layers += downsample(widen_factor * width) + [
        conv(widen_factor * width, (3, 3)),
        activation_fn,
    ] * (l2 - 1)
    layers += downsample(widen_factor**2 * width) + [
        conv(widen_factor**2 * width, (3, 3)),
        activation_fn,
    ] * (l3 - 1)

    if pooling in ("flat", "stride"):
        layers += [
            stax.Flatten(),
        ]
    elif pooling in ("avg", "avg2"):
        layers += [stax.GlobalAvgPool()]

    layers += [stax.Dense(classes, W_std, b_std, parameterization=parameterization)]

    return stax.serial(*layers)


# def MyrtleNetwork5(width, *args, **kwargs):
#     return MyrtleNetwork(width, 5, *args, **kwargs)


def FC3Network(width=512, W_std=1, b_std=0.05):
    return stax.serial(
        stax.Flatten(),
        stax.Dense(width, W_std, b_std),
        stax.Erf(),
        stax.Dense(width, W_std, b_std),
        stax.Erf(),
        stax.Dense(10, W_std, b_std),
    )


def WideResnetBlock(
    channels, strides, channel_mismatch, parameterization, activation_fn
):
    main = stax.serial(
        activation_fn,
        stax.Conv(
            channels, (3, 3), strides, padding="SAME", parameterization=parameterization
        ),
        activation_fn,
        stax.Conv(channels, (3, 3), padding="SAME", parameterization=parameterization),
    )
    shortcut = (
        stax.Identity()
        if not channel_mismatch
        else stax.Conv(
            channels, (3, 3), strides, padding="SAME", parameterization=parameterization
        )
    )
    return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut), stax.FanInSum())


def WideResnetGroup(n, channels, strides, parameterization, activation_fn):
    blocks = []
    blocks += [
        WideResnetBlock(channels, strides, True, parameterization, activation_fn)
    ]
    for _ in range(n - 1):
        blocks += [
            WideResnetBlock(channels, (1, 1), False, parameterization, activation_fn)
        ]
    return stax.serial(*blocks)


def WideResnet(
    block_size, k, num_classes, parameterization="standard", activation_fn=stax.Gelu()
):
    return stax.serial(
        stax.Conv(16, (3, 3), padding="SAME", parameterization=parameterization),
        WideResnetGroup(
            block_size, int(16 * k), (1, 1), parameterization, activation_fn
        ),
        WideResnetGroup(
            block_size, int(32 * k), (2, 2), parameterization, activation_fn
        ),
        WideResnetGroup(
            block_size, int(64 * k), (2, 2), parameterization, activation_fn
        ),
        # stax.AvgPool((8, 8)),
        stax.GlobalAvgPool(),
        stax.Flatten(),
        stax.Dense(num_classes, 1.0, 0.0, parameterization=parameterization),
    )


def WideResnetReps(
    block_size, k, num_classes, parameterization="standard", activation_fn=stax.Gelu()
):
    return stax.serial(
        stax.Conv(16, (3, 3), padding="SAME", parameterization=parameterization),
        WideResnetGroup(
            block_size, int(16 * k), (1, 1), parameterization, activation_fn
        ),
        WideResnetGroup(
            block_size, int(32 * k), (2, 2), parameterization, activation_fn
        ),
        WideResnetGroup(
            block_size - 1, int(64 * k), (2, 2), parameterization, activation_fn
        ),
        # stax.AvgPool((8, 8)),
        # stax.GlobalAvgPool(),
        # stax.Flatten(),
        # stax.Dense(num_classes, 1.0, 0.0, parameterization=parameterization),
    )


WRN34X5G_COMPILE_TIME_SEC = 29
WRN34X5G_EVAL_TIME_MSEC = 4.7
WRN34X5G_EVAL64_TIME_MSEC = 42
WRN34X5G_BATCH_SIZES = {
    ("cpu", 32): (20, 20),
    ("cpu", 64): (20, 20),
    ("2080ti", 32): (50, 30),
    ("2080ti", 64): (40, 15),
    ("rtx6k", 32): (100, 60),
    ("rtx6k", 64): (100, 30),
    ("a40", 32): (100, 100),
    ("a40", 64): (100, 60),
}

WRN34X5GE_COMPILE_TIME_SEC = 38
WRN34X5GE_EVAL_TIME_MSEC = 1.2
WRN34X5GE_EVAL64_TIME_MSEC = 30
WRN34X5GE_BATCH_SIZES = {
    ("cpu", 32): (20, 20),
    ("cpu", 64): (20, 20),
    ("2080ti", 32): (60, 10),
    ("2080ti", 64): (30, 1),
    ("rtx6k", 32): (100, 20),
    ("rtx6k", 64): (60, 10),
    ("a40", 32): (100, 80),
    ("a40", 64): (100, 40),
}


def FC3Network2(width=512, W_std=1, b_std=0.05, depth=2, activation_fn=stax.Relu(), parameterization="standard"):
    return stax.serial(
        stax.Flatten(),
        *[stax.Dense(width, W_std, b_std, parameterization=parameterization), activation_fn]
        * (depth - 1),
        stax.Dense(1, W_std, b_std, parameterization=parameterization),
    )


# CNconv = functools.partial(stax.Conv, padding="SAME", parameterization="standard")
# CNdwconv2d = functools.partial(CNconv, dimension_numbers=("NHWC", "HWIO", "NHWC"))


# def ConvNeXtBlock(dim):
#     return stax.serial(
#         stax.FanOut(2),
#         stax.parallel(
#             stax.serial(
#                 CNdwconv2d(dim, (7, 7)),
#                 stax.LayerNorm(),
#                 CNconv(4 * dim, (1, 1)),
#                 stax.Gelu(),
#                 CNconv(dim, (1, 1)),
#                 # CNdwconv2d(dim, (1, 1), W_std=1e-6, b_std=None),
#             ),
#             stax.Identity(),
#         ),
#         stax.FanInSum(),
#     )


# def ConvNeXt(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), num_classes=1):
#     layers = [
#         CNconv(dims[0], (1, 1)),
#         stax.LayerNorm(),
#     ]
#     for j in range(depths[0]):
#         layers += [ConvNeXtBlock(dims[0])]

#     for i in range(3):
#         layers += [
#             CNconv(dims[i + 1], (2, 2), (2, 2)),
#             stax.LayerNorm(),
#             ConvNeXtBlock(dims[i + 1]),
#         ]

#     layers += [
#         stax.GlobalAvgPool(),
#         stax.Flatten(),
#         stax.LayerNorm(),
#         stax.Dense(num_classes),
#     ]
#     return stax.serial(*layers)


def laplace_kernel(c, gamma=1):
    @jax.custom_jvp
    def norm_fun(x):
        return np.exp(-np.linalg.norm(x) ** gamma / c)

    @norm_fun.defjvp
    def norm_fun_jvp(primals, tangents):
        x, x_dot = primals[0], tangents[0]
        t0 = np.linalg.norm(x)
        primal_out = np.exp(-(t0**gamma) / c)
        eps = 1e-12
        tangent_out = (
            -gamma * t0 ** (gamma - 1) * primal_out / c / (t0 + eps) * np.vdot(x, x_dot)
        )
        return primal_out, tangent_out

    def kernel_fn(x, y):
        return vmap(vmap(lambda x, y: norm_fun(x - y), (None, 0)), (0, None))(x, y)

    return kernel_fn


class ConvNeXt1(nn.Module):
    convnext: nn.Module
    num_classes: int
    attach_head: bool

    @nn.compact
    def __call__(self, inputs):
        x = self.convnext(inputs, True)
        x = nn.LayerNorm(name="norm")(np.mean(x, [1, 2]))
        if self.attach_head:
            x = nn.Dense(
                1,
                kernel_init=nn.initializers.variance_scaling(
                    0.2, "fan_in", distribution="truncated_normal"
                ),
                name="head",
            )(x)

        return x


def ConvNeXt(num_classes=1, version="convnext-tiny", pretrained=False, attach_head=True):
    model_and_params = jm.load_model(version, pretrained=pretrained, attach_head=False)
    if pretrained:
        model, params = model_and_params
    else:
        model = model_and_params

    model = ConvNeXt1(model, num_classes, attach_head)

    def init_fn(key, input_shape):
        nonlocal params
        variables = unfreeze(model.init(key, np.zeros(input_shape)))

        if not pretrained:
            return None, variables

        variables["params"]["convnext"] = params
        params = freeze(variables)
        return None, params

    apply_fn = model.apply
    return init_fn, apply_fn, None
