import logging
import os
from enum import Enum
import copy
from train.task import Task
from train.proc.proc_base import Subprocess, repeated_run
from train.behavioral_cloning.run_train import BehaviouralCloning
from train.behavioral_cloning.eval_hooks.eval_env_meta import EvalEnv
from train.common.misc import dotdict


class BehaviouralCloningData(Enum):
    General = 0
    Transition = 1


class BehaviouralCloningProc(Subprocess):
    def __init__(self, config, task: Task,
                 data_type: BehaviouralCloningData = BehaviouralCloningData.General,
                 env=None, n_workers: int = None):
        super(BehaviouralCloningProc, self).__init__('Behavioural Cloning', config)
        self.task = task
        self.bc = BehaviouralCloning()
        self.data_type = data_type
        self.model = None
        self.env = env
        self.n_workers = n_workers

    @repeated_run
    def run(self):
        # save before changing
        self.model = self.task.get_model()
        self.model.train()

        if self.data_type == BehaviouralCloningData.General:
            data = os.path.join(self.task.data_dir_target, self.task.target)
            recording_dir = os.path.join(self.config.bc.recording_dir, self.task.target)
        elif self.data_type == BehaviouralCloningData.Transition:
            data = os.path.join(self.task.data_dir_transition, self.task.id)
            recording_dir = os.path.join(self.config.bc.recording_dir, self.task.id)
        else:
            raise NotImplementedError('Proc: Unknown data enum for behavioural cloning.')

        self.config.bc.param_root = recording_dir
        self.config.bc.__dict__['data'] = data
        self.config.bc.env = self.task.data_id

        # we use MineRLObtainTreechop-v0 to pre-train agents gathering logs as in our testing it turned out that
        # treechop pretrained agents outperform agents trained purely on log-gathering-sequences from
        # MineRLObtainDiamond-v0. Not a strict requirement though.
        if self.config.bc.env == "MineRLTreechop-v0":
            self.config.bc.__dict__['data'] = self.config.subtask.rootdir
            self.config.bc.__dict__['dataset'] = self.config.bc.dataset_treechop
            self.config.bc.__dict__['model'] = self.config.bc.model_treechop
            self.config.bc.__dict__['train_strategy'] = self.config.bc.train_strategy_treechop

        eval_hook = None
        if self.config.bc.eval_in_environment and self.env is not None:
            if not os.path.exists(recording_dir):
                os.makedirs(recording_dir)
            eval_hook = EvalEnv(self.env, n_workers=self.n_workers, trials=self.config.bc.n_eval_trials,
                                max_steps=self.config.bc.eval_max_steps, verbosity=0,
                                watch_item=self.task.target, recording_dir=recording_dir)

        logging.info('Proc: Cloning behaviour...')
        self.bc.train_model(self.config.bc, model=self.model, eval_hook=eval_hook)
        self.model = self.bc.best_model
        self._complete()
