"""Wave Function Collapse solver. Implementation based on https://github.com/ikarth/wfc_2019f"""
from __future__ import annotations

import itertools
import logging
import math
from typing import Any, Callable, Collection, Iterable, Iterator, Mapping, TypeVar

# from scipy import sparse  # type: ignore
import numpy
import numpy as np
from numpy.typing import NBitBase, NDArray

logger = logging.getLogger(__name__)

T = TypeVar("T", bound=NBitBase)


class Contradiction(Exception):
    """Solving could not proceed without backtracking/restarting."""

    pass


class TimedOut(Exception):
    """Solve timed out."""

    pass


class StopEarly(Exception):
    """Aborting solve early."""

    pass


class Solver:
    """WFC Solver which can hold wave and backtracking state."""

    def __init__(
        self,
        *,
        wave: NDArray[np.bool_],
        adj: Mapping[tuple[int, int], NDArray[numpy.bool_]],
        periodic: bool = False,
        backtracking: bool = False,
        on_backtrack: Callable[[], None] | None = None,
        on_choice: Callable[[int, int, int], None] | None = None,
        on_observe: Callable[[NDArray[numpy.bool_]], None] | None = None,
        on_propagate: Callable[[NDArray[numpy.bool_]], None] | None = None,
        check_feasible: Callable[[NDArray[numpy.bool_]], bool] | None = None,
    ) -> None:
        self.wave = wave
        self.adj = adj
        self.periodic = periodic
        self.backtracking = backtracking
        self.history: list[NDArray[np.bool_]] = []  # An undo history for backtracking.
        self.on_backtrack = on_backtrack
        self.on_choice = on_choice
        self.on_observe = on_observe
        self.on_propagate = on_propagate
        self.check_feasible = check_feasible

    @property
    def is_solved(self) -> bool:
        """Is True if the wave has been fully resolved."""
        return (
            self.wave.sum() == self.wave.shape[1] * self.wave.shape[2]
            and (self.wave.sum(axis=0) == 1).all()
        )

    def solve_next(
        self,
        location_heuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]],
        pattern_heuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
    ) -> bool:
        """Attempt to collapse one wave.  Returns True if no more steps remain."""
        if self.is_solved:
            return True
        if self.check_feasible and not self.check_feasible(self.wave):
            raise Contradiction("Not feasible.")
        if self.backtracking:
            self.history.append(self.wave.copy())
        propagate(
            self.wave, self.adj, periodic=self.periodic, onPropagate=self.on_propagate
        )
        pattern, i, j = None, None, None
        try:
            pattern, i, j = observe(self.wave, location_heuristic, pattern_heuristic)
            if self.on_choice:
                self.on_choice(pattern, i, j)
            self.wave[:, i, j] = False
            self.wave[pattern, i, j] = True
            if self.on_observe:
                self.on_observe(self.wave)
            propagate(
                self.wave,
                self.adj,
                periodic=self.periodic,
                onPropagate=self.on_propagate,
            )
            return False  # Assume there is remaining steps, if not then the next call will return True.
        except Contradiction:
            if not self.backtracking:
                raise
            if not self.history:
                raise Contradiction("Every permutation has been attempted.")
            if self.on_backtrack:
                self.on_backtrack()
            self.wave = self.history.pop()
            self.wave[pattern, i, j] = False
            return False

    def solve(
        self,
        location_heuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]],
        pattern_heuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
    ) -> NDArray[np.int64]:
        """Attempts to solve all waves and returns the solution."""
        while not self.solve_next(
            location_heuristic=location_heuristic, pattern_heuristic=pattern_heuristic
        ):
            pass
        return numpy.argmax(self.wave, axis=0)


def makeWave(
    n: int, w: int, h: int, ground: Iterable[int] | None = None
) -> NDArray[numpy.bool_]:
    wave: NDArray[numpy.bool_] = numpy.ones((n, w, h), dtype=numpy.bool_)
    if ground is not None:
        wave[:, :, h - 1] = False
        for g in ground:
            wave[
                g,
                :,
            ] = False
            wave[g, :, h - 1] = True
    # logger.debug(wave)
    # for i in range(wave.shape[0]):
    #  logger.debug(wave[i])
    return wave


