from collections.abc import Callable
from dataclasses import asdict

import numpy as np

from tianshou.data import (
    Collector,
    CollectStats,
    ReplayBuffer,
    SequenceSummaryStats,
)
from tianshou.policy import BasePolicy
from tianshou.trainer.base import OnpolicyTrainer
from tianshou.trainer.utils import test_episode
from tianshou.utils import (
    BaseLogger,
    LazyLogger,
)


class GRPOTrainer(OnpolicyTrainer):
    def __init__(
        self,
        policy: BasePolicy,
        max_epoch: int,
        batch_size: int | None,
        group_num: int = 5,
        train_collector: Collector | None = None,
        test_collector: Collector | None = None,
        buffer: ReplayBuffer | None = None,
        step_per_epoch: int | None = None,
        repeat_per_collect: int | None = None,
        episode_per_test: int | None = None,
        update_per_step: float = 1.0,
        step_per_collect: int | None = None,
        episode_per_collect: int | None = None,
        train_fn: Callable[[int, int], None] | None = None,
        test_fn: Callable[[int, int | None], None] | None = None,
        stop_fn: Callable[[float], bool] | None = None,
        save_best_fn: Callable[[BasePolicy], None] | None = None,
        save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
        resume_from_log: bool = False,
        reward_metric: Callable[[np.ndarray], np.ndarray] | None = None,
        logger: BaseLogger = LazyLogger(),
        verbose: bool = True,
        show_progress: bool = True,
        test_in_train: bool = True,
        save_fn: Callable[[BasePolicy], None] | None = None,
    ):
        super().__init__(
            policy,
            max_epoch,
            batch_size,
            train_collector,
            test_collector,
            buffer,
            step_per_epoch,
            repeat_per_collect,
            episode_per_test,
            update_per_step,
            step_per_collect,
            episode_per_collect,
            train_fn,
            test_fn,
            stop_fn,
            save_best_fn,
            save_checkpoint_fn,
            resume_from_log,
            reward_metric,
            logger,
            verbose,
            show_progress,
            test_in_train,
            save_fn,
        )
        self.group_num = group_num

    def train_step(self) -> tuple[CollectStats, bool]:
        """Perform one training step.

        If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
        Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return
        on it.
        Finally, if the latter is also True, will set should_stop_training to True.

        :return: A tuple of the training stats and a boolean indicating whether to stop training.
        """
        assert self.episode_per_test is not None
        assert self.train_collector is not None
        should_stop_training = False
        if self.train_fn:
            self.train_fn(self.epoch, self.env_step)
        result = self.train_collector.collect(
            n_step=self.step_per_collect,
            n_episode=self.episode_per_collect,
            group_num=self.group_num,
        )

        self.env_step += result.n_collected_steps

        if result.n_collected_episodes > 0:
            assert result.returns_stat is not None  # for mypy
            assert result.lens_stat is not None  # for mypy
            self.last_rew = result.returns_stat.mean
            self.last_len = result.lens_stat.mean
            if self.reward_metric:  # TODO: move inside collector
                rew = self.reward_metric(result.returns)
                result.returns = rew
                result.returns_stat = SequenceSummaryStats.from_sequence(rew)

            self.logger.log_train_data(asdict(result), self.env_step)

        if (
            result.n_collected_episodes > 0
            and self.test_in_train
            and self.stop_fn
            and self.stop_fn(result.returns_stat.mean)  # type: ignore
        ):
            assert self.test_collector is not None
            test_result = test_episode(
                self.policy,
                self.test_collector,
                self.test_fn,
                self.epoch,
                self.episode_per_test,
                self.logger,
                self.env_step,
            )
            assert test_result.returns_stat is not None  # for mypy
            if self.stop_fn(test_result.returns_stat.mean):
                should_stop_training = True
                self.best_reward = test_result.returns_stat.mean
                self.best_reward_std = test_result.returns_stat.std
            else:
                self.policy.train()
        return result, should_stop_training