from collections import defaultdict
from functools import partial

import jax
from jax import jit
from jax import numpy as np

import datasets
import utils


class Metric:
    @partial(jit, static_argnums=(0,))
    def join(self, partials):
        return sum(partials) / len(partials)


class OneHotAccuracyMetric(Metric):
    @partial(jit, static_argnums=(0,))
    def __call__(self, logits, targets):
        return np.mean(np.argmax(logits, axis=1) == targets)


class SoftmaxCrossEntropyLossMetric(Metric):
    @partial(jit, static_argnums=(0,))
    def __call__(self, logits, targets):
        return utils.softmax_cross_entropy(logits, jax.nn.one_hot(targets, 10))


class BalanceMetric(Metric):
    @partial(jit, static_argnums=(0,))
    def __call__(self, logits, targets):
        return np.bincount(np.argmax(logits, axis=1), length=10)

    @partial(jit, static_argnums=(0,))
    def join(self, partials):
        return sum(partials)


class BinaryAccuracyMetric(Metric):
    @partial(jit, static_argnums=(0,))
    def __call__(self, logits, targets):
        return np.mean((logits[:, 0] > 0.5) == targets)


class MSEClassificationAccuracy(Metric):
    @partial(jit, static_argnums=(0,))
    def __call__(self, preds, targets):
        return np.mean(np.rint(preds[:, 0]) == targets)


class RegressionMSE(Metric):
    @partial(jit, static_argnums=(0,))
    def __call__(self, preds, targets):
        return np.mean((preds - targets) ** 2)

    @partial(jit, static_argnums=(0,))
    def join(self, partials):
        return (sum(partials) / len(partials)) ** 0.5


class MetricComputer:
    def __init__(self, metrics, data, targets, batch_size=None, prefix=""):
        self.metrics = {prefix + k: m() for k, m in metrics.items()}
        self.batch_size = batch_size
        self.data = data
        self.targets = targets
        if batch_size is not None:
            self.loader = datasets.BatchLoader(
                data, targets, key=None, batch_size=batch_size
            )

    def get_metrics(self, params, forward):
        if self.batch_size is None:
            logits = self.logits(self.data, params, forward)
            return {
                k: v.join([v(logits, self.targets)]) for k, v in self.metrics.items()
            }
        else:
            m = defaultdict(list)
            for data, targets in self.loader:
                outputs = forward(params, data)
                for k, v in self.metrics.items():
                    m[k].append(v(outputs, targets))
            return {k: v.join(m[k]) for k, v in self.metrics.items()}

    @partial(jit, static_argnums=(0, 3))
    def logits(self, data, params, forward):
        return forward(params, data)

    # NOTE: jiting this is *very* slow on the first run as the loop is unrolled
    # and needs ~2000 epochs to become faster
    # @partial(jit, static_argnums=(0, 2))
    def batch_logits(self, params, forward):
        logits = []
        for data, _ in self.loader:
            logits.append(self.logits(data, params, forward))
        return np.concatenate(logits)


class FullMetricComputer:
    def __init__(self, metrics, dataset, batch_size=None):
        self.train_metric_computer = MetricComputer(
            metrics,
            dataset.train_data,
            dataset.train_targets,
            batch_size,
            prefix="train_",
        )
        self.test_metric_computer = MetricComputer(
            metrics,
            dataset.test_data,
            dataset.test_targets,
            batch_size,
            prefix="test_",
        )

    def get_metrics(self, params, forward):
        m1 = self.train_metric_computer.get_metrics(params, forward)
        m2 = self.test_metric_computer.get_metrics(params, forward)
        return {**m1, **m2}
