

def test_episode(policy, collector, n_episode):
    collector.reset_env()
    collector.reset_buffer()
    policy.eval()
    result = collector.collect(n_episode=n_episode)
    return result



class Trainer:
    def __init__(self, policy, train_collector, test_collector,
        max_epoch, step_per_epoch, repeat_per_collect, episode_per_test,
        batch_size, step_per_collect, save_best_fn, logger, **kwargs):
        self.policy = policy

        self.train_collector = train_collector
        self.test_collector = test_collector

        self.best_reward = 0.0
        self.best_reward_std = 0.0
        self.start_epoch = 0
        self.env_step = 0
        self.max_epoch = max_epoch
        self.step_per_epoch = step_per_epoch

        # either on of these two
        self.step_per_collect = step_per_collect
        self.repeat_per_collect = repeat_per_collect
        self.episode_per_test = episode_per_test

        self.batch_size = batch_size
        self.save_best_fn = save_best_fn
        self.logger = logger
        self.last_rew = 0.0

        self.epoch = self.start_epoch
        self.best_epoch = self.start_epoch
        self.iter_num = 0
        self.gradient_step = 0

    def policy_update_fn(self):
        assert self.train_collector is not None
        result = self.policy.update(
            self.train_collector.buffer,
            batch_size=self.batch_size,
            repeat=self.repeat_per_collect,
        )
        self.train_collector.reset_buffer()
        self.logger.log_update_data(result, self.gradient_step)
        step = max([1] + [len(v) for v in result.values() if isinstance(v, list)])
        self.gradient_step += step


    def reset(self):
        self.env_step = 0

        self.last_rew = 0.0

        self.train_collector.reset_stat()
        self.test_collector.reset_stat()
        test_result = test_episode(
            self.policy, self.test_collector,
            self.episode_per_test
        )
        self.best_epoch = self.start_epoch
        self.best_reward, self.best_reward_std = \
            test_result["rew"], test_result["rew_std"]

        print("##################reset#########################")
        print(test_result["rews"])
        print(test_result["rew"])
        print("##################reset#########################")

        if self.save_best_fn:
            self.save_best_fn(self.policy)

        self.epoch = self.start_epoch
        self.iter_num = 0

    def run(self):  # type: ignore
        self.reset()

        while self.epoch<self.max_epoch-1:
            self.epoch += 1

            # set policy in train mode
            self.policy.train()

            n = 0
            while n <= self.step_per_epoch:
                data = dict()
                result = dict()

                data, result = self.train_step()
                n += result["n/st"]
                # print(n, data)
                self.logger.log_train_data(result,self.env_step)
                self.policy_update_fn()

            # test
            test_stat = self.test_step()
            print(self.epoch,test_stat)
            self.logger.log_test_data(test_stat,self.env_step)



    def test_step(self):
        """Perform one testing step."""
        assert self.episode_per_test is not None
        assert self.test_collector is not None
        test_result = test_episode(
            self.policy, self.test_collector,
            self.episode_per_test
        )
        print(test_result["rews"], test_result['rew'])
        rew, rew_std = test_result["rew"], test_result["rew_std"]
        if self.best_epoch < 0 or (self.best_reward < rew and rew_std < 500):
            self.best_epoch = self.epoch
            self.best_reward = float(rew)
            self.best_reward_std = rew_std
            if self.save_best_fn:
                self.save_best_fn(self.policy)

        test_stat = {
            "test_reward": rew,
            "test_reward_std": rew_std,
            "best_reward": self.best_reward,
            "best_reward_std": self.best_reward_std,
            "best_epoch": self.best_epoch
        }
        return test_stat

    def train_step(self):
        """Perform one training step."""
        assert self.episode_per_test is not None
        assert self.train_collector is not None
        result = self.train_collector.collect(
            n_step=self.step_per_collect
        )
        if result["n/ep"] > 0:
            rew = result["rews"]
            result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
        self.env_step += int(result["n/st"])
        self.last_rew = result["rew"] if result["n/ep"] > 0 else self.last_rew
        data = {
            "env_step": str(self.env_step),
            "rew": f"{self.last_rew:.2f}",
            "n/ep": str(int(result["n/ep"])),
            "n/st": str(int(result["n/st"])),
        }

        return data, result



