import logging

from train.task import Task
from train.proc.proc_base import Subprocess, repeated_run
from train.reinforcment_learning.run import ReinforcementLearning
from train.common.utils import import_module


class ReinforcementLearningProc(Subprocess):
    def __init__(self, config, envs, task: Task):
        super(ReinforcementLearningProc, self).__init__('Reinforcement Learning', config)
        self.envs = envs
        self.task = task
        self.rl = ReinforcementLearning()
        self.model = None

    @repeated_run
    def run(self):
        logging.info('Proc: Getting pre-trained model...')
        self.model = self.task.get_model()
        self.model.train()
        dataset = import_module(self.config.rl.dataset)
        if self.config.debug.enabled:
            logging.info('Proc: Running reinforcement learning training in debug mode...')
            self.config.rl.total_timesteps = self.config.debug.num_rl_timesteps
            self.config.rl.train_value_timesteps = self.config.debug.num_rl_value_timesteps
            self.rl.train(self.config.rl, self.envs, self.model, dataset)
        else:
            logging.info('Proc: Running reinforcement learning training...')
            self.rl.train(self.config.rl, self.envs, self.model, dataset)
        self.model = self.rl.model
        self._complete()
