import logging
from pathlib import Path

import torch as th
import tqdm
from stable_baselines3 import RecurrentPPO
from stable_baselines3.common.type_aliases import check_cast, non_null

from learned_planners.convlstm import ConvLSTMCell, ConvLSTMFeaturesExtractor
from learned_planners.train import (
    TrainConfig,
    create_vec_env_and_eval_callbacks,
    make_model,
)

log = logging.getLogger(__name__)


class Benchmark(TrainConfig):
    def run(self, run_dir: Path):
        vec_env, _ = create_vec_env_and_eval_callbacks(self, run_dir, eval_freq=1)
        model = check_cast(RecurrentPPO, make_model(self, run_dir, vec_env, []))
        vec_env = non_null(model.env)  # Bring the wrapped env back

        obs = check_cast(th.Tensor, vec_env.reset()).to(self.device)

        fex = check_cast(ConvLSTMFeaturesExtractor, model.policy.features_extractor)
        cell_zero: ConvLSTMCell = fex.cell_list[0]
        initial_state = cell_zero.recurrent_initial_state((obs.shape[-2], obs.shape[-1]), self.env.n_envs, device=self.device)

        forward = th.compile(
            cell_zero.forward,
            options={"triton.cudagraphs": True},
            fullgraph=True,
            backend="inductor",
        )

        multi_channel_obs = th.zeros((self.env.n_envs, 32, *obs.shape[-2:]), device=self.device, dtype=th.float32)
        hidden_in = initial_state[0].squeeze(0)

        for _ in tqdm.trange(20000):
            model.policy.optimizer.zero_grad()
            out = forward(multi_channel_obs, hidden_in, initial_state)
            out[0].sum().backward()
            model.policy.optimizer.step()


class BenchmarkRecurrent(TrainConfig):
    def run(self, run_dir: Path):
        vec_env, _ = create_vec_env_and_eval_callbacks(self, run_dir, eval_freq=1)
        model = check_cast(RecurrentPPO, make_model(self, run_dir, vec_env, []))
        vec_env = non_null(model.env)  # Bring the wrapped env back

        obs = check_cast(th.Tensor, vec_env.reset())
        seq_obs = th.cat([obs] * self.n_steps, dim=0).to(dtype=th.float32, device=self.device)

        initial_state = model.policy.recurrent_initial_state(self.env.n_envs, device=self.device)
        episode_starts = th.zeros(self.env.n_envs * self.n_steps, dtype=th.bool, device=self.device)

        model.policy.features_extractor = th.compile(  # type: ignore[assignment]
            model.policy.features_extractor,
            options={"triton.cudagraphs": True},
            fullgraph=True,
            backend="inductor",
            disable=False,
        )

        for _ in tqdm.trange(20000):
            model.policy.optimizer.zero_grad()
            out = model.policy.predict_values(obs=seq_obs, state=initial_state, episode_starts=episode_starts)
            out.sum().backward()
            model.policy.optimizer.step()
