import random
from typing import Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.linalg import logm
from scipy.stats import beta as scipy_beta
from torch.distributions import Beta


# Ref: https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/utilities/seed.html#seed_everything
def seed_all(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def fig_to_img(fig):
    img = np.frombuffer(fig.canvas.buffer_rgba(), dtype="uint8")
    img = img.reshape(fig.canvas.get_width_height()[::-1] + (4,))
    img = img[:, :, :3]

    plt.close(fig)
    return img


def blogm(mat):
    dtype = mat.dtype
    out = np.array([logm(m, disp=False)[0] for m in mat.cpu().numpy()])
    return torch.from_numpy(out).to(dtype=dtype, device=mat.device)


def safe_blogm(mat, eps=1e-6):
    I = torch.eye(2, device=mat.device, dtype=mat.dtype).unsqueeze(0)

    mat_stabilized = mat + eps * I.expand_as(mat)
    return blogm(mat_stabilized)


def expand_like(x, y):
    y_dim = y.dim()

    x = x.view(-1, *[1] * (y_dim - 1))

    return x.expand_as(y).contiguous()


def batched_div(v_func, x, t):
    orig_shape = x.shape
    x = x.requires_grad_(True)

    x_flat = x.view(orig_shape[0], -1).requires_grad_(True)

    # Create wrapper that handles reshaping
    def v_func_flat(x_f, t_):
        x_mat = x_f.view(orig_shape)
        v_mat = v_func(x_mat, t_)
        return v_mat.view(orig_shape[0], -1)

    v_flat = v_func_flat(x_flat, t)

    div = 0.0
    for i in range(v_flat.shape[1]):
        div += torch.autograd.grad(
            v_flat[:, i],
            x_flat,
            grad_outputs=torch.ones_like(v_flat[:, i]),
            retain_graph=True,
        )[0][:, i]

    return div


def batched_object_div(v_func, x, t):
    orig_shape = x.shape
    x = x.detach().requires_grad_(True)

    x_flat = x.view(orig_shape[0], -1).requires_grad_(True)

    # Create wrapper that handles reshaping
    def v_func_flat(x_f, t_):
        x_mat = x_f.view(orig_shape)
        A = v_func(x_mat, t_)

        v = x_mat @ A.transpose(-2, -1)

        return v.view(orig_shape[0], -1)

    v_flat = v_func_flat(x_flat, t)

    div = torch.zeros(orig_shape[0], device=x.device, dtype=x.dtype)
    for i in range(x_flat.shape[1]):
        div += torch.autograd.grad(
            v_flat[:, i],
            x_flat,
            grad_outputs=torch.ones_like(v_flat[:, i]),
            retain_graph=True,
        )[0][:, i]

    return div


def sample_random_batch(dataset, batch_size, device, return_transform=False):
    rand_idx = torch.randint(len(dataset), (batch_size,))
    rand_batch = dataset.data[rand_idx]
    if return_transform:
        rand_transform = dataset.transform[rand_idx]
        if isinstance(rand_batch, torch.Tensor):
            rand_batch = rand_batch.to(device)
            rand_transform = rand_transform.to(device)
        else:
            rand_batch = torch.from_numpy(rand_batch).to(device)
            rand_transform = torch.from_numpy(rand_transform).to(device)

        return rand_batch, rand_transform

    else:
        if isinstance(rand_batch, torch.Tensor):
            rand_batch = rand_batch.to(device)
        else:
            rand_batch = torch.from_numpy(rand_batch).to(device)
        return rand_batch


def nonlinear_indices(N, n_elements=10, power=8):
    """
    Generate n_elements unique indices from 0 to N with non-linear spacing.
    More indices are clustered near N.

    Parameters:
    - N: maximum index value
    - n_elements: number of elements to return (default 10)
    - power: controls clustering (higher = more clustering near N)

    Returns:
    - List of unique integer indices
    """
    if n_elements > N + 1:
        raise ValueError(f"Cannot generate {n_elements} unique indices from 0 to {N}")

    # Create cumulative distribution
    x = np.linspace(0, 1, n_elements)
    y = 1 - (1 - x) ** power

    # Map to available slots ensuring uniqueness
    indices = []
    prev_idx = -1

    for i, val in enumerate(y):
        # Calculate target index
        target = int(val * N)

        # Ensure minimum spacing of 1
        if target <= prev_idx:
            target = prev_idx + 1

        # Ensure we don't exceed N
        remaining = n_elements - i - 1
        max_allowed = N - remaining
        target = min(target, max_allowed)

        indices.append(target)
        prev_idx = target

    return indices


def sample_power_distribution(
    a: float,
    b: float,
    skewness: float = 3.0,
    size: Union[int, Tuple[int, ...]] = 1,
    device: Optional[torch.device] = None,
    dtype: torch.dtype = torch.float32,
    requires_grad: bool = False,
) -> torch.Tensor:
    """
    Sample from a continuous power distribution skewed towards upper bound b.

    Uses inverse transform method: X = a + (b - a) * U^(1/n)
    where U ~ Uniform(0, 1) and n is the skewness parameter.

    Args:
        a: Lower bound of distribution
        b: Upper bound of distribution
        skewness: Parameter controlling skew (higher = more skewed towards b)
        size: Shape of output tensor (int or tuple of ints)
        device: Device to place tensor on
        dtype: Data type of output tensor
        requires_grad: Whether to track gradients

    Returns:
        Tensor of samples from the skewed distribution

    Example:
        >>> samples = sample_power_distribution(0, 100, skewness=3, size=(1000,))
        >>> print(f"Mean: {samples.mean():.2f}")  # Should be around 75
    """
    if a >= b:
        raise ValueError(f"Lower bound {a} must be less than upper bound {b}")
    if skewness <= 0:
        raise ValueError(f"Skewness parameter must be positive, got {skewness}")

    # Convert size to tuple if it's an integer
    if isinstance(size, int):
        size = (size,)

    # Generate uniform random samples
    u = torch.rand(size, device=device, dtype=dtype, requires_grad=requires_grad)

    # Apply inverse transform
    samples = a + (b - a) * torch.pow(u, 1.0 / skewness)

    return samples


def icdf_power(
    a: float,
    b: float,
    n: int,
    skewness: float = 3.0,
    device: Optional[torch.device] = None,
    dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Create N evenly spaced values that follow the power distribution.

    Args:
        a: Lower bound
        b: Upper bound
        n: Number of points
        skewness: Parameter controlling skew
        device: Device to place tensors on
        dtype: Data type of output tensors

    """

    # Create evenly spaced quantiles (more values where probability is higher)
    quantiles = torch.linspace(0, 1, n, device=device, dtype=dtype)
    # Apply inverse CDF: X = a + (b - a) * q^(1/n)
    values = a + (b - a) * torch.pow(quantiles, 1.0 / skewness)

    return values


def sample_beta_distribution(alpha, beta, n, device="cpu", dtype=torch.float32):

    # Move parameters to device
    alpha = torch.tensor(alpha, device=device, dtype=dtype)
    beta = torch.tensor(beta, device=device, dtype=dtype)

    # Create distribution and sample
    dist = Beta(alpha, beta)
    samples = dist.sample((n,))

    return samples


def icdf_beta(
    alpha: float,
    beta: float,
    n: int,
    device: Optional[torch.device] = None,
    dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # Create evenly spaced quantiles (more values where probability is higher)
    quantiles = np.linspace(0, 1, n)
    values = scipy_beta(alpha, beta).ppf(quantiles)

    values = torch.from_numpy(values).to(device=device, dtype=dtype)

    return values