def makeAdj(
    adjLists: Mapping[tuple[int, int], Collection[Iterable[int]]]
) -> dict[tuple[int, int], NDArray[numpy.bool_]]:
    adjMatrices = {}
    # logger.debug(adjLists)
    num_patterns = len(list(adjLists.values())[0])
    for d in adjLists:
        m = numpy.zeros((num_patterns, num_patterns), dtype=bool)
        for i, js in enumerate(adjLists[d]):
            # logger.debug(js)
            for j in js:
                m[i, j] = 1
        # If scipy is available, use sparse matrices.
        # adjMatrices[d] = sparse.csr_matrix(m)
        adjMatrices[d] = m
    return adjMatrices


######################################
# Location Heuristics


def makeRandomLocationHeuristic(
    preferences: NDArray[np.floating[Any]],
) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
    def randomLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
        unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
        cell_weights = numpy.where(unresolved_cell_mask, preferences, numpy.inf)
        row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
        return row.item(), col.item()

    return randomLocationHeuristic


def makeEntropyLocationHeuristic(
    preferences: NDArray[np.floating[Any]],
) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
    def entropyLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
        unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
        cell_weights = numpy.where(
            unresolved_cell_mask,
            preferences + numpy.count_nonzero(wave, axis=0),
            numpy.inf,
        )
        row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
        return row.item(), col.item()

    return entropyLocationHeuristic


def makeAntiEntropyLocationHeuristic(
    preferences: NDArray[np.floating[Any]],
) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
    def antiEntropyLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
        unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
        cell_weights = numpy.where(
            unresolved_cell_mask,
            preferences + numpy.count_nonzero(wave, axis=0),
            -numpy.inf,
        )
        row, col = numpy.unravel_index(numpy.argmax(cell_weights), cell_weights.shape)
        return row.item(), col.item()

    return antiEntropyLocationHeuristic


def spiral_transforms() -> Iterator[tuple[int, int]]:
    for N in itertools.count(start=1):
        if N % 2 == 0:
            yield (0, 1)  # right
            for _ in range(N):
                yield (1, 0)  # down
            for _ in range(N):
                yield (0, -1)  # left
        else:
            yield (0, -1)  # left
            for _ in range(N):
                yield (-1, 0)  # up
            for _ in range(N):
                yield (0, 1)  # right


def spiral_coords(x: int, y: int) -> Iterator[tuple[int, int]]:
    yield x, y
    for transform in spiral_transforms():
        x += transform[0]
        y += transform[1]
        yield x, y


def fill_with_curve(
    arr: NDArray[np.floating[T]], curve_gen: Iterable[Iterable[int]]
) -> NDArray[np.floating[T]]:
    arr_len = numpy.prod(arr.shape)
    fill = 0
    for coord in curve_gen:
        # logger.debug(fill, idx, coord)
        if fill < arr_len:
            try:
                arr[tuple(coord)] = fill / arr_len
                fill += 1
            except IndexError:
                pass
        else:
            break
    # logger.debug(arr)
    return arr


