from datetime import datetime

import torch

PRIMITIVES = (bool, str, int, float, type(None))


def snake_to_camel(snake_str: str) -> str:
    components = snake_str.split("_")
    return "".join(x.title() for x in components)


def camel_to_snake(camel_str: str) -> str:
    return "".join(["_" + i.lower() if i.isupper() else i for i in camel_str]).lstrip("_")


def is_primitive(obj) -> bool:
    return isinstance(obj, PRIMITIVES)


def timestamp_file_signature():
    return (datetime.now().isoformat()).replace(":", "").replace(".", "")


def distance_between_points(point1: list, point2: list):
    return (torch.tensor(point1) - torch.tensor(point2)).pow(2).sum().sqrt().item()


def distance_between_many_tensors(tensor1: torch.Tensor, tensor2: torch.Tensor):
    return (tensor1 - tensor2).norm(dim=-1).sqrt()


def generate_orthonormal_matrix(
    dim: int, dtype: torch.dtype = torch.float64, device: int = None
):
    orthogonal_matrix = torch.linalg.qr(
        torch.randn(dim, dim, device=device, dtype=dtype)
    )[0]
    return orthogonal_matrix / orthogonal_matrix.norm(dim=1)


def torch_uniform(low, high, shape, dtype, device):
    basic_uniform = torch.rand(*shape, dtype=dtype, device=device)
    return low + (high - low) * basic_uniform
