import dataclasses
import glob
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union

import numpy as np
import pandas as pd
import torch as th
from gym_sokoban.envs.boxoban_env import BoxobanEnv
from stable_baselines3 import RecurrentPPO
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.recurrent.policies import (
    BaseRecurrentActorCriticPolicy,
    RecurrentActorCriticPolicy,
)
from stable_baselines3.common.recurrent.torch_layers import RecurrentState
from stable_baselines3.common.type_aliases import (
    GymEnv,
    MaybeCallback,
    Schedule,
    non_null,
)
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv
from stable_baselines3.common.vec_env.util import obs_as_tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from learned_planners.common import catch_different_env_types_warning
from learned_planners.environments import BoxobanConfig
from learned_planners.train import RecurrentPPOConfig, TrainConfig, create_vec_env_and_eval_callbacks, make_model


class AStarSokobanDataset(Dataset):
    action_map = [0, 3, 1, 2]

    def __init__(self, env: VecEnv, astar_logs_path: Optional[Union[Path, str]] = None, levels_per_file: int = 1000):
        super().__init__()
        self.n_envs = env.num_envs
        assert self.n_envs == 1, "AStarSokobanDataset only supports single envs. Batching can be done using a dataloader."
        assert isinstance(env.unwrapped, DummyVecEnv), "AStarSokobanDataset only supports DummyVecEnv"
        self.unwrapped_env = env.unwrapped.envs[0].unwrapped
        assert isinstance(self.unwrapped_env, BoxobanEnv), "AStarSokobanDataset only supports BoxobanEnv"
        self.levels_per_file = levels_per_file
        if astar_logs_path is None:
            levels_dir = Path(self.unwrapped_env.train_data_dir)
            astar_logs_path = levels_dir / "logs_final"
        self.setup_dataset_labels(astar_logs_path)
        self.return_without_steps = (
            self.unwrapped_env.reward_finished + self.unwrapped_env.reward_box_on_target * self.unwrapped_env.num_boxes
        )
        self.step_reward = self.unwrapped_env.penalty_for_step

        self.fails = 0

    @staticmethod
    def map_action(action_sequence: str):
        try:
            return [AStarSokobanDataset.action_map[int(c)] for c in action_sequence.strip()]
        except ValueError:
            return None

    def setup_dataset_labels(self, astar_logs_path: Union[Path, str]) -> None:
        if isinstance(astar_logs_path, str):
            astar_logs_path = Path(astar_logs_path)
        if not astar_logs_path.exists():
            raise FileNotFoundError(f"astar_logs_path {astar_logs_path} does not exist")
        self.astar_logs_df = []
        self.total_levels = 0
        for file in sorted(glob.glob(str(astar_logs_path) + "/*.csv")):
            df = pd.read_csv(file, index_col=0)
            df = df.rename(columns=lambda x: x.strip())
            df["Actions"] = df["Actions"].apply(AStarSokobanDataset.map_action)
            self.astar_logs_df.append(df)
            self.total_levels += len(df)
        assert len(self.astar_logs_df) == len(self.unwrapped_env.level_files), "Mismatch in number of level files"  # type: ignore

    def __len__(self):
        return self.total_levels

    def __getitem__(self, idx):
        level_file_idx = idx // self.levels_per_file
        level_idx = idx % self.levels_per_file
        try:
            row = self.astar_logs_df[level_file_idx].loc[level_idx]
            if isinstance(row, pd.DataFrame):
                row = row.iloc[0]
        except KeyError:
            self.fails += 1
            random_idx = th.randint(0, self.total_levels, (1,)).item()
            return self[random_idx]
        if row["Steps"] == -1:
            self.fails += 1
            random_idx = th.randint(0, self.total_levels, (1,)).item()
            return self[random_idx]
        first_obs = self.unwrapped_env.reset(options={"level_file_idx": level_file_idx, "level_idx": level_idx})[0]
        return self.get_all_step_results(first_obs, row["Actions"], level_file_idx, level_idx)

    def get_all_step_results(self, first_obs, actions, level_file_idx, level_idx):
        obs = [first_obs]
        dones = []
        values = []
        return_val = self.return_without_steps + self.step_reward * len(actions)
        values.append(return_val)
        for a in actions:
            ob, reward, done, _, _ = self.unwrapped_env.step(a)
            obs.append(ob)
            dones.append(done)
            return_val -= float(reward)
            values.append(return_val)

        obs.pop()

        try:
            assert abs(values.pop()) < 1e-4, f"End state value: {return_val} != 0; {level_file_idx=}, {level_idx=}"
            assert dones[-1]
            assert not any(dones[:-1])
        except AssertionError:
            self.fails += 1
            return self[th.randint(0, self.total_levels, (1,)).item()]

        values = th.tensor(values)
        episode_starts = th.zeros(len(obs), dtype=th.bool)
        episode_starts[0] = 1

        return th.tensor(np.stack(obs)), th.tensor(actions), values, episode_starts