def makeSpiralLocationHeuristic(
    preferences: NDArray[np.floating[Any]],
) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
    # https://stackoverflow.com/a/23707273/5562922

    spiral_gen = (
        sc for sc in spiral_coords(preferences.shape[0] // 2, preferences.shape[1] // 2)
    )

    cell_order = fill_with_curve(preferences, spiral_gen)

    def spiralLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
        unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
        cell_weights = numpy.where(unresolved_cell_mask, cell_order, numpy.inf)
        row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
        return row.item(), col.item()

    return spiralLocationHeuristic


def makeHilbertLocationHeuristic(
    preferences: NDArray[np.floating[Any]],
) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
    from hilbertcurve.hilbertcurve import HilbertCurve  # type: ignore

    curve_size = math.ceil(math.sqrt(max(preferences.shape[0], preferences.shape[1])))
    logger.debug(curve_size)
    curve_size = 4
    h_curve = HilbertCurve(curve_size, 2)
    h_coords = (h_curve.point_from_distance(i) for i in itertools.count())
    cell_order = fill_with_curve(preferences, h_coords)
    # logger.debug(cell_order)

    def hilbertLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
        unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
        cell_weights = numpy.where(unresolved_cell_mask, cell_order, numpy.inf)
        row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
        return row.item(), col.item()

    return hilbertLocationHeuristic


def simpleLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
    unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
    cell_weights = numpy.where(
        unresolved_cell_mask, numpy.count_nonzero(wave, axis=0), numpy.inf
    )
    row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
    return row.item(), col.item()


def lexicalLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
    unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
    cell_weights = numpy.where(unresolved_cell_mask, 1.0, numpy.inf)
    row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
    return row.item(), col.item()


#####################################
# Pattern Heuristics


def lexicalPatternHeuristic(weights: NDArray[np.bool_], wave: NDArray[np.bool_]) -> int:
    return numpy.nonzero(weights)[0][0].item()


def makeWeightedPatternHeuristic(
    weights: NDArray[np.floating[Any]],
    np_random: numpy.random.Generator | None = None,
):
    num_of_patterns = len(weights)
    np_random: numpy.random.Generator = (
        numpy.random.default_rng() if np_random is None else np_random
    )

    def weightedPatternHeuristic(wave: NDArray[np.bool_], _: NDArray[np.bool_]) -> int:
        # TODO: there's maybe a faster, more controlled way to do this sampling...
        weighted_wave: NDArray[np.floating[Any]] = weights * wave
        weighted_wave /= weighted_wave.sum()
        result = np_random.choice(num_of_patterns, p=weighted_wave)
        return result

    return weightedPatternHeuristic


def makeRarestPatternHeuristic(
    weights: NDArray[np.floating[Any]],
    np_random: numpy.random.Generator | None = None,
) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]:
    """Return a function that chooses the rarest (currently least-used) pattern."""
    np_random: numpy.random.Generator = (
        numpy.random.default_rng() if np_random is None else np_random
    )

    def weightedPatternHeuristic(
        wave: NDArray[np.bool_], total_wave: NDArray[np.bool_]
    ) -> int:
        logger.debug(total_wave.shape)
        # [logger.debug(e) for e in wave]
        wave_sums = numpy.sum(total_wave, (1, 2))
        # logger.debug(wave_sums)
        selected_pattern = np_random.choice(
            numpy.where(wave_sums == wave_sums.max())[0]
        )
        return selected_pattern

    return weightedPatternHeuristic


def makeMostCommonPatternHeuristic(
    weights: NDArray[np.floating[Any]],
    np_random: numpy.random.Generator | None = None,
) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]:
    """Return a function that chooses the most common (currently most-used) pattern."""
    np_random: numpy.random.Generator = (
        numpy.random.default_rng() if np_random is None else np_random
    )

    def weightedPatternHeuristic(
        wave: NDArray[np.bool_], total_wave: NDArray[np.bool_]
    ) -> int:
        logger.debug(total_wave.shape)
        # [logger.debug(e) for e in wave]
        wave_sums = numpy.sum(total_wave, (1, 2))
        selected_pattern = np_random.choice(
            numpy.where(wave_sums == wave_sums.min())[0]
        )
        return selected_pattern

    return weightedPatternHeuristic


def makeRandomPatternHeuristic(
    weights: NDArray[np.floating[Any]],
    np_random: numpy.random.Generator | None = None,
) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]:
    num_of_patterns = len(weights)
    np_random: numpy.random.Generator = (
        numpy.random.default_rng() if np_random is None else np_random
    )

    def randomPatternHeuristic(wave: NDArray[np.bool_], _: NDArray[np.bool_]) -> int:
        # TODO: there's maybe a faster, more controlled way to do this sampling...
        weighted_wave = 1.0 * wave
        weighted_wave /= weighted_wave.sum()
        result = np_random.choice(num_of_patterns, p=weighted_wave)
        return result

    return randomPatternHeuristic


######################################
# Global Constraints


def make_global_use_all_patterns() -> Callable[[NDArray[np.bool_]], bool]:
    def global_use_all_patterns(wave: NDArray[np.bool_]) -> bool:
        """Returns true if at least one instance of each pattern is still possible."""
        return numpy.all(numpy.any(wave, axis=(1, 2))).item()

    return global_use_all_patterns


#####################################
# Solver


