"""Solvers that find the optimal set of parameters within a vertex."""
import dataclasses
import multiprocessing as mp
from typing import Any, Dict, Optional
import uuid

import cvxpy as cp
import numpy as np

from xoid import constants

from xoid.solvers import multiprocess_solver

from xoid.util import basics
from xoid.util import solver_util
from xoid.util import network_util
from xoid.util import numerics


_np = basics.to_np
about_equal = numerics.about_equal


@dataclasses.dataclass()
class SolveResults:
    vertex: np.ndarray
    loss: float
    binding_constraints: np.ndarray
    model_params: network_util.ModelParams
    predictions: Optional[np.ndarray] = None


class VertexSolver:
    """Warm startable."""
    def __init__(self, X, Y, loss_fn, m, v=1.0,
                 regularization=None, regularization_constant=0.0, eps=None,
                 *, ignore_feasibility=False, no_scs=False, include_predictions=False):
        assert loss_fn in constants.LOSS_FNS
        assert eps is not None

        # NOTE: We only regularize the w variable, not the b or c variables.
        if regularization is not None:
            assert regularization in constants.REGULARIZERS

        self.X = _np(X)
        self.Y = _np(Y)

        self.dtype = self.X.dtype

        self.loss_fn = loss_fn
        self.regularization = regularization
        self.regularization_constant = regularization_constant

        self.ignore_feasibility = ignore_feasibility
        self.no_scs = no_scs
        self.include_predictions = include_predictions

        if isinstance(v, np.ndarray):
            v = np.reshape(v, [1, m])
        self.v = v

        self.eps = eps

        self.N = X.shape[0]
        self.d = X.shape[-1]
        self.m = m

        # NOTE: Transpose of shape in the non-warm start version.
        self.vertex = cp.Parameter([self.N, self.m], name='vertex')

        self.w = cp.Variable([self.d, self.m], name='w')
        self.b = cp.Variable([1, self.m], name='b')
        self.c = cp.Variable([], name='c')

        self._set_up_problem()

    def _set_up_problem(self):
        w, b, c = self.w, self.b, self.c
        relu_arg = self.X @ w + b

        vvtx = cp.multiply(self.v, self.vertex)
        Y_pred = cp.sum(cp.multiply(vvtx, relu_arg), axis=1) + c

        self.relu_arg = relu_arg
        self.Y_pred = Y_pred

        self.pred_loss = solver_util.compute_loss(self.loss_fn, self.Y, Y_pred)

        obj = self.pred_loss + self._compute_regularization_loss(w)
        self.objective = cp.Minimize(obj)

        self.constraint_lhs = cp.multiply(1 - self.vertex, relu_arg)
        self.constraint_rhs = cp.multiply(self.vertex, relu_arg)
        if self.ignore_feasibility:
            self.constraints = []
        else:
            self.constraints = [self.constraint_lhs <= self.constraint_rhs]

        self.constraints.extend(self._compute_regularization_constraints(w, vvtx))

        self.prob = cp.Problem(self.objective, self.constraints)

    def _compute_regularization_loss(self, w):
        reg_loss = solver_util.compute_loss_regularization(self.regularization, w)
        if reg_loss is None:
            return 0.0
        return self.regularization_constant * reg_loss

    def _compute_regularization_constraints(self, w, vvtx):
        return solver_util.compute_regularization_constraints(
            self.regularization,
            self.regularization_constant,
            w,
            vvtx,
        )

    def solve(self, vertex, *, ecos_kwargs=None, scs_kwargs=None):
        # vertex.shape = [m, N]
        vertex = _np(vertex).astype(self.dtype).T

        self.vertex.value = vertex

        if ecos_kwargs is None:
            ecos_kwargs = {}
        if scs_kwargs is None:
            scs_kwargs = {}

        try:
            loss = self.prob.solve(warm_start=True, solver=cp.ECOS, **ecos_kwargs)
        except cp.error.SolverError:
            if self.no_scs:
                return None
            print('SCS')
            loss = self.prob.solve(warm_start=True, solver=cp.SCS, **scs_kwargs)

        binding_constraints = self.get_binding_constraints()

        predictions = self.Y_pred.value if self.include_predictions else None

        return SolveResults(
            vertex=vertex.T.astype(np.int32),
            loss=loss,
            binding_constraints=binding_constraints,
            predictions=predictions,
            model_params=network_util.ModelParams(
                w=self.w.value,
                b=self.b.value,
                v=self.v,
                c=self.c.value,
            ))

    def is_current_vertex(self, vertex):
        # vertex.shape = [m, N]
        vertex = _np(vertex).astype(self.dtype).T
        return (self.vertex.value == vertex).all()

    def get_binding_constraints(self):
        # Binding means the inequality is an equality at the solution.
        # NOTE: Might have connection to violator spaces when analyzing this
        # theoretically.
        bcs = about_equal(self.constraint_lhs.value, self.constraint_rhs.value, self.eps)
        return bcs.T


