from typing import Optional
import multiprocessing as mp

import numpy as np

from gluonts.evaluation import Evaluator
from gluonts.meta_tools import get_forecasts_and_series
from gluonts.model.predictor import SimpleParallelizedPredictor


class TrainerMetricsCallback:
    def __init__(
        self,
        train_dataset,
        val_dataset,
        has_gpu,
        num_val: Optional[int] = 100,
        eval_interval: Optional[int] = 1,
        num_pred_jobs: Optional[int] = mp.cpu_count(),
        prediction_length: Optional[int] = None,
    ):
        """
        Usage:
        callback = TrainerMetricsCallback(val_dataset=dataset.test)
        output = estimator.train_model(dataset.train, dataset.test, callback=callback)
        print(callback.metrics)
        :param val_dataset: test where to compute forecasting metrics
        :param num_ts_evals: maximum number of time-series to consider when computing forecasting metrics, the same set
        of series is used for all evaluations.
        :param eval_interval: metrics are computed every eval_interval epochs
        """
        self.metrics = []
        self.num_samples = 1
        self.num_ts_evals = num_val
        self.eval_interval = eval_interval if eval_interval is not None else 1
        self.num_pred_jobs = num_pred_jobs if not has_gpu else None

        def subsample(dataset):
            dataset = list(dataset)
            if self.num_ts_evals is not None:
                # use a random state to select the same time-series at every checkpoint
                # as we may want to compare the errors
                num_ts_evals = min(self.num_ts_evals, len(dataset))
                return np.random.RandomState(seed=0).choice(
                    dataset, num_ts_evals, replace=False
                )
            else:
                return dataset

        self.train_dataset = subsample(train_dataset)
        self.val_dataset = subsample(val_dataset)
        self.prediction_length = prediction_length

    def __call__(
        self,
        epoch: int,
        ctx,
        transformation,
        net,
        create_predictor,
        train_epoch_loss: Optional[float] = None,
        val_epoch_loss: Optional[float] = None,
    ):

        train_agg_metrics = {"dataset": "train", "epoch": epoch}
        val_agg_metrics = {"dataset": "val", "epoch": epoch}

        if train_epoch_loss is not None:
            train_agg_metrics["epoch_loss"] = train_epoch_loss
        if val_epoch_loss is not None:
            val_agg_metrics["epoch_loss"] = val_epoch_loss

        if (epoch + 1) % self.eval_interval == 0 or epoch == 0:
            with ctx:
                # ensure that the prediction network is created within the same MXNet
                # context as the one that was used during training
                predictor = create_predictor(transformation, net)
                if self.num_pred_jobs and self.num_pred_jobs > 1:
                    predictor = SimpleParallelizedPredictor(
                        base_predictor=predictor,
                        num_workers=self.num_pred_jobs,
                    )

            def compute_metrics(dataset):

                forecast_it, ts_it = get_forecasts_and_series(
                    dataset,
                    predictor,
                    prediction_length=self.prediction_length,
                    num_samples=self.num_samples,
                )

                # adding more than one worker throws an error, not sure why
                agg_metrics, item_metrics = Evaluator(num_workers=0)(
                    ts_it,
                    forecast_it,
                    num_series=len(dataset),
                )
                return agg_metrics

            train_agg_metrics.update(compute_metrics(self.train_dataset))
            val_agg_metrics.update(compute_metrics(self.val_dataset))

        self.metrics.append(train_agg_metrics)
        self.metrics.append(val_agg_metrics)


