import numpy as np
import jax.numpy as jnp
import optax
import jax

from jax import jit, grad, tree_util
from functools import partial
from tqdm import tqdm
from collections import namedtuple


cpus = jax.devices('cpu')
MetaLearnerState = namedtuple('MetaLearnerState', ['model', 'optimizer', 'key'])


class MetaLearner:
    def __init__(self, model, optimizer):
        self.model = model
        self.optimizer = optimizer

        self._training = False

    def train_step(self, params, state, train, test, *args):
        self.train()
        return self._train_step(params, state, train, test, args)

    @partial(jit, static_argnums=(0, 5))
    def _train_step(self, params, state, train, test, args):
        grads, (logs, model_state) = grad(self.outer_loss, has_aux=True)(
            params, state.model, train, test, args)
        updates, opt_state = self.optimizer.update(grads, state.optimizer, params)
        params = optax.apply_updates(params, updates)
        state = MetaLearnerState(model=model_state, optimizer=opt_state)
        return params, state, logs

    def init(self, key, *args, **kwargs):
        params, model_state = self.model.init(key, *args, **kwargs)
        state = MetaLearnerState(model=model_state, optimizer=self.optimizer.init(params))
        return params, state

    def train(self):
        self._training = True

    def eval(self):
        self._training = False

    def evaluate(self, params, state, dataset, *args):
        self.eval()
        if dataset.size is None:
            raise RuntimeError('The dataset for evaluation must be finite.')

        outer_loss = jit(self.outer_loss, static_argnums=(5,))
        results, key = [], state.key
        dataset.reset()
        for i, batch in enumerate(dataset):
            _, (logs, _, key) = outer_loss(params, state.model, key, batch['train'], batch['test'], args)
            results.append(jax.device_put(logs, cpus[0]))
        return tree_util.tree_map(lambda *args: jnp.concatenate(args, axis=0), *results)