import jax.numpy as jnp
import optax
import math
import json
import haiku as hk

from jax import vmap, jit, nn, random, tree_util, grad
from functools import partial
from collections import namedtuple
from dataclasses import dataclass, asdict
from optax import smooth_labels, softmax_cross_entropy

from comln.utils.metrics import accuracy_with_labels
from comln.metalearners.base import MetaLearner, MetaLearnerState
from comln.utils.gradient_flow import gradient_flow
from comln.utils.hk_utils import initializer_to_function


COMLNMetaParameters = namedtuple('COMLNMetaParameters', ['model', 'classifier', 't_final'])

@dataclass
class COMLNArguments:
    t_final: float = 1e+0
    odeint_kwargs: str = '{}'
    eps: float = 0.
    weight_decay: float = 5e-4


class COMLN(MetaLearner):
    def __init__(
            self,
            model,
            num_ways,
            optimizer,
            t_final=1e+0,
            odeint_kwargs='{}',
            eps=0.,
            weight_decay=5e-4
        ):
        super().__init__(model, optimizer)
        self.num_ways = num_ways
        self.t_final = t_final
        self.odeint_kwargs = json.loads(odeint_kwargs)
        self.eps = eps
        self.weight_decay = weight_decay

    def loss(self, params, inputs, labels):
        logits = jnp.matmul(inputs, params.T)
        loss = jnp.mean(softmax_cross_entropy(logits, labels))
        logs = {
            'loss': loss,
            'predictions': nn.softmax(logits, axis=-1),
            'accuracy': accuracy_with_labels(logits, labels)
        }
        return loss, logs

    @partial(jit, static_argnums=(0, 6))
    def outer_loss(self, params, state, key, train, test, args):
        @partial(vmap, in_axes=(None, None, 0, 0, 0))
        def _outer_loss(params, state, subkey, train, test):
            subkey1, subkey2 = random.split(subkey)

            train_features, state = self.model.apply(params.model, state,
                subkey1, train.inputs, *args)
            train_labels = self._smooth_labels(train.targets)
            adapted_params = gradient_flow(self.loss, params.classifier,
                train_features, train_labels, jnp.exp(params.t_final),
                **self.odeint_kwargs)

            test_features, state = self.model.apply(params.model, state,
                subkey2, test.inputs, *args)
            test_labels = self._smooth_labels(test.targets)
            outer_loss, outer_logs = self.loss(
                adapted_params, test_features, test_labels)

            # Log information on training data
            _, train_logs = self.loss(adapted_params, train_features, train_labels)

            return outer_loss, outer_logs, train_logs, state

        key, *subkeys = random.split(key, train.inputs.shape[0] + 1)
        outer_losses, outer_logs, train_logs, states = _outer_loss(
            params, state, jnp.asarray(subkeys), train, test)
        logs = {
            **{f'after/{k}': v for (k, v) in outer_logs.items()},
            **{f'after/train/{k}': v for (k, v) in train_logs.items()}
        }
        state = tree_util.tree_map(partial(jnp.mean, axis=0), states)
        return jnp.mean(outer_losses), (logs, state, key)

    def init(self, key, *args, **kwargs):
        key, subkey = random.split(key)
        encoder_params, encoder_state = self.model.init(subkey, *args, **kwargs)
        features, _ = self.model.apply(encoder_params, encoder_state, key, *args, **kwargs)

        # Initialization meta-parameters
        stddev = 1. / math.sqrt(features.shape[-1])
        w_init = initializer_to_function(hk.initializers.TruncatedNormal(stddev=stddev))
        shape = (self.num_ways, features.shape[-1])

        key, subkey = random.split(key)
        params = COMLNMetaParameters(
            model=encoder_params,
            classifier=w_init(subkey, shape, features.dtype),
            t_final=jnp.asarray(math.log(self.t_final))
        )
        state = MetaLearnerState(
            model=encoder_state,
            optimizer=self.optimizer.init(params),
            key=key
        )
        return params, state

    @partial(jit, static_argnums=(0, 5))
    def _train_step(self, params, state, train, test, args):
        grads, (logs, model_state, key) = grad(self.outer_loss, has_aux=True)(
            params, state.model, state.key, train, test, args)

        # Apply weight decay
        grads = grads._replace(
            model=tree_util.tree_multimap(lambda g, p: g + self.weight_decay * p,
                grads.model, params.model),
            classifier=grads.classifier + self.weight_decay * params.classifier
        )

        updates, opt_state = self.optimizer.update(grads, state.optimizer, params)
        params = optax.apply_updates(params, updates)
        state = MetaLearnerState(model=model_state, optimizer=opt_state, key=key)

        return params, state, logs

    def _smooth_labels(self, targets):
        labels = nn.one_hot(targets, num_classes=self.num_ways)
        return smooth_labels(labels, self.eps) if self._training else labels

    @classmethod
    def from_args(cls, num_ways, args, model=None, optimizer=None):
        if model is None:
            # Identity encoder
            model = hk.transform_with_state(lambda inputs, is_training: inputs)

        if optimizer is None:
            optimizer = optax.multi_transform({
                'model': optax.adam(1e-3),
                'classifier': optax.adam(1e-3),
                't_final': optax.sgd(1e-2, momentum=0.9, nesterov=True)
            }, COMLNMetaParameters(model='model', classifier='classifier', t_final='t_final'))

        return cls(model, num_ways, optimizer, **asdict(args))
