from typing import List

import torch
from torch import Tensor
from torch.nn.functional import pad

from handlers.drawers.base_drawer import Drawer
from handlers.utils import convert_to_real_class, map_parameter_to_real


def pad_2d_to_nd_graph(
    x_grid: Tensor, y_grid: Tensor, dims: List[int], n: int, pad_value: int = 0
) -> Tensor:
    padded_points_until_second_dim = pad(
        x_grid.reshape(-1, 1), (dims[0], dims[1] - dims[0] - 1), value=pad_value
    )
    points_padded_including_second_dim = torch.cat(
        [padded_points_until_second_dim, y_grid.reshape(-1, 1)], dim=1
    )
    points_on_full_dim = pad(
        points_padded_including_second_dim, (0, n - dims[1] - 1), value=pad_value
    )
    return points_on_full_dim


def create_grid_points(
    x_lower_bounds: float,
    x_upper_bounds: float,
    y_lower_bounds: float,
    y_upper_bounds: float,
    num_of_points: int,
    dim_size: int,
    dims: list,
    device: int = None,
) -> Tensor:
    x_axis = torch.linspace(x_lower_bounds, x_upper_bounds, num_of_points, device=device)
    y_axis = torch.linspace(y_lower_bounds, y_upper_bounds, num_of_points, device=device)
    x_grid, y_grid = torch.meshgrid(x_axis, y_axis)
    points = pad_2d_to_nd_graph(x_grid, y_grid, tuple(dims), dim_size)
    return points


def convert_to_real_drawer(drawer: Drawer):
    convert_to_real_class(drawer)

    drawer.draw_data = map_parameter_to_real(drawer.draw_data)
    drawer.update_data = map_parameter_to_real(drawer.update_data)
    drawer.start_drawing = map_parameter_to_real(drawer.start_drawing)
    drawer.end_drawing = map_parameter_to_real(drawer.end_drawing)
    return drawer
