import pathlib
from setuptools import setup

dir = pathlib.Path(__file__).parent
README = (dir / "README.md").read_text()
# Dependencies
with open("requirements.txt") as f:
    requirements = f.readlines()

setup(
    name="spatial-kfold",
    version="0.0.4",
    packages=["spatialkfold"],
    author="Walid Ghariani",
    author_email="walid11ghariani@gmail.com",
    description=(
        "spatial-kfold: A Python Package for Spatial Resampling Toward More Reliable Cross-Validation in Spatial Studies."
    ),
    long_description=README,
    long_description_content_type="text/markdown",
    license="GPL-3.0",
    keywords="cross-validation, machine-learning, GIS, spatial",
    url="https://github.com/WalidGharianiEAGLE/spatial-kfold",
    package_data={"spatialkfold": ["./data/*.geojson"]},
    include_package_data=True,
    # Dependencies
    install_requires=requirements,
    python_requires=">=3.7",
    # Classifiers
    classifiers=[
        "Development Status :: 3 - Alpha",
        "Intended Audience :: Science/Research",
        "License :: OSI Approved :: MIT License",
        "Programming Language :: Python :: 3",
        "Programming Language :: Python :: 3.7",
        "Programming Language :: Python :: 3.8",
        "Programming Language :: Python :: 3.9",
        "Programming Language :: Python :: 3.10",
        "Programming Language :: Python :: 3.11",
        'Programming Language :: Python :: 3.12',
        'Programming Language :: Python :: 3.13',
    ],
    # testing
    setup_requires=["pytest-runner"],
    tests_require=["pytest"],
)


from typing import Union
import math

import numpy as np
import pandas as pd
import geopandas as gpd
from shapely.geometry import box
from shapely.geometry import Polygon


def create_grid(
    gdf: gpd.GeoDataFrame,
    width: Union[int, float],
    height: Union[int, float],
    grid_type="rect",
) -> gpd.GeoDataFrame:
    """
    Create a grid of polygons with a specified width and height based on the bounds of a provided GeoDataFrame.

    Parameters
    ----------
    gdf : GeoDataFrame
        The GeoDataFrame containing the bounds to use for creating the grid.
    width : int or float
        The width of the grid cells in the x-dimension.
    height : int or float
        The height of the grid cells in the y-dimension.
    grid_type : str
        Either 'rect' for rectangular grid or 'hex' for hexagonal grid.

    Returns
    -------
    GeoDataFrame
        A GeoDataFrame containing the grid polygons.
        Each polygon represents a grid cell with the specified 'width' and 'height'.

    Source: Code for creating a grid was adapted from the solution provided by user "Mativane" in the following
    gis.stackexchange thread: https://gis.stackexchange.com/questions/269243/creating-polygon-grid-using-geopandas
    """
    if not isinstance(gdf, gpd.GeoDataFrame):
        raise TypeError("Input must be a GeoDataFrame")
    if gdf.crs == None:
        raise AttributeError(
            "The passed GeoDataFrame has no CRS. Use `to_crs()` to reproject one of the input geometries."
        )
    if not (isinstance(width, (int, float)) and width > 0):
        raise ValueError("Width must be a positive number")
    if grid_type not in ["rect", "hex"]:
        raise ValueError(
            f"Invalid grid_type {grid_type}. Specify either 'rect' or 'hex'."
        )
    if grid_type == "rect":
        if not (isinstance(height, (int, float)) and height > 0):
            raise ValueError("Height must be a positive number for 'rect' grid.")

    # Get the bounds of the points
    xmin, ymin, xmax, ymax = gdf.total_bounds
    polygons = []

    if grid_type == "rect":
        cols = np.arange(xmin, xmax + width, width)
        rows = np.arange(ymin, ymax + height, height)
        polygons = [
            box(x, y, x + width, y + height) for x in cols[:-1] for y in rows[:-1]
        ]

    elif grid_type == "hex":
        sqrt3 = np.sqrt(3)
        cos = sqrt3 / 2  # cos(30°)
        r = width / 2  # center to flat
        R = r / cos  # circumradius (center to corner)
        dx = 3 / 2 * R  # horizontal distance between centers
        dy = sqrt3 * R  # vertical distance between rows

        x = xmin
        col = 0
        while x < xmax + dx:
            y_offset = 0 if col % 2 == 0 else dy / 2
            y = ymin + y_offset
            while y < ymax + dy:
                polygons.append(_create_flat_top_hexagon(x, y, R))
                y += dy
            x += dx
            col += 1

    return gpd.GeoDataFrame({"geometry": polygons}, crs=gdf.crs)


