# our base model will be an MLP
from abc import ABC, abstractmethod
from functools import partial
import pickle

from flax import linen as fnn
import jax
from jax import random
from jax import nn
import jax.numpy as jnp
from jax.tree_util import tree_map, Partial
from jaxopt import GradientDescent
from jaxopt import linear_solve
from jaxopt.tree_util import tree_sub, tree_l2_norm, tree_mul, tree_scalar_mul
from optax import softmax_cross_entropy


DEFAULT_GBM_SOLVER_KWARGS = {
    "stepsize": 0.1,
    "maxiter": 4,
    "maxls": 16,
    "tol": 1e-8,
    "acceleration": False,
    "decrease_factor": 0.5,
}


@jax.jit
def linear(x, w, b):
    return jnp.einsum("io, ...i -> ...o", w, x) + b


@partial(jax.jit, static_argnames=("nonlinearity",))
def mlp(x, params, nonlinearity="relu"):
    h = x
    for wb in params[:-1]:
        w = wb["weight"]
        b = wb["bias"]
        h = linear(h, w, b)
        if nonlinearity == "relu":
            h = nn.relu(h)
        elif nonlinearity == "sigmoid":
            h = nn.sigmoid(h)
        else:
            raise ValueError(f"Unknown nonlinearity {nonlinearity}")
    w, b = params[-1]["weight"], params[-1]["bias"]
    h = linear(h, w, b)
    return h


class CNN(fnn.Module):
    # CNN is a dataclass so there is not init
    n_conv: int = 4
    n_filters: int = 64
    non_linearity: str = "relu"
    n_outputs: int = 5
    pooling: str = "max"
    zero_init: bool = False

    def setup(self):
        if self.non_linearity == "relu":
            self.activation = nn.relu
        elif self.non_linearity == "sigmoid":
            self.activation = nn.sigmoid
        else:
            raise ValueError(f"Unknown nonlinearity {self.non_linearity}")
        self.conv_stride = 2 if self.pooling == "conv" else 1
        if self.zero_init:
            self.bias_init = fnn.initializers.zeros_init()
            self.dense_kernel_init = fnn.initializers.zeros_init()
        else:
            self.bias_init = fnn.initializers.uniform(0.05)
            self.dense_kernel_init = fnn.initializers.xavier_uniform()

    @fnn.compact
    # Provide a constructor to register a new parameter
    # and return its initial value
    def __call__(self, x):
        for i in range(self.n_conv):
            x = fnn.Conv(
                features=self.n_filters,
                kernel_size=(3, 3),
                strides=self.conv_stride,
                kernel_init=fnn.initializers.xavier_uniform(),
                use_bias=False,  # not needed as there are batchnorms afterward
            )(x)
            # https://github.com/cbfinn/maml/issues/9#issuecomment-321533256
            if not self.zero_init:
                x = fnn.BatchNorm(use_running_average=False, momentum=0)(x)
            x = self.activation(x)
            if self.pooling == "max":
                x = fnn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
            if self.zero_init:
                x = fnn.BatchNorm(use_running_average=False, momentum=0)(x)
        x = x.reshape((x.shape[0], -1)) # Flatten
        x = fnn.Dense(
            features=self.n_outputs,
            kernel_init=self.dense_kernel_init,
            bias_init=self.bias_init,
        )(x)
        return x


