import torch
from functools import partial
import math
from typing import Tuple, Literal, Dict, Any, cast, Optional, List
from utils.time import Timeit
from torch import nn
from dataclasses import dataclass
from algos.ad.casual_transformer import TransformerModel, Cfg as TransformerModelCfg
import torch.nn.functional as F
from omegaconf import DictConfig
from algos.common import cosine_annealing_with_warmup, Tokenizer, Header, Context


TRAJS = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]


@dataclass
class Cfg:
    token_dim: int
    num_envs: int
    hidden_dim: int
    n_layers: int
    device: str
    ctx_len: int
    steps_trained: int
    lr: float
    # grid_size: int
    num_train_steps: int
    clip_grad_norm: float


class AD:
    def __init__(self, cfg: Cfg, tcfg: Dict[str, Any], rcfg) -> None:
        self.rcfg = rcfg
        tokn = Tokenizer({**tcfg, "token_dim": cfg.token_dim}, rcfg).to(cfg.device)
        self.tokn = tokn

        assert self.rcfg.algo.name in ["ad", "mem", "xl", "ed"]
        self.variant: Literal["mem", "ad", "ed", "xl"] = self.rcfg.algo.name

        self.ctx_len = cfg.ctx_len
        self.ctx = Context(cfg.num_envs, self.ctx_len, cfg.device, tcfg)

        head = Header(cfg.hidden_dim, tcfg).to(cfg.device)
        self.model = TransformerModel(
            TransformerModelCfg(
                token_dim=cfg.token_dim,
                hidden_dim=cfg.hidden_dim,
                n_layers=cfg.n_layers,
                max_seq_len=cfg.steps_trained,
                device=cfg.device,
                mem_len=None if self.variant == "ad" else rcfg.algo.mem_len,
                all_cfg=self.rcfg,
            ),
            tokenizers=tokn,
            head=head,
            variant=self.variant,
        )
        self.head = head
        model = self.model
        self.model.apply(self.model._init_weights)
        self.optim = torch.optim.AdamW(
            params=model.parameters(),
            lr=cfg.lr,
            weight_decay=1e-4,
            betas=(0.9, 0.999),
        )
        print([f"{n}: {t.shape}" for n, t in model.named_parameters()])
        self.scheduler = cosine_annealing_with_warmup(
            optimizer=self.optim,
            warmup_steps=int(cfg.num_train_steps * 0.1),
            total_steps=cfg.num_train_steps,
        )
        self.clip_grad_norm = cfg.clip_grad_norm
        self.loss_fn = (
            nn.CrossEntropyLoss(label_smoothing=0.1)
            if tcfg["act_is_concrete"]
            else lambda pred, target: F.mse_loss(pred, target)
        )
        self.cfg = cfg

        self.act_dim = tcfg["act_nums"] if tcfg["act_is_concrete"] else tcfg["act_dim"]
        self.tcfg = tcfg

    def enroll_obs(self, obs, pos):
        return self.ctx.enroll_obs(self.tokn.obs_to_int(obs), pos)

    @torch.no_grad()
    def make_action(self, emb_mems: Optional[List[torch.Tensor]] = None):
        trajs = self.ctx.extract()
        obs, actions, rewards, terms, pos = trajs

        _ret = self.model(obs, actions, rewards, terms, pos, emb_mems=emb_mems)
        if self.rcfg.algo.name in ["xl", "ed"]:
            logits, new_emb_mem = _ret
        else:
            logits = _ret

        _acts = self.model.head.to_act(logits)

        return _acts[:, -1], (
            None if self.rcfg.algo.name not in ["xl", "ed"] else new_emb_mem
        )

    def enroll_rest(self, acts, rwds, terms, pos):
        return self.ctx.enroll_rest(acts, rwds, terms, pos)

    def update(self, trajs: TRAJS, acc, emb_mems: Optional[List[torch.Tensor]] = None):
        obs, true_actions, rewards, terms, pos = trajs
        obs = self.tokn.obs_to_int(obs)

        assert (
            true_actions.shape[1]
            == self.ctx_len
            == rewards.shape[1]
            == terms.shape[1]
            == pos.shape[1]
        )

        # Make prediction
        _ret = self.model(obs, true_actions, rewards, terms, pos, emb_mems=emb_mems)
        if self.rcfg.algo.name in ["xl", "ed"]:
            logits, new_emb_mem = _ret
        else:
            logits = _ret
        assert logits.shape[1:] == (self.ctx_len, self.act_dim)

        if self.tcfg["act_is_concrete"]:
            true_actions = true_actions.to(torch.int64)

        with acc.autocast():
            loss = self.loss_fn(
                logits.flatten(0, 1),
                true_actions.flatten(0, 1),
            )

        # Make optimization step
        self.optim.zero_grad(set_to_none=True)

        acc.backward(loss)
        acc.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm)
        self.optim.step()
        if not acc.optimizer_step_was_skipped:
            self.scheduler.step()

        return loss.item(), (
            None if self.rcfg.algo.name not in ["xl", "ed"] else new_emb_mem
        )

    def eval(self):
        self.model.eval()

    def train(self):
        self.model.train()

    def reset(self):
        self.ctx.reset()