def _create_flat_top_hexagon(cx: float, cy: float, R: float) -> Polygon:
    """
    Create a flat-topped hexagon centered at (cx, cy) with circumradius R.
    """
    return Polygon(
        [
            (
                cx + R * math.cos(math.radians(angle)),
                cy + R * math.sin(math.radians(angle)),
            )
            for angle in [0, 60, 120, 180, 240, 300]
        ]
    )


def spatial_blocks(
    gdf: gpd.GeoDataFrame,
    width: Union[int, float],
    height: Union[int, float],
    nfolds: int,
    method="random",
    orientation="tb-lr",
    grid_type: str = "rect",
    random_state=None,
):
    """
    Create a grid of polygons based on the intersection with a provided GeoDataFrame and assign each polygon
    to a number of fold.

    Parameters
    ----------
    gdf : GeoDataFrame
        The GeoDataFrame containing the points to use for creating the blocks.
    width : int or float
        The width of the grid cells in the x-dimension.
    height : int or float
        The height of the grid cells in the y-dimension.
    nfolds : int
        The number of folds to assign for each polygon.
    method : str, optional
        The method to use for assigning folds to the blocks. Valid values are 'continuous' and 'random'.
        Default is 'random'.
    orientation : str, optional
        The orientation of the grid-folds. Can be 'tb-lr' (top-bottom, left-right) and 'bt-rl' (bottom-top, right-left).
        Default is 'tb-lr'.
    grid_type : str
        'rect' or 'hex'.
    random_state : int, optional
        An optional integer seed to use when shuffling the grid cells. If provided, this allows the shuffling of the grid
        cells to be reproducible.

    Returns
    -------
    GeoDataFrame
        A GeoDataFrame containing the blocks, with a 'folds' column indicating the block number for each polygon.
    """
    if not (isinstance(nfolds, int) and nfolds > 0):
        raise ValueError("nfolds must be a positive int number.")
    if method not in ["random", "continuous"]:
        raise ValueError(
            f"Invalid method {method}. Specify either 'random' or 'continuous'."
        )
    elif orientation not in ["tb-lr", "bt-rl"]:
        raise ValueError(
            f"Invalid orientation {orientation}. Specify either 'tb-lr' or 'bt-rl'. By default the orientation is 'tb-lr'."
        )

    # Create GeoDataFrame containing the grid of polygons
    grids = create_grid(gdf, width, height, grid_type)

    in_grids = grids.sjoin(gdf, how="inner").drop_duplicates("geometry")
    # Keep only geometry column
    valid_grids = in_grids.copy()[["geometry"]]
    # Reset index and remove index column
    valid_grids = valid_grids.reset_index().copy().drop(columns=["index"])
    # Shuffle the blocks if method = random
    if method == "random":
        sp_blocks = valid_grids.sample(frac=1, random_state=random_state)
    elif method == "continuous" and orientation == "tb-lr":
        sp_blocks = valid_grids
    elif method == "continuous" and orientation == "bt-rl":
        reversed_blocks = valid_grids[::-1].reset_index(drop=True)
        sp_blocks = reversed_blocks

    # Split the data into a certain number of blocks
    block_indices = np.array_split(sp_blocks.index, nfolds)
    blocks_list = [sp_blocks.loc[idx].assign(folds=i) for i, idx in enumerate(block_indices, start=1)]

    # Set as Create a geodataframe
    blocks_folds_gdf = gpd.GeoDataFrame(pd.concat(blocks_list))

    return blocks_folds_gdf


