from __future__ import annotations
from model_checking.labelling import get_property
import importlib
from itertools import compress
from gym.envs.atari import AtariEnv
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from collections import defaultdict 

import random
import omnisafe
from typing import Any, ClassVar
import torch
import torch.nn.functional as F
from omnisafe.envs.core import CMDP, env_register, env_unregister

class LabellingEnv(gym.Wrapper):
    """Adds a label from a labeller to the info returned at each step"""

    def __init__(self, env, labellers):
        super().__init__(env)
        self.env = env
        self.labellers = labellers

    def reset(self):
        obs = self.env.reset()
        for labeller in self.labellers:
            labeller.reset()
        return obs

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        labels = [labeller.label(obs, reward, done, info) for labeller in self.labellers]
        info['labels'] = labels
        return obs, reward, done, info

@env_register
@env_unregister
class Seaquest(CMDP):

  LOCK = None

  _support_envs: ClassVar[list[str]] = ['SeaquestCMDP-v0', 'SeaquestCMDP-v1']  # Supported task names

  need_auto_reset_wrapper = True  # Whether `AutoReset` Wrapper is needed
  need_time_limit_wrapper = True  # Whether `TimeLimit` Wrapper is needed


  def __init__(self, env_id: str, seed=0, repeat=4, size=128, gray=False, noops=30, lives='unused',
      sticky=True, actions='all', length=108000, resize='opencv', cost_val=1.0, reset_on_death=False, **kwargs) -> None:
    assert lives in ('unused', 'discount', 'reset'), lives
    assert actions in ('all', 'needed'), actions
    assert resize in ('opencv', 'pillow'), resize
    self.env_id = env_id
    self._num_envs = 1
    if self.env_id == 'SeaquestCMDP-v0':
        property_file="property_1"
    elif self.env_id == 'SeaquestCMDP-v1':
        property_file="property_3"
    else:
        raise RuntimeError(f"cost function not specified for env id {env_id}")
    if self.LOCK is None:
      import multiprocessing as mp
      mp = mp.get_context('spawn')
      self.LOCK = mp.Lock()
    self._resize = resize
    if self._resize == 'opencv':
      import cv2
      self._cv2 = cv2
    if self._resize == 'pillow':
      from PIL import Image
      self._image = Image
    self._rep_indices = np.array([30,34,70,97,71,75,86, 87,102,103,59,62])
    self._repeat = repeat
    self._size = len(self._rep_indices)
    self._gray = gray
    self._noops = noops
    self._lives = lives
    self._sticky = sticky
    self._length = length
    self._random = np.random.RandomState(seed)
    self._cost_val = cost_val
    self._reset_on_death = reset_on_death
    with self.LOCK:
      self._env = AtariEnv(
          game='seaquest',
          obs_type='rgb',
          frameskip=1, repeat_action_probability=0.25 if sticky else 0.0,
          full_action_space=(actions == 'all'))
    assert self._env.unwrapped.get_action_meanings()[0] == 'NOOP'
    self._n_actions = self._env.action_space.n
    self._observation_space = spaces.Box(low=0, high=1., shape=(self._size * 2 + 1,), dtype=np.float32)
    self._action_space = spaces.Box(low=-5, high=2, shape=(self._n_actions,), dtype=np.float32)
    # labelling fn
    self._atomic_propositions = ['surface', 'diver', 'early-surface', 'out-of-oxygen', 'death']
    labellers = [get_property(self._env, 'seaquest', label) for label in self._atomic_propositions]
    self._env = LabellingEnv(self._env, labellers)
    self._property = property_file
    ctor = f'model_checking.properties.{self._property}:cost_function'
    if isinstance(ctor, str):
        module, cls = ctor.split(':')
        module = importlib.import_module(module)
        ctor = getattr(module, cls)
    self._cost_function = ctor
    ctor = f'model_checking.properties.{self._property}:automaton'
    if isinstance(ctor, str):
        module, cls = ctor.split(':')
        module = importlib.import_module(module)
        ctor = getattr(module, cls)
    self._automaton = ctor
    self._n_automaton_states = len(self._automaton.states)
    self._automaton_state = self._automaton.initial
    shape = self._env.observation_space.shape
    self._buffer = np.zeros((2, self._size), np.float32)
    self._ale = self._env.unwrapped.ale
    self._last_lives = None
    self._step = 0

  def set_seed(self, seed: int) -> None:
    self._random = np.random.RandomState(seed)

  def step(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict]:
    action_probs = F.softmax(action, dim=-1).detach().cpu().numpy()
    discrete_action = np.random.choice(self._n_actions, p=action_probs)

    total = 0.0
    total_cost = 0.0
    dead = False
    labels = set()
    for repeat in range(self._repeat):
      _, reward, over, info = self._env.step(discrete_action)

      ram = self._get_ram()
      self._buffer[1] = ram - self._buffer[0]
      self._buffer[0] = ram 

      self._step += 1
      total += reward

      current = self._ale.lives()
      if current < self._last_lives:
        self._last_lives = current
        if self._reset_on_death:
          self._automaton_state = self._cost_function.reset()
          labels = set(compress(self._atomic_propositions, info['labels']))
        else:
          labels.update(set(compress(self._atomic_propositions, info['labels'])))
        if self._lives != 'unused':
          dead = True
          break
      else:
        labels.update(set(compress(self._atomic_propositions, info['labels'])))
      if over:
        break
    cost, self._automaton_state = self._cost_function.step(labels)
    total_cost += cost * self._cost_val

    terminated = torch.as_tensor(dead or over)
    truncated = torch.as_tensor((self._length and self._step >= self._length))
    reward = torch.as_tensor(total)
    cost = torch.as_tensor(total_cost)
    obs = torch.as_tensor(self._obs())
    info = {}
    info.update({'final_observation': obs})
    return obs, reward, cost, terminated, truncated, info

  def reset(self, seed: int | None = None,options: dict[str, Any] | None = None) -> tuple[torch.Tensor, dict]:
    if seed is not None:
      self.set_seed(seed)
    self._env.reset()
    self._automaton_state = self._cost_function.reset()
    if self._noops:
      for _ in range(self._random.randint(self._noops)):
         _, _, dead, _ = self._env.step(0)
         if dead:
           self._env.reset()
    self._last_lives = self._ale.lives()
    ram = self._get_ram()
    self._buffer[0] = ram
    self._buffer[1].fill(0)
    self._step = 0
    obs = torch.as_tensor(self._obs())
    return obs, {}

  def _obs(self):
    buffer = self._buffer.flatten()
    virtual_automaton_state = [self._automaton_state / self._n_automaton_states]
    return np.concatenate([buffer, virtual_automaton_state], axis=0).astype(np.float32)

  def _get_ram(self):
    full_ram = np.array(self._env.ale.getRAM())
    return full_ram[self._rep_indices]

  @property
  def max_episode_steps(self) -> None:
    """The max steps per episode."""
    return int(self._length/self._repeat)

  def render(self) -> Any:
    pass

  def close(self) -> None:
    return self._env.close()