import gym
from typing import List
from modules.utils.Utils import printProgressBar
from abc import abstractmethod
from modules.agents.Agent import Agent
from modules.agents.tabular.OTTRPO import DiscreteOTTRPOAgent
from modules.train.Oracle import Oracle
from modules.train.FiniteSpaceOracle import OracleMC
from modules.utils.Distances import *

class SimEnv:
    _env = None
    _agents = None
    _oracle = None

    def env(self):
        if self._env is None:
            self._env = self._singleton_env()
        return self._env

    def agents(self):
        if self._agents is None:
            self._agents = self._singleton_agents()
        return self._agents

    def oracle(self):
        if self._oracle is None:
            self._oracle = self._singleton_oracle()
        return self._oracle

    def test_agent(self, agent, env, n_tests=1000, render=True):
        events = self.reward_to_event_map()
        stats = { key: 0 for key in events.keys() }
        stats["Trajectory length"] = 0
        for e in range(n_tests):
            self._test_episode(agent, env, events, stats, render)
            if render:
                print(f"######### END OF EPISODE {e} #########")
            else:
                printProgressBar(e+1, n_tests, prefix='Policy simulation: ', suffix='', decimals=1, length=100, fill='█', printEnd="\r")
        print(f"Stats averaged over {n_tests} runs:")
        for k in stats.keys():
            print(f"{k}:", stats[k] / n_tests)

    @abstractmethod
    def _singleton_oracle(self):
        "Make oracle"
        pass

    @abstractmethod
    def _singleton_env(self):
        "Make env"
        pass

    @abstractmethod
    def _singleton_agents(self):
        "Make env"
        pass

    @abstractmethod
    def batch_size(self) -> int:
        "Training params: batch size"
        pass

    @abstractmethod
    def n_batch(self) -> int:
        "Number of batches"
        pass

    @abstractmethod
    def reward_to_event_map(self) -> dict:
        "Return reward: event name map"
        pass

    def _test_episode(self, agent, env, events, stats, render=True):
        state = env.reset()
        actions, rewards = [], []
        done = False
        while not done:
            if render:
                env.render()
            action = agent.take_action(state)
            actions.append(action)
            state, reward, done, _ = env.step(action)
            rewards.append(reward)
            for key in events.keys():
                if reward == events[key]:
                    stats[key] += 1

        stats["Trajectory length"] += len(actions) + 1


#########################
# TAXI
#########################
class SimEnvTaxi(SimEnv):
    def _singleton_env(self):
        return gym.make("Taxi-v3")
    
    def _singleton_agents(self) -> List[Agent]:
        n_state = self.env().observation_space.n
        n_action = self.env().action_space.n
        return [
            DiscreteOTTRPOAgent(None, { "p": 1, "epsilon": 0.1, "dist": binary_distance, "n_state": n_state, "n_action": n_action })
            # DiscreteOTTRPOAgent(None, { "p": 1, "epsilon": 0.01, "dist": taxi_distance, "n_state": n_state, "n_action": n_action })
        ]

    def _singleton_oracle(self) -> Oracle:
        return OracleMC(self.env(), discount_factor=0.9, learning_rate=0.5, q_init=0)

    def batch_size(self) -> int:
        return 32

    def n_batch(self) -> int:
        return 1000

    def reward_to_event_map(self) -> dict:
        return {
            "Illegal action": -10,
            "Successfull dropoff": 20
        }


#########################
# Cliff
#########################
class SimEnvCliff(SimEnv):
    def _singleton_env(self):
        return gym.make("CliffWalking-v0")
    
    def _singleton_agents(self) -> List[Agent]:
        n_state = self.env().observation_space.n
        n_action = self.env().action_space.n
        return [
            DiscreteOTTRPOAgent(None, { "p": 1, "epsilon": 0.01, "dist": binary_distance, "n_state": n_state, "n_action": n_action })
        ]

    def _singleton_oracle(self) -> Oracle:
        return OracleMC(self.env(), 0.999999, 0.2, 0)

    def batch_size(self) -> int:
        return 1

    def n_batch(self) -> int:
        return 20000

    def reward_to_event_map(self) -> dict:
        return {
            "Failed": -100
        }

class ENVIRONMENTS:
    TAXI = ("Taxi", SimEnvTaxi())
    CLIFF = ("Cliff", SimEnvCliff())


