from abc import ABC
import time

import gym
import numpy as np

from tasks.dynamic_programming import policy_evaluation
import tasks.envs


class Task(ABC):
    def __init__(self, env_id: str, discount: float, duration: int, seed: int):
        self.env = env = tasks.envs.make(env_id)

        assert 0.0 <= discount <= 1.0
        self.discount = discount

        assert duration > 0
        self.duration = duration

        self.start_time = time.time()

        # Set random seeds
        try:
            env.seed(seed)
            self.seed = None
        except AttributeError:
            self.seed = seed
        env.action_space.seed(seed)

    def step(self, action):
        return self.env.step(action)

    def reset(self):
        if self.seed is not None:
            obs, info = self.env.reset(seed=self.seed)
            self.seed = None
            return obs, info

        return self.env.reset()

    def close(self):
        return self.env.close()

    def time(self):
        return time.time() - self.start_time


class PredictionTask(Task):
    def __init__(self, env_id: str, discount: float, duration: int, seed: int):
        # TODO: Pass in behavior policy as argument
        super().__init__(env_id, discount, duration, seed)
        env = self.env.unwrapped
        assert isinstance(env.observation_space, gym.spaces.Discrete)
        assert isinstance(env.action_space, gym.spaces.Discrete)
        n = env.action_space.n
        behavior_policy = lambda s: np.ones(n) / n
        # NOTE: This only works for gym-classics environments
        self.v_pi = policy_evaluation(env, self.discount, behavior_policy, precision=1e-9)
        self.all_observations = np.stack([self.env.observation(s) for s in env.states()])

    def policy(self):
        b_prob = t_prob = 1.0 / self.env.unwrapped.observation_space.n
        return self.env.action_space.sample(), b_prob, t_prob

    def msve(self, agent):
        v = agent.predict(self.all_observations)
        return np.mean(np.square(v - self.v_pi))


class ControlTask(Task):
    def __init__(self, env_id: str, discount: float, duration: int, seed: int):
        super().__init__(env_id, discount, duration, seed)
        self._undisc_return = None
        self.done = False

    @property
    def undiscounted_return(self):
        assert self.done, "must wait until end of episode"
        return self._undisc_return

    def step(self, action):
        assert not self.done
        obs, reward, terminated, truncated, info = super().step(action)
        self._undisc_return += reward
        self.done = terminated or truncated
        return obs, reward, terminated, truncated, info

    def reset(self):
        self._undisc_return = 0.0
        self.done = False
        return super().reset()
