import torch
from functools import partial
import math
from typing import Tuple, Literal, Dict, Any, cast
from utils.time import Timeit
from torch import nn
from x_transformers import Decoder
from dataclasses import dataclass
import torch.nn.functional as F
from omegaconf import DictConfig
from algos.common import cosine_annealing_with_warmup, Tokenizer, Header, Context


def refill_cfg(cfg: DictConfig) -> Tuple[DictConfig, Dict[str, Any]]:
    from omegaconf.omegaconf import OmegaConf

    dcfg = {}

    if cfg.env.name in ["dark_room", "dark_key_to_door"]:
        assert OmegaConf.is_missing(cfg, "episodes_trained")
        match [
            cast(
                Literal["dark_room", "dark_key_to_door"],
                cfg.env.name,
            ),
            cast(Literal["normal", "large"], cfg.env.size_kind),
            cast(Literal["normal", "quick", "dense"], cfg.env.kind),
        ]:
            case ["dark_room", "normal", "normal"]:
                if OmegaConf.is_missing(cfg, "episodes_trained"):
                    cfg.episodes_trained = 125
            case ["dark_room", "large", "normal"]:
                if OmegaConf.is_missing(cfg, "episodes_trained"):
                    cfg.episodes_trained = 150
            case ["dark_room", "large", "quick"]:
                if OmegaConf.is_missing(cfg, "episodes_trained"):
                    cfg.episodes_trained = 250
            case ["dark_room", "large", "dense"]:
                if OmegaConf.is_missing(cfg, "episodes_trained"):
                    cfg.episodes_trained = 175

            case ["dark_key_to_door", "normal", "normal"]:
                if OmegaConf.is_missing(cfg, "episodes_trained"):
                    cfg.episodes_trained = 150
            case ["dark_key_to_door", "large", "normal"]:
                if OmegaConf.is_missing(cfg, "episodes_trained"):
                    cfg.episodes_trained = 200
            case ["dark_key_to_door", "large", "quick"]:
                if OmegaConf.is_missing(cfg, "episodes_trained"):
                    cfg.episodes_trained = 400
            case ["dark_key_to_door", "large", "dense"]:
                if OmegaConf.is_missing(cfg, "episodes_trained"):
                    cfg.episodes_trained = 350

            case _:
                raise ValueError(
                    f"unsupport env config: {cfg.env.kind}, {cfg.env.size_kind}"
                )
    else:
        assert cfg.env.name in ["cartpole", "reacher"], cfg.env.name
        # cfg.episo
        cfg.episodes_trained = 20 if cfg.env.name == "cartpole" else 22

    dcfg["steps_trained"] = cfg.episodes_trained * cfg.env.episode_length

    if OmegaConf.is_missing(cfg.algo, "ctx_len"):
        if (
            cfg.env.name == "dark_room"
            and cfg.env.kind == "normal"
            and cfg.env.size_kind == "normal"
        ):
            cfg.algo.ctx_len = 10
        elif cfg.env.name in ["cartpole", "reacher"]:
            cfg.algo.ctx_len = 100
        else:
            cfg.algo.ctx_len = 25

    if OmegaConf.is_missing(cfg.algo, "mem_len"):
        if (
            cfg.env.name == "dark_room"
            and cfg.env.kind == "normal"
            and cfg.env.size_kind == "normal"
        ):
            cfg.algo.mem_len = 10
        elif cfg.env.name in ["cartpole", "reacher"]:
            cfg.algo.mem_len = 200 if cfg.algo.name != "mem" else 25
        else:
            cfg.algo.mem_len = 50

    if OmegaConf.is_missing(cfg, "training_steps"):
        if cfg.env.name in ["dark_room", "dark_key_to_door"]:
            _training_steps = int(1e5) if cfg.env.name == "dark_room" else int(2e5)
            if cfg.env.size_kind == "large":
                _training_steps *= 2
            match cfg.env.kind:
                case "normal":
                    _training_steps *= 1
                case "quick":
                    _training_steps *= 1.5
                case "dense":
                    _training_steps *= 1
                case _:
                    raise ValueError(f"unsupported env.kind: {cfg.env.kind}")

            if (
                cfg.env.name == "dark_room"
                and cfg.env.kind == "normal"
                and cfg.env.size_kind == "normal"
            ):
                _training_steps /= 1.5
            else:
                _training_steps /= 1.7

        else:
            assert cfg.env.name in ["cartpole", "reacher"]
            _training_steps = int(6e5)

        # if cfg.algo.reduced:
        # multiply by 5 for xls
        _training_steps *= 5 if cfg.env.name not in ["cartpole", "reacher"] else 2
        cfg.training_steps = _training_steps

    print(f"training_steps before reduced: {cfg.training_steps}")
    cfg.training_steps = math.floor(
        cfg.training_steps * cfg.training_steps_reduce_factor
    )
    print(
        f"training_steps after reduced: {cfg.training_steps}"
    )

    cfg.training_steps -= cfg.training_steps % cfg.split_nums
    print(
        f"training_steps after chunked: {cfg.training_steps}"
    )
    assert cfg.training_steps % cfg.split_nums == 0
    dcfg["frames_per_split"] = int(cfg.training_steps / cfg.split_nums)

    return cfg, dcfg