def propagate(
    wave: NDArray[np.bool_],
    adj: Mapping[tuple[int, int], NDArray[numpy.bool_]],
    periodic: bool = False,
    onPropagate: Callable[[NDArray[numpy.bool_]], None] | None = None,
) -> None:
    """Completely probagate any newly collapsed waves to all areas."""
    last_count = wave.sum()

    while True:
        supports = {}
        if periodic:
            padded = numpy.pad(wave, ((0, 0), (1, 1), (1, 1)), mode="wrap")
        else:
            padded = numpy.pad(
                wave, ((0, 0), (1, 1), (1, 1)), mode="constant", constant_values=True
            )

        # adj is the list of adjacencies. For each direction d in adjacency,
        # check which patterns are still valid...
        for d in adj:
            dx, dy = d
            # padded[] is a version of the adjacency matrix with the values wrapped around
            # shifted[] is the padded version with the values shifted over in one direction
            # because my code stores the directions as relative (x,y) coordinates, we can find
            # the adjacent cell for each direction by simply shifting the matrix in that direction,
            # which allows for arbitrary adjacency directions. This is somewhat excessive, but elegant.

            shifted = padded[
                :, 1 + dx : 1 + wave.shape[1] + dx, 1 + dy : 1 + wave.shape[2] + dy
            ]
            # logger.debug(f"shifted: {shifted.shape} | adj[d]: {adj[d].shape} | d: {d}")
            # raise StopEarly
            # supports[d] = numpy.einsum('pwh,pq->qwh', shifted, adj[d]) > 0

            # The adjacency matrix is a boolean matrix, indexed by the direction and the two patterns.
            # If the value for (direction, pattern1, pattern2) is True, then this is a valid adjacency.
            # This gives us a rapid way to compare: True is 1, False is 0, so multiplying the matrices
            # gives us the adjacency compatibility.
            supports[d] = (adj[d] @ shifted.reshape(shifted.shape[0], -1)).reshape(
                shifted.shape
            ) > 0
            # supports[d] = ( <- for each cell in the matrix
            # adj[d]  <- the adjacency matrix [sliced by the direction d]
            # @       <- Matrix multiplication
            # shifted.reshape(shifted.shape[0], -1)) <- change the shape of the shifted matrix to 2-dimensions, to make the matrix multiplication easier
            # .reshape(           <- reshape our matrix-multiplied result...
            #   shifted.shape)   <- ...to match the original shape of the shifted matrix
            # > 0    <- is not false

        # multiply the wave matrix by the support matrix to find which patterns are still in the domain
        for d in adj:
            wave *= supports[d]

        if wave.sum() == last_count:
            break  # No changes since the last loop, changed waves have been fully propagated.
        last_count = wave.sum()

    if onPropagate:
        onPropagate(wave)

    if (wave.sum(axis=0) == 0).any():
        raise Contradiction("Wave is in a contradictory state and can not be solved.")


def observe(
    wave: NDArray[np.bool_],
    locationHeuristic: Callable[[NDArray[np.bool_]], tuple[int, int]],
    patternHeuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
) -> tuple[int, int, int]:
    """Return the next best wave to collapse based on the provided heuristics."""
    i, j = locationHeuristic(wave)
    pattern = patternHeuristic(wave[:, i, j], wave)
    return pattern, i, j


def run(
    wave: NDArray[np.bool_],
    adj: Mapping[tuple[int, int], NDArray[numpy.bool_]],
    locationHeuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]],
    patternHeuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
    periodic: bool = False,
    backtracking: bool = False,
    onBacktrack: Callable[[], None] | None = None,
    onChoice: Callable[[int, int, int], None] | None = None,
    onObserve: Callable[[NDArray[numpy.bool_]], None] | None = None,
    onPropagate: Callable[[NDArray[numpy.bool_]], None] | None = None,
    checkFeasible: Callable[[NDArray[numpy.bool_]], bool] | None = None,
    onFinal: Callable[[NDArray[numpy.bool_]], None] | None = None,
    depth: int = 0,
    depth_limit: int | None = None,
) -> NDArray[numpy.int64]:
    solver = Solver(
        wave=wave,
        adj=adj,
        periodic=periodic,
        backtracking=backtracking,
        on_backtrack=onBacktrack,
        on_choice=onChoice,
        on_observe=onObserve,
        on_propagate=onPropagate,
        check_feasible=checkFeasible,
    )
    while not solver.solve_next(
        location_heuristic=locationHeuristic, pattern_heuristic=patternHeuristic
    ):
        pass
    if onFinal:
        onFinal(solver.wave)
    return numpy.argmax(solver.wave, axis=0)
