import traceback

import hydra

from utils.utils_fn import set_seed


class Training(object):
    def __init__(self, cfg) -> None:
        self.cfg = cfg
        self.device = self.cfg.device

        ### Seed Random Seeds
        self.seed = int(self.cfg.seed)
        set_seed(self.seed)

        ### Create Env
        self.env = hydra.utils.instantiate(
            self.cfg.data_cfg.task, seed=self.seed, device=self.device
        )
        self.eval_env = hydra.utils.instantiate(
            self.cfg.data_cfg.task, seed=self.seed, device=self.device
        )

        ### Policy eval parameters
        self.eval_frequency = self.cfg.eval_frequency
        self.num_eval_episodes = self.cfg.num_eval_episodes

        ### Feedback collector
        self.feedback_collector = hydra.utils.instantiate(
            self.cfg.data_cfg.feedback_collector,
            env=self.env,
            device=self.device,
        )

        self.replay_buffer = hydra.utils.instantiate(
            self.cfg.replay_buffer_cfg.buffer, env=self.env, device=self.device
        )

        ### Create Agent
        self.agent = hydra.utils.instantiate(
            self.cfg.agent_cfg.agent_cfg,
            device=self.device,
            env=self.env,
            _recursive_=False,
        )

        ### Create Advantage Function (reward model)
        self.reward_model = hydra.utils.instantiate(
            self.cfg.reward_cfg.reward,
            env=self.env,
            agent=self.agent,
            device=self.device,
        )

        ### Add pretraing
        self.pretrainer = hydra.utils.instantiate(
            self.cfg.pretrain_cfg.pretrainer, device=self.device
        )

        ### Training Prep
        self.num_train_steps = self.cfg.num_train_steps

        ### Create Logger
        self.logger = hydra.utils.instantiate(self.cfg.logger_cfg, self.cfg)

    def eval_policy(self, step):
        eval_metrics = {}
        if step >= 0 and step % self.cfg.eval_frequency == 0:
            self.logger.set_phase("eval")

            eval_metrics = self.eval_env.evaluate(
                agent=self.agent,
                reward=self.reward_model,
                num_episodes=self.cfg.num_eval_episodes,
                step=step,
                logger=self.logger,
            )
        return eval_metrics

    def train(self):

        ### Pretraining
        self.logger.set_phase("pretrain")
        self.agent.start_training(env=self.env)

        self.train_step = self.pretrainer.pretrain(
            training_alg=self,
            env=self.env,
            agent=self.agent,
            replay_buffer=self.replay_buffer,
            feedback_collector=self.feedback_collector,
            reward_model=self.reward_model,
            logger=self.logger,
            eval_env=self.eval_env,
        )

        ### Training Loop
        while self.train_step <= self.num_train_steps:

            # Training Evaluations
            self.eval_policy(step=self.train_step)

            ### Phase 1: Feedback collection
            self.logger.set_phase("feedback_collection")
            self.feedback_collector.collect_feedback(
                step=self.train_step,
                agent=self.agent,
                reward_model=self.reward_model,
                replay_buffer=self.replay_buffer,
                logger=self.logger,
            )

            ### Phase 2: Train Advantage
            self.logger.set_phase("reward_training")
            self.reward_model.train_reward(
                feedback_collector=self.feedback_collector,
                agent=self.agent,
                replay_buffer=self.replay_buffer,
                step=self.train_step,
                logger=self.logger,
            )

            ## Phase 3: Policy training
            self.logger.set_phase("policy_training")
            num_steps = self.agent.update(
                replay_buffer=self.replay_buffer,
                reward_model=self.reward_model,
                feedback_collector=self.feedback_collector,
                env=self.env,
                logger=self.logger,
                step=self.train_step,
            )

            self.train_step += num_steps

        self.logger.finish()


@hydra.main(
    version_base=None, config_path="config", config_name="experiments_paws.yaml"
)
def main(cfg):
    try:
        training = Training(cfg)
        training.train()
    except Exception as ex:
        print("-- exception occured. traceback :")
        traceback.print_tb(ex.__traceback__)
        print(ex, flush=True)
        print("--------------------------------\n")
        traceback.print_exception(ex)


if __name__ == "__main__":
    main()
