from typing import Callable, List

import torch
from scipy.stats import ortho_group


def custom_rotation(
    points: torch.Tensor,
    angle: torch.Tensor = torch.tensor(45.0),
) -> torch.Tensor:
    theta = torch.deg2rad(angle)
    rotation_matrix = torch.tensor([[torch.cos(theta), -torch.sin(theta)], [torch.sin(theta), torch.cos(theta)]])
    return points @ rotation_matrix.T


def custom_isotropic_scale(
    points: torch.Tensor,
    scale_factor: torch.Tensor = torch.as_tensor(0.8),
) -> torch.Tensor:
    return points * scale_factor


def custom_translation(
    points: torch.Tensor,
    translation_vector: torch.Tensor = torch.as_tensor([3.0, 3.0]),
) -> torch.Tensor:
    return points + translation_vector


def custom_orthogonal(points: torch.Tensor) -> torch.Tensor:
    orthogonal_matrix = torch.as_tensor(ortho_group.rvs(dim=points.shape[1]), dtype=torch.float32)
    return points @ orthogonal_matrix


def permute_indices(
    points: torch.Tensor,
) -> torch.Tensor:
    if points.shape[1] == 2:
        points[:, [0, 1]] = points[:, [1, 0]]
    else:
        random_perm = torch.randperm(points.shape[1])
        points = points[:, random_perm]
    return points


def affine_transformation(
    points: torch.Tensor,
) -> torch.Tensor:
    matrix = torch.randn(points.shape[1], points.shape[1])
    points = points @ matrix
    return custom_translation(points)


def linear_transformation(
    points: torch.Tensor,
) -> torch.Tensor:
    matrix = torch.randn(points.shape[1], points.shape[1])
    points = points @ matrix
    return points


def apply_transformations(
    data_points: torch.Tensor,
    transf_funcs: List[Callable[[torch.Tensor], torch.Tensor]],
    **kwargs,
) -> torch.Tensor:
    transformed_points = data_points
    for transf_func in transf_funcs:
        transformed_points = transf_func(transformed_points, **kwargs)
    return transformed_points
