"""Base version of the greedy local search on zonotope vertices."""
import dataclasses
import itertools

import numpy as np

from xoid import constants

from xoid.solvers import feasibility
from xoid.solvers import vertex_solvers

from xoid.util import basics
from xoid.util import vertex_util


_np = basics.to_np


@dataclasses.dataclass()
class GlsOptions:
    """Options/parameters for the base mGLS."""
    loss_fn: str

    def __post_init__(self):
        assert self.loss_fn in constants.LOSS_FNS


class Gls:
    def __init__(self, X, Y, v, options: GlsOptions, *, eps=1e-7):
        self.options = options

        self.X = _np(X)
        self.Y = _np(Y)

        self.v = _np(v)

        self.m = self.v.shape[0]
        self.N, self.d = X.shape

        self.eps = eps
        self.dtype = self.X.dtype

        self.vertex_solver = vertex_solvers.VertexSolver(
            self.X, self.Y, loss_fn=self.options.loss_fn,
            m=self.m, v=self.v, eps=self.eps)
        self.feasibility_checker = feasibility.FeasibilityChecker(self.X, eps=self.eps)

        # shape = [d, m]
        self.current_vertex = vertex_util.random_vertex(self.X, self.m)

    def _is_feasible(self, vertex):
        for i, unit in enumerate(vertex):
            if not self.feasibility_checker.is_feasible(unit, i):
                return False
        return True

    def _all_neighbors(self, vertex):
        m, N = vertex.shape
        for i, j in itertools.product(range(m), range(N)):
            neighbor = np.copy(vertex)
            neighbor[i, j] = 1 - neighbor[i, j]
            yield neighbor

    def _all_feasible_neighbors(self, vertex):
        for neighbor in self._all_neighbors(vertex):
            if self._is_feasible(neighbor):
                yield neighbor

    def solve(self, max_iters: int = 250):
        loss = self.vertex_solver.solve(self.current_vertex).loss

        for i in range(max_iters):
            print(f'{i}: {loss}')

            neighbors_data = []
            for neighbor in self._all_feasible_neighbors(self.current_vertex):
                results = self.vertex_solver.solve(neighbor)
                neighbors_data.append((results.loss, neighbor))

            loss2, neighbor = min(neighbors_data)
            if loss2 < loss:
                self.current_vertex = neighbor
                loss = loss2
            else:
                break

            if loss <= self.eps:
                break

        return loss, i + 1