class MetaNet(ABC):
    def __init__(
            self,
            implicit=False,
            debug_implicit=False,
            debug_inner=False,
            reg_init=2.0,
            learn_reg=False,
            per_param_reg=False,
            cg_steps=5,
            cg_damping=0.,
            zero_init=False,
            **gbm_solver_kwargs,
        ):
        """A Meta-learner for MLPs

        Args:
            implicit (boolean, optional): whether to use iMAML/MAML. Defaults to False
                which uses MAML
            debug_implicit (boolean, optional): whether to use unrolling for the implicit formulation.
                Defaults to False
            debug_inner (boolean, optional): whether to print some info regarding the inner solver.
                Defaults to False
            reg_init (float, optional): initial value for the regularization parameter. Defaults to 2.0.
                The regularization might be learned or not.
            learn_reg (boolean, optional): whether to learn the regularization parameter. Defaults to False.
            per_param_reg (boolean, optional): whether to use per-parameter regularization. Defaults to False.
            cg_steps (int, optional): number of CG steps for Hessian inversion in the IFT. Defaults to 5.
            cg_damping (float, optional): CG damping. Defaults to 0.
            zero_init (boolean, optional): whether to initialize the task-specific parameters to zero.
                Defaults to False.
            gbm_solver_kwargs (dict, optional): Keyword arguments for the gradient descent solver. Defaults to {}.
                Typically you would set stepsize, maxiter, maxls, tol, acceleration, decrease_factor
        """
        self.implicit = implicit
        self.debug_inner = debug_inner
        self.reg_init = reg_init
        self.learn_reg = learn_reg
        self.per_param_reg = per_param_reg
        self.zero_init = zero_init
        gbm_solver_kwargs = {**DEFAULT_GBM_SOLVER_KWARGS, **gbm_solver_kwargs}
        if self.implicit:
            gbm_solver_kwargs["implicit_diff_solve"] = Partial(
                partial(linear_solve.solve_cg, ridge=cg_damping),
                maxiter=cg_steps,
                tol=1e-7,
            )
        self.gbm_solver = GradientDescent(  # inner-solver
            fun=self.loss,
            implicit_diff=self.implicit and not debug_implicit,
            **gbm_solver_kwargs
        )

    @abstractmethod
    def base_call(self, x, params):
        ...

    @abstractmethod
    def base_loss(self, predictions, task_output):
        ...

    @abstractmethod
    def initialize_params(self):
        ...

    @partial(jax.jit, static_argnames=("self",))
    def reg_loss(self, task_adapted_params, meta_params, reg):
        param_diff = tree_sub(task_adapted_params, meta_params)
        if self.per_param_reg:
            sqrt_reg = tree_map(jnp.sqrt, reg)
            weighted_param_diff = tree_mul(param_diff, sqrt_reg)
        else:
            weighted_param_diff = tree_scalar_mul(jnp.sqrt(reg), param_diff)
        param_diff_norm = tree_l2_norm(weighted_param_diff, squared=True)
        r_loss = 0.5 * param_diff_norm
        return r_loss

    @partial(jax.jit, static_argnames=("self",))
    def loss(self, task_adapted_params, meta_params, task_input, task_output, reg=None):
        predictions = self.base_call(task_input, task_adapted_params)
        orig_loss = self.base_loss(predictions, task_output)
        if self.implicit:
            r_loss = self.reg_loss(task_adapted_params, meta_params, reg)
            return orig_loss + r_loss
        return orig_loss

    def duplicate_params(self, meta_params, n_tasks):
        duplicated_params = tree_map(
            lambda x: jnp.repeat(x[None, ...], n_tasks, axis=0),
            meta_params,
        )
        return duplicated_params

    @partial(jax.jit, static_argnames=("self",))
    def __call__(self, training_task_batch, meta_params_and_reg):
        tasks_inputs, tasks_outputs = training_task_batch
        meta_params = meta_params_and_reg["meta_params"]
        reg = tree_map(jnp.exp, meta_params_and_reg["reg"])
        if not self.learn_reg and reg is not None:
            reg = jax.lax.stop_gradient(reg)
        if self.implicit and self.zero_init:
            inner_initialization = tree_map(
                jnp.zeros_like,
                meta_params,
            )
        else:
            inner_initialization = meta_params
            if self.implicit:
                inner_initialization = jax.lax.stop_gradient(inner_initialization)
        inner_solver_result = jax.vmap(self.gbm_solver.run, (None, None, 0, 0, None))(
            inner_initialization,
            meta_params,
            tasks_inputs,
            tasks_outputs,
            reg,
        )
        if self.debug_inner:
            jax.debug.print("inner solver error {e}", e=jnp.mean(inner_solver_result.state.error))
            jax.debug.print("inner solver stepsize {s}", s=jnp.mean(inner_solver_result.state.stepsize))
        task_adapted_params = inner_solver_result.params
        return task_adapted_params


    @partial(jax.jit, static_argnames=("self",))
    def batch_predict(self, task_adapted_params, tasks_inputs):
        return jax.vmap(self.base_call)(tasks_inputs, task_adapted_params)


    def initialize_params_and_reg(self):
        meta_params = self.initialize_params()
        if self.per_param_reg:
            reg = tree_map(lambda x: jnp.array(self.reg_init, dtype=jnp.float32), meta_params)
        else:
            reg = jnp.array(self.reg_init, dtype=jnp.float32)
        return {"meta_params": meta_params, "reg": reg}


