import math
from functools import partial
from torch import nn
from typing import Tuple, Dict, Any, Literal
import torch
from torch.distributions import Categorical, Normal


def _cosine_decay_warmup(iteration, warmup_iterations, total_iterations):
    """
    Linear warmup from 0 --> 1.0, then decay using cosine decay to 0.0
    """
    if iteration <= warmup_iterations:
        multiplier = iteration / warmup_iterations
    else:
        multiplier = (iteration - warmup_iterations) / (
            total_iterations - warmup_iterations
        )
        multiplier = 0.5 * (1 + math.cos(math.pi * multiplier))
    return multiplier


def cosine_annealing_with_warmup(optimizer, warmup_steps, total_steps):
    _decay_func = partial(
        _cosine_decay_warmup,
        warmup_iterations=warmup_steps,
        total_iterations=total_steps,
    )
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, _decay_func)
    return scheduler


class Tokenizer(nn.Module):
    def __init__(self, cfg: Dict[str, Any], rcfg) -> None:
        super().__init__()
        token_dim = cfg["token_dim"]
        self.rcfg = rcfg

        self.obs_emb = (
            nn.Embedding(cfg["obs_nums"], token_dim)
            if cfg["obs_is_concrete"]
            else nn.Linear(cfg["obs_dim"], token_dim)
        )
        self.act_emb = (
            nn.Embedding(cfg["act_nums"], token_dim)
            if cfg["act_is_concrete"]
            else nn.Linear(cfg["act_dim"], token_dim)
        )
        self.rwd_emb = (
            nn.Embedding(cfg["rwd_nums"], token_dim)
            if cfg["rwd_is_concrete"]
            else nn.Linear(cfg["rwd_dim"], token_dim)
        )
        self.cfg = cfg

        self.term_emb = nn.Embedding(2, token_dim)

    def forward(self, obs, acts, rwds, terms):
        if (
            self.rcfg.env.name in ["dark_key_to_door", "dark_room"]
            and self.rcfg.env.kind == "dense"
        ):
            rwds = rwds + (self.rcfg.env.grid_size - 1) * 2
        if (
            self.rcfg.env.name in ["dark_key_to_door", "dark_room", "reacher"]
            and self.rcfg.env.kind == "quick"
        ):
            rwds = torch.zeros_like(rwds, dtype=rwds.dtype, device=rwds.device)
        return (
            self.obs_emb(obs),
            self.act_emb(acts),
            self.rwd_emb(rwds),
            self.term_emb(terms),
        )

    def obs_to_int(self, obs):
        if self.rcfg.env.name in ["dark_room", "dark_key_to_door"]:
            return obs[..., 0] * self.rcfg.env.grid_size + obs[..., 1]
        return obs


class Header(nn.Module):
    def __init__(
        self,
        hidden_dim: int,
        tcfg: Dict[str, Any],
    ) -> None:
        super().__init__()
        self.head_mlp = nn.Linear(
            hidden_dim,
            tcfg["act_nums"] if tcfg["act_is_concrete"] else tcfg["act_dim"],
            bias=False,
        )
        self.n_step_comp = 4
        self.dcfg = tcfg

    def forward(self, hiddens):
        logits = self.head_mlp(hiddens[:, :: self.n_step_comp])

        return logits

    def to_act(self, logits):

        acts = (
            logits.tanh()
            if not self.dcfg["act_is_concrete"]
            else Categorical(logits=logits).sample()
        )

        return acts


class Context:
    def __init__(self, num_envs: int, ctx_len: int, device, tcfg) -> None:
        self.num_envs = num_envs
        self.ctx_len = ctx_len
        self.tcfg = tcfg
        self.device = device
        self.reset()

    @property
    def is_full(self):
        assert not self.wrapped
        return self.ctx_len == self.ptr

    def reset(self):
        tcfg = self.tcfg
        device = self.device
        self.obs = torch.zeros(
            (self.num_envs, self.ctx_len)
            + (tuple() if tcfg["obs_is_concrete"] else (tcfg["obs_dim"],)),
            device=device,
            dtype=torch.int32 if tcfg["obs_is_concrete"] else torch.float32,
        )
        self.acts = torch.zeros(
            (self.num_envs, self.ctx_len)
            + (tuple() if tcfg["act_is_concrete"] else (tcfg["act_dim"],)),
            device=device,
            dtype=torch.int32 if tcfg["act_is_concrete"] else torch.float32,
        )
        self.rwds = torch.zeros(
            (self.num_envs, self.ctx_len)
            + (tuple() if tcfg["rwd_is_concrete"] else (tcfg["rwd_dim"],)),
            device=device,
            dtype=torch.int32 if tcfg["rwd_is_concrete"] else torch.float32,
        )
        self.terms = torch.zeros(
            (self.num_envs, self.ctx_len), device=device, dtype=torch.int32
        )
        self.pos = torch.zeros(
            (self.num_envs, self.ctx_len), device=device, dtype=torch.int32
        )
        self.ptr = 0
        self.wrapped = False

    def enroll_obs(self, obs, pos):
        if self.ptr == self.ctx_len:
            self.ptr = 0
            self.wrapped = True

        self.obs[:, self.ptr] = obs
        self.pos[:, self.ptr] = pos

    def enroll_rest(self, acts, rwds, terms, pos):

        assert torch.all(self.pos[:, self.ptr] == pos)
        self.acts[:, self.ptr] = acts
        self.rwds[:, self.ptr] = rwds
        self.terms[:, self.ptr] = terms
        self.ptr += 1

    def extract(self):
        if not self.wrapped:
            _p = self.ptr + 1
            _obs = self.obs[:, :_p]
            _acts = self.acts[:, :_p]
            _rwds = self.rwds[:, :_p]
            _terms = self.terms[:, :_p]
            _pos = self.pos[:, :_p]
            return (_obs, _acts, _rwds, _terms, _pos)

        _p = self.ptr + 1
        _obs = torch.cat((self.obs[:, _p:], self.obs[:, :_p]), dim=1)
        _acts = torch.cat((self.acts[:, _p:], self.acts[:, :_p]), dim=1)
        _rwds = torch.cat((self.rwds[:, _p:], self.rwds[:, :_p]), dim=1)
        _terms = torch.cat((self.terms[:, _p:], self.terms[:, :_p]), dim=1)
        _pos = torch.cat((self.pos[:, _p:], self.pos[:, :_p]), dim=1)

        return (_obs, _acts, _rwds, _terms, _pos)
