import os
import pickle
import numpy as np

from ray.tune import result as tune_result
from src.rllib.agents.trainer import Trainer, with_common_config


class _MockTrainer(Trainer):
    """Mock trainer for use in tests"""

    _name = "MockTrainer"
    _default_config = with_common_config({
        "mock_error": False,
        "persistent_error": False,
        "test_variable": 1,
        "num_workers": 0,
        "user_checkpoint_freq": 0,
        "framework": "tf",
    })

    @classmethod
    def default_resource_request(cls, config):
        return None

    def _init(self, config, env_creator):
        self.info = None
        self.restored = False

    def step(self):
        if self.config["mock_error"] and self.iteration == 1 \
                and (self.config["persistent_error"] or not self.restored):
            raise Exception("mock error")
        result = dict(
            episode_reward_mean=10,
            episode_len_mean=10,
            timesteps_this_iter=10,
            info={})
        if self.config["user_checkpoint_freq"] > 0 and self.iteration > 0:
            if self.iteration % self.config["user_checkpoint_freq"] == 0:
                result.update({tune_result.SHOULD_CHECKPOINT: True})
        return result

    def save_checkpoint(self, checkpoint_dir):
        path = os.path.join(checkpoint_dir, "mock_agent.pkl")
        with open(path, "wb") as f:
            pickle.dump(self.info, f)
        return path

    def load_checkpoint(self, checkpoint_path):
        with open(checkpoint_path, "rb") as f:
            info = pickle.load(f)
        self.info = info
        self.restored = True

    def _register_if_needed(self, env_object, config):
        pass

    def set_info(self, info):
        self.info = info
        return info

    def get_info(self, sess=None):
        return self.info


class _SigmoidFakeData(_MockTrainer):
    """Trainer that returns sigmoid learning curves.

    This can be helpful for evaluating early stopping algorithms."""

    _name = "SigmoidFakeData"
    _default_config = with_common_config({
        "width": 100,
        "height": 100,
        "offset": 0,
        "iter_time": 10,
        "iter_timesteps": 1,
        "num_workers": 0,
    })

    def step(self):
        i = max(0, self.iteration - self.config["offset"])
        v = np.tanh(float(i) / self.config["width"])
        v *= self.config["height"]
        return dict(
            episode_reward_mean=v,
            episode_len_mean=v,
            timesteps_this_iter=self.config["iter_timesteps"],
            time_this_iter_s=self.config["iter_time"],
            info={})


class _ParameterTuningTrainer(_MockTrainer):

    _name = "ParameterTuningTrainer"
    _default_config = with_common_config({
        "reward_amt": 10,
        "dummy_param": 10,
        "dummy_param2": 15,
        "iter_time": 10,
        "iter_timesteps": 1,
        "num_workers": 0,
    })

    def step(self):
        return dict(
            episode_reward_mean=self.config["reward_amt"] * self.iteration,
            episode_len_mean=self.config["reward_amt"],
            timesteps_this_iter=self.config["iter_timesteps"],
            time_this_iter_s=self.config["iter_time"],
            info={})


def _trainer_import_failed(trace):
    """Returns dummy agent class for if PyTorch etc. is not installed."""

    class _TrainerImportFailed(Trainer):
        _name = "TrainerImportFailed"
        _default_config = with_common_config({})

        def setup(self, config):
            raise ImportError(trace)

    return _TrainerImportFailed