class MetaMLP(MetaNet):
    def __init__(self, n_hidden=2, n_units=40, nonlinearity="relu", n_output=1, n_input=1, stddev=1e-2, **meta_kwargs):
        """A Meta-learner for MLPs

        Args:
            n_hidden (int, optional): Number of hidden layers. Defaults to 2.
            n_units (int, optional): Number of units per hidden layer. Defaults to 40.
            nonlinearity (str, optional): Nonlinearity to use. Defaults to "relu".
            n_output (int, optional): Number of output units. Defaults to 1.
            n_input (int, optional): Number of input units. Defaults to 1.
            meta_kwargs (dict, optional): Keyword arguments for the meta network. Defaults to {}.
                Typically you would set stepsize, maxiter, maxls, tol, acceleration, decrease_factor
                implicit, reg, cg_steps, zero_init
        """
        super().__init__(**meta_kwargs)
        self.n_hidden = n_hidden
        self.n_units = n_units
        self.nonlinearity = nonlinearity
        self.n_output = n_output
        self.n_input = n_input
        self.stddev = stddev

    @partial(jax.jit, static_argnames=("self",))
    def base_call(self, x, params):
        predictions = mlp(x, params, nonlinearity=self.nonlinearity)
        return predictions


    @partial(jax.jit, static_argnames=("self",))
    def base_loss(self, predictions, task_output):
        return jnp.mean((predictions - task_output) ** 2)

    def initialize_params(self):
        params = []
        key = random.PRNGKey(0)
        self.initialize_layer(key, self.n_input, self.n_units, params)
        for _ in range(self.n_hidden-1):
            key, subkey = random.split(key)
            self.initialize_layer(subkey, self.n_units, self.n_units, params)
        self.initialize_layer(key, self.n_units, self.n_output, params)
        return params

    def initialize_layer(self, key, n_in, n_out, params):
        w = random.normal(key, (n_in, n_out)) * self.stddev
        b = jnp.zeros(n_out)
        params.append({
            "weight": w,
            "bias": b,
        })
        return params


class MetaConvNet(MetaNet):
    def __init__(
        self,
        n_conv=4,
        n_filters=64,
        non_linearity="relu",
        n_output=1,
        image_size=28,
        pooling="max",
        hypertorch_init=False,
        **meta_kwargs,
    ):
        """A Meta-learner for MLPs

        Args:
            n_conv (int, optional): Number of convolutional layers. Defaults to 4.
            n_filters (int, optional): Number of filters per convolutional layer. Defaults to 64.
            n_hidden (int, optional): Number of hidden layers. Defaults to 2.
            n_units (int, optional): Number of units per hidden layer. Defaults to 40.
            non_linearity (str, optional): Non linearity to use. Defaults to "relu".
            n_output (int, optional): Number of output units. Defaults to 1.
            image_size (int, optional): Size of the image. Defaults to 28.
            pooling (str, optional): Pooling to use. Defaults to "max".
            hypertorch_init (bool, optional): Whether to use zero init for the meta weights. Defaults to False.
            meta_kwargs (dict, optional): Keyword arguments for the meta network. Defaults to {}.
                Typically you would set stepsize, maxiter, maxls, tol, acceleration, decrease_factor
                implicit, reg, cg_steps, zero_init
        """
        super().__init__(**meta_kwargs)
        self.n_conv = n_conv
        self.n_filters = n_filters
        self.non_linearity = non_linearity
        self.n_output = n_output
        self.image_size = image_size
        self.pool = pooling
        self.hypertorch_init = hypertorch_init
        self.cnn = CNN(
            n_conv=n_conv,
            n_filters=n_filters,
            non_linearity=non_linearity,
            n_outputs=n_output,
            pooling=pooling,
            zero_init=hypertorch_init,
        )
        self.compute_frozen_batch_stats()

    def compute_frozen_batch_stats(self):
        key = random.PRNGKey(0)  # does not matter as we only consider the batch stats
        # init which is deterministic
        batch_stats = self.cnn.init(
            key,
            jnp.ones([1, self.image_size, self.image_size, 1]),
        )["batch_stats"]
        self.frozen_batch_stats = batch_stats
        return batch_stats

    @partial(jax.jit, static_argnames=("self",))
    def base_call(self, x, params):
        predictions, _ = self.cnn.apply(
            {
                "params": params,
                "batch_stats": self.frozen_batch_stats,
            },
            x,
            mutable=["batch_stats"],
        )
        return predictions

    @partial(jax.jit, static_argnames=("self",))
    def base_loss(self, predictions, task_output):
        # the second return value of apply are the batch statistics
        return jnp.mean(softmax_cross_entropy(logits=predictions, labels=task_output))

    def initialize_params(self, seed=0):
        key = random.PRNGKey(seed)
        params = self.cnn.init(key, jnp.ones([1, self.image_size, self.image_size, 1]))["params"]
        return params


def save_weights(meta_params, save_path):
    with open(save_path, "wb") as file:
        pickle.dump(meta_params, file)


def load_weights(load_path):
    with open(load_path, "rb") as file:
        meta_params = pickle.load(file)
    return meta_params
