from typing import Any

import torch
from torch import Tensor
from torch.nn.functional import log_softmax, one_hot, softmax

from mdp.mdp_dataset import MDPDataset, MDPDatasetTorch


def get_discounted_rewards(dataset: MDPDatasetTorch | MDPDataset, gamma: float) -> Tensor:
    rewards = dataset.rewards.detach()
    values = torch.zeros((dataset.n_envs, dataset.n_steps), device=dataset.device)
    values[:, -1] = rewards[:, -1]
    for t in reversed(range(dataset.n_steps - 1)):
        values[:, t] = rewards[:, t] + gamma * values[:, t + 1]

    return values


def gae(dataset: MDPDatasetTorch | MDPDataset, values: Tensor, gamma: float, lam: float, n_steps: int) -> Tensor:
    assert dataset.n_steps % n_steps == 0, "dataset.n_steps must be a multiple of n_steps"

    rewards = normalize(dataset.rewards.detach())

    advantages = torch.zeros((dataset.n_envs, dataset.n_steps), device=dataset.device)
    last_advantage = torch.zeros(dataset.n_envs, device=dataset.device)
    last_value = values[:, -1]

    for t in reversed(range(dataset.n_steps)):
        if (t - 1) % n_steps == 0 and t - 1 != dataset.n_steps:
            last_value = values[:, t]
            last_advantage = 0 * last_advantage

        delta = rewards[:, t] + gamma * last_value - values[:, t]

        last_advantage = delta + gamma * lam * last_advantage
        last_value = values[:, t]

        advantages[:, t] = last_advantage

    return advantages


def ae(dataset: MDPDatasetTorch | MDPDataset, values: Tensor, gamma: float) -> Tensor:
    advantages = torch.zeros((dataset.n_envs, dataset.n_steps), device=dataset.device)

    last_value = values[:, -1]

    for t in reversed(range(dataset.n_steps)):
        advantages[:, t] = dataset.rewards[:, t] + gamma * last_value - values[:, t]
        last_value = values[:, t]

    return advantages


def normalize(x: Tensor, dim=None) -> Tensor:
    return (x - x.mean(dim=dim)) / (x.std(dim=dim) + 1e-8)


def project_onto_simplex_batch(v: Tensor) -> Tensor:
    """Project batched vector onto probability simplex."""
    orig_shape = v.shape
    v = v.reshape(-1, v.shape[-1])

    v_sorted = torch.sort(v, dim=-1, descending=True).values
    cssv = torch.cumsum(v_sorted, dim=-1) - 1
    ind = torch.arange(v.shape[-1], device=v.device)
    cond = v_sorted > cssv / (ind + 1).float()

    rho = torch.sum(cond, dim=-1) - 1
    theta = torch.gather(cssv, 1, rho.unsqueeze(1)).squeeze() / (rho + 1).float()

    return torch.clamp(v - theta.unsqueeze(1), min=0).reshape(orig_shape)


def compute_fisher_matrix(policy: Tensor, states: Tensor) -> Tensor:
    n_steps = states.shape[0]
    logits = policy[states]  # (n_steps, n_actions)
    probs = softmax(logits, dim=-1)  # (n_steps, n_actions)
    log_probs = log_softmax(logits, dim=-1)  # (n_steps, n_actions)

    grad_log_probs = []
    for action in range(logits.shape[1]):
        action_mask = one_hot(torch.tensor([action], device=policy.device), logits.shape[1]).float()
        action_log_probs = torch.sum(log_probs * action_mask, dim=-1)
        grad_a = torch.autograd.grad(action_log_probs.sum(), policy, retain_graph=True)[0]
        grad_log_probs.append(grad_a.flatten())

    grad_log_probs = torch.stack(grad_log_probs, dim=0)  # (n_actions, n_states*n_actions)

    fim = torch.zeros((grad_log_probs.shape[1], grad_log_probs.shape[1]), device=policy.device)
    for step in range(n_steps):
        for action in range(probs.shape[1]):
            grad = grad_log_probs[action].unsqueeze(1)  # (D, 1)
            fim += probs[step, action] * grad @ grad.T

    fim = fim / n_steps
    return fim