class RecurrentAstarSL(RecurrentPPO):
    def __init__(
        self,
        policy: Union[str, Type[BaseRecurrentActorCriticPolicy[RecurrentState]]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 3e-4,
        n_steps: int = 128,
        batch_envs: int = 128,
        batch_time: Optional[int] = None,
        n_epochs: int = 10,
        clip_range: Union[float, Schedule] = 0.2,
        clip_range_vf: Union[None, float, Schedule] = None,
        max_grad_norm: Optional[float] = 0.5,
        tensorboard_log: Optional[str] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[th.device, str] = "auto",
        _init_setup_model: bool = True,
    ):
        super().__init__(
            policy=policy,
            env=env,
            learning_rate=learning_rate,
            n_steps=n_steps,
            batch_envs=batch_envs,
            batch_time=batch_time,
            n_epochs=n_epochs,
            clip_range=clip_range,
            clip_range_vf=clip_range_vf,
            max_grad_norm=max_grad_norm,
            tensorboard_log=tensorboard_log,
            policy_kwargs=policy_kwargs,
            verbose=verbose,
            seed=seed,
            device=device,
            _init_setup_model=_init_setup_model,
        )

    @th.no_grad()
    def collect_rollouts(  # type: ignore[override]
        self,
        env,
        callback,
        rollout_buffer,
        n_rollout_steps,
    ):
        raise RuntimeError("RecurrentAstarSL.collect_rollouts should not be called")

    def train_sl(self, levels_per_file: int = 1000, callback: MaybeCallback = None) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        assert self.env is not None, "You must pass an environment to the model before training it"
        dataset = AStarSokobanDataset(self.env, levels_per_file=levels_per_file)
        self._total_timesteps = self.n_epochs * len(dataset) // self.batch_envs
        self.num_timesteps = 0
        callback = self._init_callback(callback)

        dataset = AStarSokobanDataset(self.env, levels_per_file=levels_per_file)
        self._total_timesteps = self.n_epochs * len(dataset) // self.batch_envs
        self.num_timesteps = 0
        callback = self._init_callback(callback)

        lstm_states = non_null(self._last_lstm_states)

        def collate_fn(batch):
            batch_by_keys = list(zip(*batch))
            batch_by_keys[0] = [th.permute(obs, (0, 3, 1, 2)) for obs in batch_by_keys[0]]
            return tuple(th.nn.utils.rnn.pad_sequence(t, padding_value=0) for t in batch_by_keys)

        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_envs,
            shuffle=True,
            collate_fn=collate_fn,
            num_workers=16,
            persistent_workers=False,
        )
        callback.on_training_start(locals(), globals())

        # train for n_epochs epochs
        for epoch in range(self.n_epochs):
            value_losses, policy_losses, total_losses = [], [], []
            print(f"Epoch {epoch}/{self.n_epochs}")

            for obs, gt_actions, gt_values, ep_starts in tqdm(dataloader):
                obs_tensor = obs_as_tensor(obs, self.device)
                gt_actions = gt_actions.to(self.device)
                gt_values = gt_values.to(self.device)
                ep_starts = ep_starts.to(self.device)
                pred_values, log_probs, _ = self.policy.evaluate_actions(obs_tensor, gt_actions, lstm_states, ep_starts)
                # Compute the loss of plicy (cross entropy) and value (MSE) network with torch
                policy_loss = -log_probs.mean()

                value_loss = F.mse_loss(pred_values.squeeze(-1), gt_values)
                loss = policy_loss + value_loss

                # Optimization step
                self.policy.optimizer.zero_grad()
                loss.backward()
                # Clip grad norm
                if self.max_grad_norm is not None:
                    th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
                self.policy.optimizer.step()

                value_losses.append(value_loss.item())
                policy_losses.append(policy_loss.item())
                total_losses.append(loss.item())
                self._n_updates += 1
                self.num_timesteps += 1

                # Update optimizer learning rate
                self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)
                self._update_learning_rate(self.policy.optimizer)

            # Log losses
            self.logger.record("train/policy_loss", np.mean(policy_losses))
            self.logger.record("train/value_loss", np.mean(value_losses))
            self.logger.record("train/loss", np.mean(total_losses))
            self.logger.record("train/epoch", epoch + 1)
            self.logger.dump(step=epoch + 1)

            # Evaluate the policy
            callback.update_locals(locals())
            callback.on_step()

        self.policy.optimizer.zero_grad(set_to_none=True)
        callback.on_training_end()

    def train(self, callback: Optional[BaseCallback] = None) -> None:
        raise RuntimeError("RecurrentAstarSL.train should not be called")

    def learn(
        self,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 1,
        tb_log_name: str = "RecurrentPPO",
        reset_num_timesteps: bool = True,
        progress_bar: bool = False,
    ):
        raise RuntimeError("RecurrentAstarSL.learn should not be called")


