import jax.nn.initializers as ji
import cola
from cola.ops import Dense, Permutation

import serket as sk
import numpy as np
import jax.numpy as jnp
import jax
import jax.random as jr
import functools as ft


class Sequential(sk.TreeClass):
    def __init__(self, *layers):
        self.layers = layers

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class MLP(sk.TreeClass):
    def __init__(self, sizes, key):
        keys = jr.split(key, len(sizes))
        self.net = Sequential(
            *[Sequential(sk.nn.Linear(s1, s2, key=k), sk.nn.Swish()) for s1, s2, k in zip(sizes[:-2], sizes[1:], keys)],
            sk.nn.Linear(sizes[-2], sizes[-1], key=keys[-1]))

    def __call__(self, x):
        return self.net(x.reshape(-1))


def random_kron(n, m, key):
    keys = jr.split(key, 4)
    sqrtn = int(np.ceil(np.sqrt(n)))
    sqrtm = int(np.ceil(np.sqrt(m)))
    w1 = ji.he_normal()(keys[0], (sqrtn, sqrtm))
    w2 = ji.he_normal()(keys[1], (sqrtn, sqrtm))
    A = cola.kron(Dense(w1), Dense(w2))
    if sqrtm**2 != m or sqrtn**2 != n:
        A = A[:n, :m]
    P1 = Permutation(jr.permutation(keys[2], n), dtype=A.dtype)
    P2 = Permutation(jr.permutation(keys[3], m), dtype=A.dtype)
    return P1 @ A @ P2


def lora(n, m, key):
    rank = 1
    keys = jr.split(key, 2)
    U = ji.he_normal()(keys[0], (n, rank)) * .001
    V = ji.he_normal()(keys[1], (rank, m))
    return Dense(U) @ Dense(V)


def my_weird(n, m, key):
    k = 5
    keys = jr.split(key, k + 1)
    return sum(random_kron(n, m, keyi) / np.sqrt(k) for keyi in keys)  # +lora(n,m,keys[-1])


class ColaLinear(sk.TreeClass):
    def __init__(self, A, key):
        self.A = A
        # self.b = np.zeros(A.shape[0])
        d = A.shape[0]
        # NOTE: key usage not totally kosher here
        self.b = jr.uniform(jr.split(key, 2)[0], (d, )) / np.sqrt(d)

    def __call__(self, x):
        return self.A @ x + self.b


class KronMLP(sk.TreeClass):
    # perm: sk.field(on_getattr=[sk.unfreeze], on_setattr=[sk.freeze])
    def __init__(self, sizes, key):
        keys = jr.split(key, len(sizes) + 1)
        kron_As = [my_weird(s2, s1, k) for s1, s2, k in zip(sizes[:-2], sizes[1:], keys)]
        self.net = Sequential(*[Sequential(ColaLinear(A, key), sk.nn.Swish()) for A, key in zip(kron_As, keys)],
                              sk.nn.Linear(sizes[-2], sizes[-1], key=keys[-2]))
        self.perm = jr.permutation(keys[-1], sizes[0])

    def __call__(self, x):
        return self.net(x.reshape(-1)[self.perm])


@jax.vmap
def softmax_cross_entropy(logits, label):
    return -jax.nn.log_softmax(logits)[label]


@ft.partial(jax.grad, has_aux=True)
def loss_fn(nn, x, y):
    logits = jax.vmap(sk.tree_unmask(nn))(x)
    loss = jnp.mean(softmax_cross_entropy(logits, y))
    return loss, (loss, logits)


@jax.vmap
def accuracy_func(logits, y):
    return jnp.argmax(logits) == y
