import gymnasium as gym
import numpy as np
import os

from typing import  Optional
from numpy.typing import NDArray

from gymnasium.core import ActType, ObsType
from gymnasium import Env


from loguru import logger
from molecule_movement.logging import log_and_raise

from enum import Enum, auto

class Task(Enum):
    ORIENTATION = auto()
    TRANSLATION = auto()
    DONE = 99

    def to_string(self) -> str:
        return self.name.lower()

class TaskSchedulingWrapper(
    gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
    def __init__(self,
                 env: Env[ObsType, ActType],
                 tasks: list[Task] = [Task.ORIENTATION, Task.TRANSLATION],
                 priorization: Task = Task.TRANSLATION,
                 force_finish: bool = False
                 ):
        super().__init__(env)
        gym.Wrapper.__init__(self, env)
        gym.utils.RecordConstructorArgs.__init__(self, tasks=tasks, priorization=priorization)
        self.tasks = tasks
        self.priorization = priorization
        self.force_finish = force_finish
        assert self.priorization in self.tasks, f"The task for priorization ({priorization=}) is not in the list of tasks ({tasks})"
        self.env = env

        assert isinstance(env, Env)

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[dict[str, NDArray], dict]:
        obs, info = self.env.reset(seed=seed,options=options)
        info['task'] = self.compute_next_task()
        self.current_task = info["task"]
        return obs, info

    def step(self, action):
        obs, reward, truncated, terminated, info = self.env.step(action)
        self.update_current_task()
        info["task"] = self.current_task
        return obs, reward, truncated, terminated, info

    def observation(self):
        return self.env.get_wrapper_attr("observation")()

    def compute_next_task(self) -> Task:
        oriented = self.orientation_completed()
        positioned = self.translation_completed()
        if self.priorization == Task.ORIENTATION:
            if not oriented:
                return Task.ORIENTATION
            elif oriented and Task.TRANSLATION in self.tasks and not positioned:
                return Task.TRANSLATION
            else:
                return Task.DONE
        if self.priorization == Task.TRANSLATION:
            if not positioned:
                return Task.TRANSLATION
            elif positioned and Task.ORIENTATION in self.tasks and not oriented:
                return Task.ORIENTATION
            else:
                return Task.DONE
        return Task.DONE

    def update_current_task(self) -> None:
        next_task = self.compute_next_task()
        if not self.force_finish:
            self.current_task = next_task
            return
        if next_task != self.current_task and not self.current_task_completed():
            next_task = self.current_task
        else:
            self.current_task = next_task

    def current_task_completed(self) -> bool:
        if self.current_task == Task.TRANSLATION:
            return self.translation_completed()
        elif self.current_task == Task.ORIENTATION:
            return self.orientation_completed()
        elif self.current_task == Task.DONE:
            return True
        assert False

    def translation_completed(self) -> bool:
        return self.env.get_wrapper_attr("reached_goal_position")()

    def orientation_completed(self) -> bool:
        return self.env.get_wrapper_attr("reached_goal_orientation")()

    def increment_matching(self) -> tuple[dict[str, NDArray], dict]:
        obs, info = self.env.get_wrapper_attr("increment_matching")()
        info['task'] = self.compute_next_task()
        self.current_task = info["task"]
        return obs, info

    def _reset_task(self, info) -> Task:
        self.current_task = self.priorization
        info["task"] = self.current_task
        #logger.bind(task="stats", current_task=self.current_task.to_string()).trace("")
        return info