@dataclasses.dataclass
class RecurrentAstarSLConfig(RecurrentPPOConfig):
    _alg_class = RecurrentAstarSL
    batch_time: Optional[int] = 20
    batch_envs: int = 128

    def make(
        self,
        policy: type[ActorCriticPolicy] | type[RecurrentActorCriticPolicy] | str,
        env: GymEnv,
        n_steps: int,
        seed: int,
        device: th.device,
        policy_kwargs: dict | None,
        # Avoid adding kwargs on purpose -- these are all the args that RecurrentAstarSLConfig accepts.
    ) -> RecurrentAstarSL:
        policy_kwargs = (policy_kwargs or {}).copy()
        policy_kwargs.update(self.optimizer.policy_kwargs())
        return RecurrentAstarSL(
            policy=policy,  # type: ignore
            env=env,
            learning_rate=self.optimizer.lr,
            n_epochs=self.n_epochs,
            clip_range=self.clip_range,
            clip_range_vf=self.clip_range_vf,
            max_grad_norm=self.max_grad_norm,
            policy_kwargs=policy_kwargs,
            seed=seed,
            device=device,
            batch_time=self.batch_time,
            batch_envs=self.batch_envs,
        )


@dataclasses.dataclass
class TrainAstarSLConfig(TrainConfig):
    levels_per_file: int = 1000

    def __post_init__(self):
        assert isinstance(self.env, BoxobanConfig)

    def run(self, run_dir: Path):
        return train(self, run_dir)


def train(args: TrainAstarSLConfig, run_dir: Path):
    """Trains a policy that uses algorithm `args.alg` and policy `args.policy` in the environment `args.env`."""
    vec_env, eval_callbacks = create_vec_env_and_eval_callbacks(args, run_dir, eval_freq=args.checkpoint_freq)
    model = make_model(args, run_dir, vec_env, eval_callbacks)
    assert isinstance(model, RecurrentAstarSL)
    cp_callback = CheckpointCallback(args.checkpoint_freq, str(run_dir), verbose=1)

    with catch_different_env_types_warning(vec_env, eval_callbacks):
        model.train_sl(
            levels_per_file=args.levels_per_file,
            callback=[
                cp_callback,
                *eval_callbacks,
            ],
        )
    model.logger.close()
