from functools import lru_cache
from typing import Callable, Type

import torch
from distributions import BaseDistribution, MixtureDistribution
from torch import Tensor
from torch.distributions import Categorical


def get_categorical_rejection_constant(
    proposal_dist: Type[Categorical],
    target_dist: Type[Categorical],
    delta: float = 0.05,
) -> Tensor:
    ratios: Tensor = target_dist.probs / (proposal_dist.probs + 1e-5)

    if delta > 0:
        sorted_ratios, indices = torch.sort(ratios, dim=-1, descending=True)

        sorted_target_probs = torch.gather(target_dist.probs, -1, indices)
        cumulative_mass = torch.cumsum(sorted_target_probs, dim=-1)

        mask = (cumulative_mass >= delta).float()

        ratios = sorted_ratios * mask
    elif delta == -1:
        max_prob_ind = torch.argmax(target_dist.probs, dim=-1, keepdim=True)
        ratios = ratios.gather(-1, max_prob_ind).clamp(min=1).squeeze(-1)
        return ratios

    return ratios.max(-1).values.clamp(1)


def _get_component_segments(
    dist: Type[BaseDistribution],
    x_grid: Tensor,
    _func: Callable,
) -> tuple[Tensor, Tensor]:
    # Compute midpoints of the grid
    x_mid = (x_grid[..., 1:] + x_grid[..., :-1]) / 2

    if len(dist.param_shape) != x_grid.ndim - 1:
        x_mid = x_mid.unsqueeze(-1)

    x_mid_shift = x_mid.movedim(2, 0)

    # Compute the derivative of the PDF at the midpoints
    px_derivative = dist.derivative(x_mid_shift)

    # Compute the PDF at the midpoints
    px = dist.log_prob(x_mid_shift).exp()

    px = px.movedim(0, 2)
    px_derivative = px_derivative.movedim(0, 2)

    # Compute left and right x values
    x_left = x_grid[..., :-1]
    x_right = x_grid[..., 1:]
    if len(dist.param_shape) != x_grid.ndim - 1:
        x_left = x_left.unsqueeze(-1)
        x_right = x_right.unsqueeze(-1)

    # Compute left and right y values using the derivative
    y_left = px_derivative * (x_left - x_mid) + px
    y_right = px_derivative * (x_right - x_mid) + px

    # Compute alternative left and right y values using the PDF
    y_left_alt = dist.log_prob(x_left.movedim(2, 0)).exp().movedim(0, 2)
    y_right_alt = dist.log_prob(x_right.movedim(2, 0)).exp().movedim(0, 2)

    # Combine the computed values using the provided function
    y_left = _func(y_left, y_left_alt)
    y_right = _func(y_right, y_right_alt)

    return y_left, y_right


def _get_segments_for_mixture(
    mixture_dist: Type[MixtureDistribution],
    x_grid: Tensor,
    _func: Callable,
) -> tuple[Tensor, Tensor]:
    y_left_batch, y_right_batch = _get_component_segments(mixture_dist.component_distribution, x_grid, _func)

    y_left_sum = (y_left_batch * mixture_dist.weights.unsqueeze(-2)).sum(dim=-1)
    y_right_sum = (y_right_batch * mixture_dist.weights.unsqueeze(-2)).sum(dim=-1)

    return y_left_sum, y_right_sum


def get_segments(
    dist: Type[BaseDistribution] | Type[MixtureDistribution],
    x_grid: Tensor,
    _func: Callable,
) -> tuple[Tensor, Tensor]:
    if type(dist).__name__ == "MixtureDistribution":
        return _get_segments_for_mixture(dist, x_grid, _func)
    else:
        return _get_component_segments(dist, x_grid, _func)


@lru_cache(maxsize=1)
def _get_grid_linspace(num_points: int, device: torch.device, shape: tuple[int]) -> Tensor:
    x = torch.linspace(0, 1, num_points).to(device)
    x = x.view(*([1] * (len(shape) - 1)), num_points).repeat(*shape[:-1], 1)
    return x


def get_grid(
    proposal_dist: Type[BaseDistribution],
    target_dist: Type[BaseDistribution],
    exact: bool,
    num_points: int = 0,
) -> Tensor:
    bounds = target_dist.percentile(torch.tensor([0.05, 0.95]), exact=exact)

    if num_points > 0:
        grid = _get_grid_linspace(num_points, bounds.device, bounds.shape)
        grid = grid * (bounds[..., 1:] - bounds[..., :-1]) + bounds[..., :-1]
        return grid

    target_mode = target_dist.mode()
    proposal_mode = proposal_dist.mode().expand_as(target_mode)

    if target_mode.ndim < bounds.ndim:
        target_mode = target_mode.unsqueeze(-1)
        proposal_mode = proposal_mode.unsqueeze(-1)

    target_inflections = target_dist.inflection_points()
    proposal_inflections = proposal_dist.inflection_points().expand_as(target_inflections)

    grid = torch.cat([
        bounds,
        proposal_mode,
        target_mode,
        proposal_inflections,
        target_inflections,
    ], dim=-1)

    grid = grid.sort(dim=-1).values
    return grid


def get_rejection_constant(
    proposal_dist: Type[BaseDistribution],
    target_dist: Type[BaseDistribution],
    exact: bool = True,
    num_points: int = 0,
    top_k: int = 1,
) -> Tensor:
    x_grid = get_grid(proposal_dist, target_dist, exact, num_points)

    target_left, target_right = get_segments(target_dist, x_grid, torch.maximum)
    proposal_left, proposal_right = get_segments(proposal_dist, x_grid, torch.minimum)

    mask = (proposal_left < 1e-3) | (proposal_right < 1e-3)

    ratio_left = (target_left / proposal_left).clamp(min=1)
    ratio_right = (target_right / proposal_right).clamp(min=1)

    ratio_left[mask] = 1
    ratio_right[mask] = 1

    ratio = torch.minimum(ratio_left[..., 1:], ratio_right[..., :-1])
    if top_k == 1:
        return ratio[..., 1:-1].max(-1).values.clamp(1)
    else:
        return ratio.topk(top_k, dim=-1).values[...,-1].clamp(1)
