import os
import numpy as np
from tqdm import tqdm
from typing import List
from source.constants import ARROW_PATH


def _render_smooth_line(x0, y0, x1, y1, grid_size=32):
    """
    Render a line in [0, 1]x[0, 1] with smooth edges on a grid.

    Parameters:
        x0, y0, x1, y1: Line endpoints in [0, 1].
        grid_size: The size of the grid (default is 32x32).

    Returns:
        A 2D numpy array representing the grid with smooth edges.
    """
    # Map coordinates to grid space
    x0, y0 = x0 * (grid_size - 1), y0 * (grid_size - 1)
    x1, y1 = x1 * (grid_size - 1), y1 * (grid_size - 1)

    grid = np.zeros((grid_size, grid_size))

    # Wu's Line Algorithm
    def fpart(x):
        return x - np.floor(x)

    def rfpart(x):
        return 1 - fpart(x)

    def plot_intensity(grid, x, y, intensity):
        """
        Plot the intensity at a subpixel location (x, y) onto the grid.
        """
        ix, iy = int(x), int(y)
        if 0 <= ix < grid.shape[1] and 0 <= iy < grid.shape[0]:
            grid[iy, ix] += intensity

    steep = abs(y1 - y0) > abs(x1 - x0)
    if steep:
        x0, y0 = y0, x0
        x1, y1 = y1, x1

    if x0 > x1:
        x0, x1 = x1, x0
        y0, y1 = y1, y0

    dx = x1 - x0
    dy = y1 - y0
    gradient = dy / dx if dx != 0 else 1

    # First endpoint
    xend = round(x0)
    yend = y0 + gradient * (xend - x0)
    xgap = rfpart(x0 + 0.5)
    xpxl1 = int(xend)
    ypxl1 = int(yend)
    if steep:
        plot_intensity(grid, ypxl1, xpxl1, rfpart(yend) * xgap)
        plot_intensity(grid, ypxl1 + 1, xpxl1, fpart(yend) * xgap)
    else:
        plot_intensity(grid, xpxl1, ypxl1, rfpart(yend) * xgap)
        plot_intensity(grid, xpxl1, ypxl1 + 1, fpart(yend) * xgap)

    # Main loop
    intery = yend + gradient
    for x in range(xpxl1 + 1, int(x1) + 1):
        if steep:
            plot_intensity(grid, int(intery), x, rfpart(intery))
            plot_intensity(grid, int(intery) + 1, x, fpart(intery))
        else:
            plot_intensity(grid, x, int(intery), rfpart(intery))
            plot_intensity(grid, x, int(intery) + 1, fpart(intery))
        intery += gradient

    # Last endpoint
    xend = round(x1)
    yend = y1 + gradient * (xend - x1)
    xgap = fpart(x1 + 0.5)
    xpxl2 = int(xend)
    ypxl2 = int(yend)
    if steep:
        plot_intensity(grid, ypxl2, xpxl2, rfpart(yend) * xgap)
        plot_intensity(grid, ypxl2 + 1, xpxl2, fpart(yend) * xgap)
    else:
        plot_intensity(grid, xpxl2, ypxl2, rfpart(yend) * xgap)
        plot_intensity(grid, xpxl2, ypxl2 + 1, fpart(yend) * xgap)

    # Normalize grid
    grid = np.clip(grid, 0, 1)
    return grid