import geopandas as gpd
from sklearn.cluster import KMeans
from sklearn.cluster import BisectingKMeans


def spatial_kfold_clusters(
    gdf: gpd.GeoDataFrame,
    name: str,
    nfolds: int,
    algorithm="kmeans",
    random_state=None,
    **kwargs
):
    """
    Perform spatial clustering using KMeans or BisectingKMeans on a GeoDataFrame with coordinates
    and assign each geo point to a fold.

    Parameters
    ----------
    gdf : GeoDataFrame
        The GeoDataFrame with a geometry column containing the points to use for spatial clustering.
    name : str
        Name of the column that identifies each unique geospatial point (e.g., station_id or city_code).
    nfolds : int
        The number of clusters/folds to assign for each geospatial point.
    algorithm : str, optional
        The clustering algorithm to use ('kmeans' or 'bisectingkmeans'). Default is 'kmeans'.
    kwargs : set of arguments to provide for each algorithm:
        - For kmeans from sklearn API:
        https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html
        e.g., algorithm {"lloyd", "elkan", “auto”, “full”}, default=”lloyd”
        - For bisectingkmeans from sklearn API:
        https://scikit-learn.org/stable/modules/generated/sklearn.cluster.BisectingKMeans.html
        e.g.,{“lloyd”, “elkan”}, default=”lloyd”
    random_state : int, optional
        An optional integer seed to use for centroid initialization.

    Returns
    -------
    GeoDataFrame
        A GeoDataFrame containing a 'folds' column.
    """
    if not isinstance(gdf, gpd.GeoDataFrame):
        raise TypeError("Input must be a GeoDataFrame.")
    if gdf.crs == None:
        raise AttributeError(
            "The passed GeoDataFrame has no CRS. Use `to_crs()` to reproject one of the input geometries."
        )
    if not (isinstance(nfolds, int) and nfolds > 0):
        raise ValueError("nfolds must be a positive int number.")
    if algorithm not in ["kmeans", "bisectingkmeans"]:
        raise ValueError(
            'Unsupported clustering algorithm. Use "kmeans" or "bisectingkmeans".'
        )

    gdf_copy = gdf.reset_index().copy().drop(columns=["index"])
    gdf_copy["lon"] = gdf_copy["geometry"].centroid.x
    gdf_copy["lat"] = gdf_copy["geometry"].centroid.y
    gdf_valid = gdf_copy[[name, "lon", "lat"]]

    # Remove duplicates so we can save time and run the algorithm on the unique spatial points
    gdf_sp = gdf_valid.drop_duplicates()
    # We need only the 'lat' and 'lat' for the algorithm to run
    lon_lat = gdf_sp.copy().drop(columns=name)

    if algorithm == "kmeans":
        clustering_model = KMeans(
            n_clusters=nfolds, random_state=random_state, **kwargs
        )
    elif algorithm == "bisectingkmeans":
        clustering_model = BisectingKMeans(
            n_clusters=nfolds, random_state=random_state, **kwargs
        )
    clustering_model.fit(lon_lat)
    cluster_labels = clustering_model.predict(lon_lat)

    lon_lat["folds"] = cluster_labels + 1
    lon_lat[name] = gdf_sp[name]
    lon_lat_valid = lon_lat[[name, "folds"]]
    # Assign the folds-clusters to the original gdf
    gdf_kfold_clusters = gdf.merge(lon_lat_valid, on=name, how="left")

    return gdf_kfold_clusters


import geopandas as gpd
import pkg_resources


def load_ames():
    # Load ames.geojson file
    filepath = pkg_resources.resource_filename(__name__, "data/ames.geojson")
    data = gpd.read_file(filepath)

    return data


from typing import Union

