from typing import Any
import math

import torch
from torch import nn
import einops


def _no_grad_trunc_normal_(tensor : torch.Tensor, 
                           mean : float, 
                           std : float, 
                           a : float, 
                           b : float) -> torch.Tensor:
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x : float):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def trunc_normal_(tensor : torch.Tensor, 
                  mean : float =0., 
                  std : float =1., 
                  a : float =-2.,
                b : float =2.) -> torch.Tensor:
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)


def to_fp32(state_dict: dict[str, Any]) -> dict[str, Any]:
    return {k: v.float() for k, v in state_dict.items()}



def pos_encode_time(n_times, n_dim, max_n_times, out: torch.Tensor | None = None):
    """
    1-dimensional positional encoding.

    Args:
        n_times: int
            Number of time samples to encode.
        n_dim: int
            Number of dimensions of the positional encoding. Must be even.
        max_n_times: int
            The largest possible number of time samples to encode.
            Used to scale the positional encoding.
        out: torch.Tensor | None
            The output tensor. If None, a new tensor will be created.
            Must have shape (n_times, n_dim).
    Returns:
        out: (n_times, n_dim)
    """
    out = torch.empty((n_times, n_dim), dtype=torch.float32) if out is None else out
    assert out.shape == (n_times, n_dim)
    assert n_dim % 2 == 0
    position = torch.arange(n_times, device=out.device).unsqueeze(1)
    div_term = torch.exp(
        torch.arange(0, n_dim, 2, device=out.device) * (-math.log(max_n_times) / n_dim)
    )
    out[:, 0::2] = torch.sin(position * div_term)
    out[:, 1::2] = torch.cos(position * div_term)
    return out


def pos_encode_continuous(
    x, x_min, x_max, n_dim, out: torch.Tensor | None = None
) -> torch.Tensor:
    """
    1-dimensional positional encoding.

    Args:
        x: float
            The position to encode.
        x_min: float
            The minimum possible value of x.
        x_max: float
            The maximum possible value of x.
        n_dim: int
            Number of dimensions of the positional encoding. Must be even.
        out: torch.Tensor | None
            The output tensor. If None, a new tensor will be created.
            Must have shape (n_dim,).
    Returns:
        out: (n_dim,)
    """
    out = torch.empty((n_dim,), dtype=torch.float32) if out is None else out
    assert out.shape == (n_dim,)
    assert n_dim % 2 == 0
    div_term = torch.exp(
        (1 - torch.arange(0, n_dim, 2, device=out.device) / n_dim) * 2 * math.pi
    )
    xx = (x - x_min) / (x_max - x_min)
    out[0::2] = torch.sin(xx * div_term)
    out[1::2] = torch.cos(xx * div_term)
    return out


def pos_encode_continuous_batched(
    x, x_min, x_max, n_dim, out: torch.Tensor | None = None
) -> torch.Tensor:
    """
    1-dimensional positional encoding.

    Args:
        x: torch.Tensor of shape (*xdims)
            The positions to encode.
        x_min: float
            The minimum possible value of x.
        x_max: float
            The maximum possible value of x.
        n_dim: int
            Number of dimensions of the positional encoding. Must be even.
        out: torch.Tensor | None
            The output tensor. If None, a new tensor will be created.
            Must have shape (*xdims, n_dim).
    Returns:
        out: (*xdims, n_dim)
    """
    out = torch.empty(x.shape + (n_dim,), dtype=torch.float32) if out is None else out
    assert out.shape == x.shape + (n_dim,)
    assert n_dim % 2 == 0
    div_term = torch.exp(
        (1 - torch.arange(0, n_dim, 2, device=out.device) / n_dim) * 2 * math.pi
    )
    xx = torch.as_tensor((x - x_min) / (x_max - x_min)).unsqueeze(-1)
    out[..., 0::2] = torch.sin(xx * div_term)
    out[..., 1::2] = torch.cos(xx * div_term)
    return out


def batched_index_select(x : torch.Tensor, index : torch.Tensor, selected_dim : int =1) -> torch.Tensor:
    """
    Similar to torch.index_select(dim=0) but with batch dimension

    Args:
        x: (*bdims, selected_dim_size, *xdims)
        index: (*bdims, *idims)
        selected_dim: int
    Returns:
        out: (batch_size, *idims, *xdims)
    """
    assert (
        x.shape[:selected_dim] == index.shape[:selected_dim]
    ), "x and index must share the same number of batch elements"
    bdims = {f"b{i}": d for i, d in enumerate(x.shape[:selected_dim])}
    idims = {f"i{i}": d for i, d in enumerate(index.shape[selected_dim:])}
    xdims = {f"x{i}": d for i, d in enumerate(x.shape[selected_dim + 1 :])}
    bdims_str = " ".join(bdims.keys())
    xdims_str = " ".join(xdims.keys())
    idims_str = " ".join(idims.keys())
    index2 = einops.repeat(
        index, f"{bdims_str} ... -> {bdims_str} (...) {xdims_str}", **xdims
    )
    gathered = x.gather(dim=selected_dim, index=index2)
    return einops.rearrange(
        gathered,
        f"{bdims_str} ({idims_str}) ... -> {bdims_str} {idims_str} ...",
        **idims,
    )


def batched_index_unselect(x, index):
    """
    Reverse function of batched_index_select

    Args:
        x: (*bdims, selected_dim_size, *xdims)
            The output of batched_index_select
        index: (*bdims, selected_dim_size)
            The index used to select the elements from x.
            Only works if the index is 1D.
    Returns:
        out: (batch_size, selected_dim, *xdims)
    """
    selected_dim = index.ndim - 1
    assert (
        x.shape[: selected_dim + 1] == index.shape
    ), "x and index must share the same number of batch elements"
    assert index.max() == x.shape[selected_dim] - 1, "sparse index"
    xdims = {f"x{i}": d for i, d in enumerate(x.shape[selected_dim + 1 :])}
    xdims_str = " ".join(xdims.keys())
    rev_index = torch.argsort(index, dim=selected_dim)
    rev_index = einops.repeat(rev_index, f"... i -> ... i {xdims_str}", **xdims)
    gathered = x.gather(dim=selected_dim, index=rev_index)
    return gathered


def get_euclidean_diatance_matrix(pos):
    """
    same as torch.cdist when batch dimension is present

    :param pos: torch.array, shape (n_channels, 3) or  (batch_size, n_channels, 3)
        XYZ positions of the EEG channels
    :return: torch.array, shape (n_channels, n_channels) or (batch_size, n_channels, n_channels)
        the euclidean distance matrix between the different EEG channels
    """
    return ((pos.unsqueeze(dim=-2) - pos.unsqueeze(dim=-3)) ** 2).sum(dim=-1) ** 0.5


def get_covariance_matrix(x):
    """

    :param x: torch.array, shape (n_channels, n_time_samples) or (batch_size, n_channels, n_time_samples)
    :return: torch.array, shape (n_channels, n_channels) or (batch_size, n_channels, n_channels)
    """
    x -= x.mean(dim=-1, keepdim=True)
    n_time_samples = x.shape[-1]
    return (x.unsqueeze(dim=-2) * x.unsqueeze(dim=-3)).sum(dim=-1) / (
        n_time_samples - 1
    )