def refill_cfg(cfg: DictConfig) -> Tuple[DictConfig, Dict[str, Any]]:
    from omegaconf.omegaconf import OmegaConf

    dcfg = {}

    if cfg.env.name in ["dark_room", "dark_key_to_door"]:
        assert OmegaConf.is_missing(cfg, "episodes_trained")
        match [
            cast(
                Literal["dark_room", "dark_key_to_door"],
                cfg.env.name,
            ),
            cast(Literal["normal", "large"], cfg.env.size_kind),
            cast(Literal["normal", "quick", "dense"], cfg.env.kind),
        ]:
            case ["dark_room", "normal", "normal"]:
                if OmegaConf.is_missing(cfg, "episodes_trained"):
                    cfg.episodes_trained = 125
                cfg.env.episode_length = 20
            case ["dark_room", "large", "normal"]:
                if OmegaConf.is_missing(cfg, "episodes_trained"):
                    cfg.episodes_trained = 150
                cfg.env.episode_length = 50
            case ["dark_room", "large", "quick"]:
                if OmegaConf.is_missing(cfg, "episodes_trained"):
                    cfg.episodes_trained = 250
                cfg.env.episode_length = 50
            case ["dark_room", "large", "dense"]:
                if OmegaConf.is_missing(cfg, "episodes_trained"):
                    cfg.episodes_trained = 175
                cfg.env.episode_length = 50

            case ["dark_key_to_door", "normal", "normal"]:
                if OmegaConf.is_missing(cfg, "episodes_trained"):
                    cfg.episodes_trained = 150
                cfg.env.episode_length = 50
            case ["dark_key_to_door", "large", "normal"]:
                if OmegaConf.is_missing(cfg, "episodes_trained"):
                    cfg.episodes_trained = 200
                cfg.env.episode_length = 70
            case ["dark_key_to_door", "large", "quick"]:
                if OmegaConf.is_missing(cfg, "episodes_trained"):
                    cfg.episodes_trained = 400
                cfg.env.episode_length = 50
            case ["dark_key_to_door", "large", "dense"]:
                if OmegaConf.is_missing(cfg, "episodes_trained"):
                    cfg.episodes_trained = 350
                cfg.env.episode_length = 70

            case _:
                raise ValueError(
                    f"unsupport env config: {cfg.env.kind}, {cfg.env.size_kind}"
                )
    else:
        assert cfg.env.name in ["cartpole", "reacher"], cfg.env.name
        # cfg.episo
        assert cfg.env.episode_length == 1000
        cfg.episodes_trained = 20 if cfg.env.name == "cartpole" else 22

    dcfg["steps_trained"] = cfg.episodes_trained * cfg.env.episode_length

    if OmegaConf.is_missing(cfg.algo, "ctx_len"):
        if not cfg.algo.reduced:
            # cfg.algo.ctx_len = int(2.5 * cfg.env.episode_length)
            _ctx_len = int(2.5 * cfg.env.episode_length)
            cfg.algo.ctx_len = min(_ctx_len, 125)
            if _ctx_len > 125:
                print(f"too large ctx_len {_ctx_len}, reduce to {cfg.algo.ctx_len}")
        else:
            if (
                cfg.env.name == "dark_room"
                and cfg.env.kind == "normal"
                and cfg.env.size_kind == "normal"
            ):
                cfg.algo.ctx_len = 10
            elif cfg.env.name in ["cartpole", "reacher"]:
                cfg.algo.ctx_len = 125
            else:
                cfg.algo.ctx_len = 25

    if OmegaConf.is_missing(cfg, "training_steps"):
        if cfg.env.name in ["dark_room", "dark_key_to_door"]:
            _training_steps = int(1e5) if cfg.env.name == "dark_room" else int(2e5)
            if cfg.env.size_kind == "large":
                _training_steps *= 2
            match cfg.env.kind:
                case "normal":
                    _training_steps *= 1
                case "quick":
                    _training_steps *= 1.5
                case "dense":
                    _training_steps *= 1
                case _:
                    raise ValueError(f"unsupported env.kind: {cfg.env.kind}")

        else:
            assert cfg.env.name in ["cartpole", "reacher"]
            _training_steps = int(6e5)

        # multiply by 5 for ADR
        if cfg.algo.reduced:
            _training_steps *= 5 if cfg.env.name not in ["cartpole", "reacher"] else 2
        cfg.training_steps = _training_steps

    cfg.training_steps = int(cfg.training_steps)
    assert cfg.training_steps % cfg.split_nums == 0
    dcfg["frames_per_split"] = int(cfg.training_steps / cfg.split_nums)

    return cfg, dcfg