import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
from sklearn.model_selection import LeaveOneGroupOut


def spatial_kfold_plot(
    X: Union[np.ndarray, pd.DataFrame, pd.Series],
    geometry: Union[pd.Series, pd.DataFrame],
    groups: Union[np.ndarray, pd.Series],
    fold_num: int,
    ax=None,
    **kwargs,
):
    """
    Generate a plot differentiating between the train and test data during the cross validation for a specific fold

    Parameters
    ----------
    X : pandas DataFrame
        The feature data.
    geometry : pd.DataFrame or pandas Series
        geometry
    group : np.ndarray, pd.Series
       values in the column in X that defines the spatial resmaple groups.
    fold_num : int
        fold number

    Returns
    -------
    plt plot
        Plot illustrating the cross validation for specific fold.
    """
    n_folds = len(np.unique(list(groups.values)))

    if fold_num > n_folds:
        raise ValueError(
            f"The provided humber of folds {fold_num} is out of range. The number of existing folds is equal to {n_folds}"
        )

    # Initialize the LeaveOneGroupOut
    spatial_kfold = LeaveOneGroupOut()
    # Iterate over the training and testing indices
    for idx, (train_index, test_index) in enumerate(
        spatial_kfold.split(X, geometry, groups=groups)
    ):
        if idx != (fold_num - 1):
            None
        elif idx == (fold_num - 1):
            X_train, X_test = geometry.loc[train_index], geometry.loc[test_index]
            gdf_train = gpd.GeoDataFrame(X_train)
            gdf_test = gpd.GeoDataFrame(X_test)
            # Add corresponding train - test
            gdf_train["folds"] = "train"
            gdf_test["folds"] = "test"
            # Combine as single gdf
            gdf_cv = pd.concat([gdf_train, gdf_test])
            if ax:
                show = False
            else:
                show = True
            ax = plt.gca()
            gdf_cv.plot(column="folds", legend=True, ax=ax, **kwargs)
            ax.set_title(f"Fold {idx + 1 }", fontweight="bold")

            # Show the plot if no axes object was provided
            if show:
                plt.show()


from typing import Union

import numpy as np
import pandas as pd
from sklearn.model_selection import LeaveOneGroupOut


def spatial_kfold_stats(
    X: Union[np.ndarray, pd.DataFrame, pd.Series],
    y: Union[np.ndarray, pd.Series, pd.DataFrame],
    groups: Union[np.ndarray, pd.Series],
):
    """
    Generate a DataFrame with the number of train and test samples in each split of a spatial resampling procedure.

    Parameters
    ----------
    X : np.ndarray or pd.DataFrame or pd.Series
        The feature data.
    y : np.ndarray or pd.Series or pd.DataFrame
        The target values.
    group : np.ndarray or pd.Series
       values in the column in X that defines the spatially resmapled groups.

    Returns
    -------
    pandas DataFrame
        A DataFrame with the number of train and test samples in each split of the spatial resampling procedure.

    This function uses LeaveOneGroupOut from the scikit-learn documentation to ensure a leave-location-out” (LLO) procedure
    over a predifined group of folds:
     >> Each Group of clustered, blocked or user defined locations are used during the testing
        https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.LeaveOneGroupOut.html
    """
    df_list = {"split": [], "train": [], "test": []}
    spatial_kfold = LeaveOneGroupOut()

    for idx, (train_index, test_index) in enumerate(
        spatial_kfold.split(X, y=None, groups=groups)
    ):
        if isinstance(X, pd.DataFrame):
            X_train, X_test = X.loc[train_index], X.loc[test_index]
        elif isinstance(X, pd.Series):
            X_train, X_test = X.loc[train_index], X.loc[test_index]
        else:
            X_train, X_test = X[train_index], X[test_index]
        df_list["split"].append(idx + 1)
        df_list["train"].append(len(X_train))
        df_list["test"].append(len(X_test))
        kfold_splits = pd.DataFrame(df_list)

    return kfold_splits