def _render_arrow(angle=0, grid_size=32, rng: np.random.Generator = None):
    """
    Draw an equilateral triangle using the render_smooth_line function.
    """

    # Rotation center
    cx, cy = 0.5, 0.5

    # Triangle vertices in normalized space
    v1 = (0.5, 0.9)  # Top vertex
    v2 = (0.3, 0.6)  # Bottom-left vertex
    v3 = (0.7, 0.6)  # Bottom-right vertex
    v4 = (0.5, 0.1)  # Bottom-right vertex

    def rotate_point(x, y, cx, cy, angle):
        """
        Rotate a point (x, y) around a center (cx, cy) by a given angle.

        Parameters:
            x, y: Point coordinates.
            cx, cy: Center of rotation.
            angle: Rotation angle in degrees.

        Returns:
            Rotated (x, y) coordinates.
        """
        angle_rad = np.radians(angle)
        cos_a, sin_a = np.cos(angle_rad), np.sin(angle_rad)
        x, y = x - cx, y - cy  # Translate to origin
        x_rot = x * cos_a - y * sin_a
        y_rot = x * sin_a + y * cos_a
        return x_rot + cx, y_rot + cy  # Translate back

    # Rotate each vertex around the center
    v1 = rotate_point(v1[0], v1[1], cx, cy, angle)
    v2 = rotate_point(v2[0], v2[1], cx, cy, angle)
    v3 = rotate_point(v3[0], v3[1], cx, cy, angle)
    v4 = rotate_point(v4[0], v4[1], cx, cy, angle)

    # random translation when rng is provided
    if rng is not None:
        dx, dy = rng.uniform(-0.1, 0.1), rng.uniform(-0.1, 0.1)
        v1 = (v1[0] + dx, v1[1] + dy)
        v2 = (v2[0] + dx, v2[1] + dy)
        v3 = (v3[0] + dx, v3[1] + dy)
        v4 = (v4[0] + dx, v4[1] + dy)

    # Initialize the grid
    grid = np.zeros((grid_size, grid_size))

    # Draw the arrow
    grid += _render_smooth_line(v1[0], v1[1], v2[0], v2[1], grid_size)
    grid += _render_smooth_line(v1[0], v1[1], v3[0], v3[1], grid_size)
    grid += _render_smooth_line(v1[0], v1[1], v4[0], v4[1], grid_size)

    # Normalize the grid
    grid = np.clip(grid, 0, 1)

    return grid


def _generate_data(
    n_samples: int = 90_000,
    pixels: int = 32,
    angle_range: List = [0, 180],
    seed: int = 2357,
):
    samples = np.zeros((n_samples, pixels, pixels))
    labels = np.zeros((n_samples))

    rng = np.random.default_rng(seed)
    angles = rng.uniform(*angle_range, n_samples)

    for n in tqdm(range(n_samples)):
        # render arrow image -> pass rng for random translation
        samples[n] = _render_arrow(angles[n], pixels, rng)
        labels[n] = angles[n]

    return samples, labels


def generate_and_save_data(save_path: str = os.path.join(ARROW_PATH, "arrows.npz")):
    os.makedirs(ARROW_PATH, exist_ok=True)

    samples, labels = _generate_data(
        n_samples=90_000, pixels=32, angle_range=[0, 180], seed=2357
    )

    # save as uint8
    samples *= 255
    samples = samples.astype(np.uint8)
    labels = labels.astype(np.float32)

    np.savez(save_path, samples=samples, labels=labels)


def get_data(seed=2357):
    data = np.load(os.path.join(ARROW_PATH, "arrows.npz"))
    samples = data["samples"].astype(np.float32)
    labels = data["labels"].astype(np.int32)

    # normalize samples
    samples /= 255
    # unsqueeze samples
    samples = samples[:, None, :, :]

    labels = labels.astype(np.float32) / 180

    # split into train, val and test
    rng = np.random.default_rng(seed)
    indices = np.arange(len(samples))
    rng.shuffle(indices)

    train_indices = indices[:60_000]
    val_indices = indices[60_000:75_000]
    test_indices = indices[75_000:]

    train_samples, train_labels = samples[train_indices], labels[train_indices]
    val_samples, val_labels = samples[val_indices], labels[val_indices]
    test_samples, test_labels = samples[test_indices], labels[test_indices]

    return (
        train_samples,
        train_labels,
        val_samples,
        val_labels,
        test_samples,
        test_labels,
    )
