import math
from typing import Callable, List, Tuple

import numpy as np
import torch


def generate_grid(npoints: int = 12, starting_point: int = -1, ending_point: int = 1) -> Tuple[torch.Tensor, List[int]]:
    x = torch.linspace(starting_point, ending_point, npoints)
    y = torch.linspace(starting_point, ending_point, npoints)
    X, Y = torch.meshgrid(x, y)
    data_points = torch.vstack([Y.ravel(), X.ravel()]).T
    return torch.as_tensor(data_points), [0, npoints // 2]


def generate_diagonal(npoints: int = 12) -> Tuple[torch.Tensor, List[int]]:
    data_points = torch.zeros((npoints, 2))
    data_points[:, 0] = torch.linspace(-1, 1, npoints)
    data_points[:, 1] = torch.linspace(-1, 1, npoints)
    return torch.as_tensor(data_points), [0, npoints // 2]


def generate_spiral(npoints: int = 12, nturns: int = 5, max_radius: int = 10) -> Tuple[torch.Tensor, List[int]]:
    theta = torch.linspace(0, 2 * np.pi * nturns, npoints)  # Generate angles
    radius = torch.linspace(0, max_radius, npoints)  # Generate radii
    x = radius * torch.cos(theta)  # Compute x-coordinates
    y = radius * torch.sin(theta)  # Compute y-coordinates
    data_points = torch.stack([x, y], dim=1)
    return torch.as_tensor(data_points), [0, npoints // 2]


def generate_square(npoints: int = 12) -> Tuple[torch.Tensor, List[int]]:
    side_length = npoints // 4 + 1
    half_side = side_length // 2 + 1
    x = torch.cat(
        [
            torch.linspace(-half_side, half_side, side_length),  # Top side
            torch.full((side_length,), half_side),  # Right side
            torch.linspace(half_side, -half_side, side_length),  # Bottom side
            torch.full((side_length,), -half_side),  # Left side
        ]
    )
    y = torch.cat(
        [
            torch.full((side_length,), half_side),  # Top side
            torch.linspace(half_side, -half_side, side_length),  # Right side
            torch.full((side_length,), -half_side),  # Bottom side
            torch.linspace(-half_side, half_side, side_length),  # Left side
        ]
    )
    return torch.stack([x, y], dim=1), [0, (npoints // 4) + 1]


def generate_node(npoints: int = 12) -> Tuple[torch.Tensor, List[int]]:
    angles = torch.linspace(0, 2 * math.pi, npoints)
    radius = torch.linspace(-1, 1, npoints)
    x = radius * torch.cos(angles)
    y = radius * torch.sin(angles)
    return torch.stack([x, y], dim=1), [0, npoints // 2]


def generate_circle(npoints: int = 12) -> Tuple[torch.Tensor, List[int]]:
    angles = torch.linspace(0, 2 * math.pi, npoints + 1)[:-1]
    radius = torch.ones(npoints)
    x = radius * torch.cos(angles)
    y = radius * torch.sin(angles)
    return torch.stack([x, y], dim=1), [0, npoints // 4]


def generate_ellipse(
    npoints: int = 12,
    a: float = 1.0,
    b: float = 0.5,
) -> Tuple[torch.Tensor, List[int]]:
    angles = torch.linspace(0, 2 * math.pi, npoints + 1)[:-1]
    x = a * torch.cos(angles)
    y = b * torch.sin(angles)
    return torch.stack([x, y], dim=1), [0, npoints // 4]


def generate_cluster(
    npoints: int = 12,
    ncluster: int = 3,
    centers: torch.Tensor = torch.tensor([[-0.5, 0.5], [0.5, 0.5], [0.0, -0.5]]),
) -> Tuple[torch.Tensor, List[int]]:
    cluster_size = npoints // ncluster
    data_points = torch.empty((npoints, 2))
    for i in range(3):
        start = i * cluster_size
        end = start + cluster_size
        data_points[start:end, :] = torch.randn(cluster_size, 2) * 0.1 + centers[i]
    return torch.as_tensor(data_points), [0, npoints // 4]


def generate_random(npoints: int = 12, side_length: int = 2) -> Tuple[torch.Tensor, List[int]]:
    data_points = torch.rand((npoints, side_length)) * 2 - 1
    return torch.as_tensor(data_points), [0, npoints // 2]


def generate_data_points(
    npoints: int,
    space_pattern: Callable,
    **kwargs,
) -> Tuple[torch.Tensor, List[int]]:
    return space_pattern(npoints=npoints, **kwargs)
