"""mGLS implementation that can test out several neighboring vertices at once."""
import dataclasses
from typing import Optional

import numpy as np

from xoid.gls import base_mgls
from xoid.solvers import vertex_solvers

from xoid.util import basics
from xoid.util import numerics
from xoid.util import vertex_util


_np = basics.to_np
_vtx_key = vertex_util.to_vertex_key

BaseMglsOptions = base_mgls.BaseMglsOptions
BaseMgls = base_mgls.BaseMgls


@dataclasses.dataclass()
class MultiprocessMglsOptions(BaseMglsOptions):
    n_processes: int = None
    # Defaults to `n_processes` if not set / set to None.
    max_parallelism: Optional[int] = None

    def __post_init__(self):
        super().__post_init__()
        assert self.n_processes is not None
        assert self.n_processes > 0
        if self.max_parallelism is None:
            self.max_parallelism = self.n_processes


class _MultiSolverContext:
    def __init__(self, mgls):
        self.mgls = mgls

    def __enter__(self):
        pass

    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.mgls.multiprocess_vertex_solvers.kill_pending_requests()


class MultiprocessMgls(BaseMgls):

    def shutdown_subprocesses(self):
        self.multiprocess_vertex_solvers.shutdown_subprocesses()

    def _initialize_vertex_solver(self):
        self.multiprocess_vertex_solvers = vertex_solvers.MultiprocessVertexSolver(
            self.X, self.Y, self.options.loss_fn,
            n_processes=self.options.n_processes,
            m=self.m,
            v=self.v,
            regularization=self.options.regularization,
            regularization_constant=self.options.regularization_constant,
            eps=self.eps,
            no_scs=self.options.no_scs,
        )
        return super()._initialize_vertex_solver()

    def _solve_vertex_iter(self, vertex_iter):
        max_pending = self.options.max_parallelism
        response_stream = iter(self.multiprocess_vertex_solvers.response_stream())
        with _MultiSolverContext(self):
            for vertex in vertex_iter:
                vertex = _np(vertex)
                vtx_key = _vtx_key(vertex)

                if vtx_key in self.vertex_to_loss:
                    yield vertex, self.vertex_to_loss[vtx_key], None
                    continue

                if self.multiprocess_vertex_solvers.n_pending() >= max_pending:
                    request, response = next(response_stream)

                    # TODO: Figure out what good error handling would do.
                    if response.error:
                        raise response.error

                    results = response.results
                    if results is None:
                        continue
                    self.vertex_to_loss[_vtx_key(results.vertex)] = results.loss
                    yield results.vertex, results.loss, results

                if self.multiprocess_vertex_solvers.n_pending() < max_pending:
                    self.multiprocess_vertex_solvers.launch_solve(vertex)

        if not self.multiprocess_vertex_solvers.n_pending():
            return

        # TODO: Remove some code duplication here.
        for request, response in response_stream:
            n_pending = self.multiprocess_vertex_solvers.n_pending()
            # TODO: Figure out what good error handling would do.
            if response.error:
                raise response.error

            results = response.results
            if results is None:
                continue
            self.vertex_to_loss[_vtx_key[results.vertex]] = results.loss
            yield results.vertex, results.loss, results
            if not n_pending:
                return