###############################################################################


@dataclasses.dataclass()
class SolveRequest:
    uuid: str
    vertex: np.ndarray


@dataclasses.dataclass()
class SolveResponse:
    uuid: str
    results: Optional[SolveResults]
    error: Optional[Any]


class VertexSolverProcess(multiprocess_solver.SolverProcessAbc):

    def _process_message(self, msg: SolveRequest) -> SolveResponse:
        try:
            results = self.solver.solve(msg.vertex)
            return SolveResponse(
                uuid=msg.uuid,
                results=results,
                error=None,
            )
        except cp.error.SolverError as error:
            print('Solver errored.')
            return SolveResponse(
                uuid=msg.uuid,
                results=None,
                error=error,
            )
            

class MultiprocessVertexSolver:
    """
    Essentially the same as `VertexSolver` but can solve different inputs
    in parallel.
    """
    def __init__(self, *args, n_processes: int, solver_cls=VertexSolver, **kwargs):
        self.n_processes = n_processes
        self.solver_cls = solver_cls

        self.request_queue = mp.Queue()
        self.response_queue = mp.Queue()
        # Request id to request.
        self.pending_requests: Dict[str, SolveRequest] = {}

        # Uuids of pending messages to ignore.
        self.ignored_uuids = set()

        self.processes = []
        for _ in range(self.n_processes):
            p = VertexSolverProcess(
                request_queue=self.request_queue,
                response_queue=self.response_queue,
                instantiate_solver=lambda: solver_cls(*args, **kwargs),
            )
            p.start()
            self.processes.append(p)

    def n_pending(self):
        return len(self.pending_requests)

    # def n_active_processes(self):
    #     return len(self.pending_requests) + len(self.ignored_uuids)

    def shutdown_subprocesses(self):
        # TODO: This class probably requires some manually resource deallocation
        # upon garbage collection and/or way to make sure the processes and this
        # object can garbage collected.
        for p in self.processes:
            self.request_queue.put(multiprocess_solver.KILL_SIGNAL)
        for p in self.processes:
            p.join()
        for p in self.processes:
            p.close()

    def launch_solve(self, vertex):
        msg = SolveRequest(uuid=uuid.uuid4().hex, vertex=_np(vertex))
        self.request_queue.put(msg)
        self.pending_requests[msg.uuid] = msg
        return msg.uuid

    def response_stream(self):
        # while self.pending_requests:
        while True:
            response = self.response_queue.get()
            request = self.pending_requests[response.uuid]
            del self.pending_requests[response.uuid]
            if request.uuid in self.ignored_uuids:
                self.ignored_uuids.remove(request.uuid)
                continue
            yield request, response

    def kill_pending_requests(self):
        if not self.pending_requests:
            return

        # # TODO: Actually stop the process instead of just ignoring their results.
        # self.ignored_uuids.update(self.pending_requests.keys())

        # TODO: Actually stop the process instead of waiting to finish.
        for _ in self.response_stream():
            if not self.pending_requests:
                return
