from types import FunctionType
from typing import Iterable

import numpy as np
from joblib import Parallel, delayed
from sklearn.base import BaseEstimator, TransformerMixin

import multipers as mp
from multipers.simplex_tree_multi import SimplexTreeMulti


reduce_grid = SimplexTreeMulti._reduce_grid


def get_simplex_tree_from_delayed(x) -> mp.SimplexTreeMulti:
    f, args, kwargs = x
    return f(*args, **kwargs)


def get_simplextree(x) -> mp.SimplexTreeMulti:
    if isinstance(x, mp.SimplexTreeMulti):
        return x
    if len(x) == 3 and isinstance(x[0], FunctionType):
        return get_simplex_tree_from_delayed(x)
    else:
        raise TypeError("Not a valid SimplexTree !")


def filtration_grid_to_coordinates(F, return_resolution):
    # computes the mesh as a coordinate list
    mesh = np.meshgrid(*F)
    coordinates = np.concatenate([stuff.flatten()[:, None] for stuff in mesh], axis=1)
    if return_resolution:
        return coordinates, tuple(len(f) for f in F)
    return coordinates


def get_filtration_weights_grid(
    num_parameters: int = 2,
    resolution: int | Iterable[int] = 3,
    *,
    min: float = 0,
    max: float = 20,
    dtype=float,
    remove_homothetie: bool = True,
    weights=None,
):
    """
    Provides a grid of weights, for filtration rescaling.
     - num parameter : the dimension of the grid tensor
     - resolution :  the size of each coordinate
     - min : minimum weight
     - max : maximum weight
     - weights : custom weights (instead of linspace between min and max)
     - dtype : the type of the grid values (useful for int weights)
    """
    from itertools import product

    # if isinstance(resolution, int):
    try:
        float(resolution)
        resolution = [resolution] * num_parameters
    except:
        pass
    if weights is None:
        weights = [
            np.linspace(start=min, stop=max, num=r, dtype=dtype) for r in resolution
        ]
    try:
        float(weights[0])  # same weights for each filtrations
        weights = [weights] * num_parameters
    except:
        None
    out = np.asarray(list(product(*weights)))
    if remove_homothetie:
        _, indices = np.unique(
            [x / x.max() for x in out if x.max() != 0], axis=0, return_index=True
        )
        out = out[indices]
    return list(out)


class SimplexTreeEdgeCollapser(BaseEstimator, TransformerMixin):
    def __init__(
        self,
        num_collapses: int = 0,
        full: bool = False,
        max_dimension: int | None = None,
        n_jobs: int = 1,
    ) -> None:
        super().__init__()
        self.full = full
        self.num_collapses = num_collapses
        self.max_dimension = max_dimension
        self.n_jobs = n_jobs
        return

    def fit(self, X: np.ndarray | list, y=None):
        return self

    def transform(self, X):
        edges_list = Parallel(n_jobs=-1, prefer="threads")(
            delayed(mp.SimplextreeMulti.get_edge_list)(x) for x in X
        )
        collapsed_edge_lists = Parallel(n_jobs=self.n_jobs)(
            delayed(mp._collapse_edge_list)(
                edges, full=self.full, num=self.num_collapses
            )
            for edges in edges_list
        )
        collapsed_simplextrees = Parallel(n_jobs=-1, prefer="threads")(
            delayed(mp.SimplexTreeMulti._reconstruct_from_edge_list)(
                collapsed_edge_lists, swap=True, expand_dim=self.max_dimension
            )
        )
        return collapsed_simplextrees
