from typing import Any, Dict, Optional, Union, List, Callable, Type
from collections import deque
import torch


class LinearSchedule:

    def __init__(self, init: float, final: Optional[float] = None, steps: Optional[int] = None):
        super().__init__()

        self.init = init
        self.final = final
        self.steps = steps
        self.constant = final is None and steps is None

        if not self.constant:
            self.decreasing = final < init
            self.delta = (final - init) / steps

    def value(self, step):

        if self.constant:
            return self.init

        if self.decreasing:
            return max(self.init + self.delta * step, self.final)
        else:
            return min(self.init + self.delta * step, self.final)


class MovingStatistics:

    def __init__(self, key_len_dict: dict):

        self._dict = {}
        for k, l in key_len_dict.items():
            self._dict[k] = deque([], maxlen=l)

    def __getitem__(self, key):
        return self._dict[key]

    def regist(self, key: str, length: int):

        self._dict[key] = deque([], maxlen=length)


def update_optimizer(optimizer: torch.optim.Optimizer, key: str, value: Union[List[float], float]):

    if type(value) == float:
        value = [value] * len(optimizer.param_groups)

    assert len(optimizer.param_groups) == len(value)

    for i, group in enumerate(optimizer.param_groups):
        if key in group:
            group[key] = value[i]


def shuffle_matrix(shape: torch.Size, device: Union[torch.device, str] = 'cpu'):
    assert len(shape) == 2
    return torch.argsort(torch.randn(shape, device=device), dim=1)


@torch.no_grad()
def entropy_from_policy(pol: torch.Tensor, eps=1e-8):
    return - (pol * (pol + eps).log()).sum(dim=1).mean().item()


@torch.no_grad()
def entropy_from_log_policy(log_pol: torch.Tensor):
    return - (log_pol.exp() * log_pol).sum(dim=1).mean().item()


@torch.no_grad()
def policy_smoothing(pol: torch.Tensor, eps: float):
    actions = pol.shape[-1]
    return (1 - eps) * pol + eps / actions


def batch_vdot(x: torch.Tensor, y: torch.Tensor):
    return torch.bmm(x.unsqueeze(1), y.unsqueeze(-1)).squeeze()


def get_paired(x: torch.Tensor, x_n: torch.Tensor, splits: List[int]):
    _tmp = []
    for _x, _xn in zip(x.split(splits), x_n):
        _tmp.append(_x[1:])
        _tmp.append(_xn[None])
    return x, torch.cat(_tmp)


@torch.no_grad()
@torch.compile
def tree_backup_target(
        act: torch.Tensor,
        rew: torch.Tensor,
        target_q: torch.Tensor,
        target_pol: torch.Tensor,
        target_v_next: torch.Tensor,
        dones: List[bool],
        splits: List[int],
        gamma: float):

    target_v = (target_q * target_pol).sum(dim=-1)

    _v_next = []
    for _tv, _tvn, _d in zip(target_v.split(splits), target_v_next, dones):
        _v_next.append(_tv[1:])
        _v_next.append(_tvn * (1 - float(_d)))

    _v_next = torch.cat(_v_next)

    q_act = target_q.gather(-1, act[:, None]).squeeze(-1)
    diff = rew + gamma * _v_next - q_act
    p_act = target_pol.gather(-1, act[:, None]).squeeze(-1)

    result = []
    for _diff, _p, s in zip(diff.split(splits), p_act.split(splits), splits):
        _tmp = []
        _last = 0
        for i in range(s):
            _tmp.append(_diff[-1 - i] + _last * gamma * _p[-i])
            _last = _tmp[-1]
        _tmp.reverse()
        result += _tmp
    return torch.stack(result) + q_act
