import numpy as np
import torch
import math
import gym
from utils.common import LazyFrames, resolve_lazy_frames, get_device
from typing import (
    Any,
    TypeVar,
    Generic,
    Literal,
    List,
    Tuple,
    cast,
    Dict,
    Optional,
    Callable,
    Protocol,
    Union,
    TypeVar,
)
from tqdm import tqdm
from utils.algorithm import Algorithm, ActionInfo, Mode, ReportInfo
from utils.common import Action, Info, Reward, State
from torch.utils.data import DataLoader
from utils.transition import TransitionTuple
from utils.reporter import Reporter, ReportTrait
from utils.step import NotNoneStep, Step


class BaseTrainer(ReportTrait):
    def __init__(
        self,
        algm: Algorithm,
        with_reporter: Optional[Reporter] = None,
    ):
        ReportTrait.__init__(self, with_reporter=with_reporter)
        self.algm = algm
        assert self.algm.name
        self.name: str = algm.name
        self.algm.set_reporter(with_reporter)

    def get_action(self, state: State, env: gym.Env, mode: Mode) -> ActionInfo:
        action_space = env.action_space
        actinfo = self.algm.take_action(mode, state, env)
        act = actinfo[0] if isinstance(actinfo, tuple) else actinfo

        info = (
            actinfo[1]
            if isinstance(actinfo, tuple)
            else [{"end": False} for _ in range(act.size(0))]
        )
        assert all([not i["end"] for i in info])

        assert (
            isinstance(action_space, gym.spaces.Discrete)
            or isinstance(action_space, gym.spaces.Box)
            or isinstance(action_space, gym.spaces.MultiDiscrete)
        )

        if isinstance(action_space, gym.spaces.Discrete):
            assert act.shape == (1,) or act.shape == tuple()

        if isinstance(action_space, gym.spaces.Box):
            assert act.shape == action_space.shape

        return (act, info)

    def train(self, info: Info, training_frames=int(1e6)):
        env = info["env"]

        with tqdm(total=training_frames) as pbar:
            frames = 0
            while frames <= training_frames:
                train_frames, stopped = self.rollout(env, "train")
                assert train_frames >= 0
                assert len(stopped) >= 1

                for i in stopped:
                    env.add_scalars(
                        # dict(episode_return=sum(env.last_episode[i][2])), "train"
                        dict(episode_return=sum(env.last_episode_reward(i))),
                        "train",
                    )

                pbar.update(train_frames)
                frames += train_frames

    def eval(self, info: Info):
        self.algm.eval(info)

    def train_and_eval(
        self,
        info: Info,
        eval_env: gym.Env,
        seed: int,
        single_train_frames=int(1e4),
        total_train_frames=int(1e6),
    ):
        s = math.ceil(total_train_frames / single_train_frames)

        self.algm.pretrain(info)
        self.algm.save(0)

        meta_info = {"progress": 0.0, "iter": 0, "total": s}
        self.eval(dict(env=eval_env, **info, **meta_info))
        self.algm.valid({**info, **meta_info})

        for i in tqdm(range(s), desc="total training progress"):
            self.train({**info, **meta_info}, single_train_frames)

            meta_info["progress"] = (i + 1) / s
            meta_info["iter"] = i + 1
            self.algm.save(i + 1)
            self.eval(dict(env=eval_env, **info, **meta_info))
            self.algm.valid({**info, **meta_info})


class OnlineTrainer(BaseTrainer):
    ...


class OfflineTrainer(BaseTrainer):
    def train(self, info: Info, training_frames=int(1e6)):
        with tqdm(total=training_frames, desc="single training progress") as pbar:
            frames = 0
            while frames <= training_frames:
                last_trained_steps = self.algm.trained_steps
                self.algm.manual_train(info)
                train_frames = self.algm.trained_steps - last_trained_steps

                assert train_frames > 0

                pbar.update(train_frames)
                frames += train_frames


AllTrainer = Union[OnlineTrainer, OfflineTrainer]
