import gymnasium as gym
import numpy as np
import os

from typing import Generic, Optional
from numpy.typing import NDArray

from gymnasium.core import ActType, ObsType, WrapperObsType, WrapperActType, Wrapper
from gymnasium import Env

from ..colour_utils import Highlight

from molecule_movement import Matching
from molecule_movement.envs import ExperimentalEnvironment
from molecule_movement.matching import RandomMatching, GreedyMatching, HungarianMatching
from molecule_movement.scheduling import SATBasedScheduling
from molecule_movement.exceptions import InfeasibleError, AssumptionsError

from shapely import Point, LineString
import shapely

from loguru import logger
from molecule_movement.logging import log_and_raise

from pytictoc import TicToc

from PIL import Image

from dataclasses import dataclass

from datetime import datetime


def is_left(matching: Matching, c: Point) -> bool:
    a = matching.molecule.center
    b = matching.goal.position
    return (b.x - a.x)*(c.y - a.y) - (b.y - a.y)*(c.x - a.x) > 0

@dataclass
class MatchingConflicts:
    matching: list[Matching]
    conflicts: dict[str, Matching]

    def __post_init__(self):
        pass
        #self.__analyze_explanations()

    def __analyze_explanations(self):
        self.conflicting_goals = list()
        self.conflicting_starts = list()
        for explanation, conflict in self.conflicts.items():
            split = str(explanation).split("-")
            name = conflict.molecule.name
            if split[0] == "goal":
                self.conflicting_goals.append(conflict)
            elif split[0] == "start":
                self.conflicting_starts.append(conflict)


    def __str__(self) -> str:
        expl = ""
        for explanation in self.conflicts.keys():
            split = str(explanation).split("-")
            if split[0] == "goal":
                expl += f"Goal of {split[2]} lies in corridor of {split[1]}\n"
            elif split[0] == "start":
                expl += f"Start of {split[2]} lies in corridor of {split[1]}\n"
        return expl


class SATBasedSchedulingWrapper(
    gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
    def __init__(self,
                 env: Env[ObsType, ActType],
                 corridor_width: Optional[float] = None,
                 store_explanations: dict[str, str | os.PathLike] = dict()):
        super().__init__(env)
        gym.utils.RecordConstructorArgs.__init__(self)
        gym.Wrapper.__init__(self, env)
        self.env = env
        self.store_explanations = store_explanations
        self.conflict_history = list()

        if corridor_width is not None:
            self.corridor_width = corridor_width
            try:
                logger.info(f"Changing corridor_width to {self.corridor_width=} (changed from previously set {env.get_wrapper_attr('corridor_width')=})")
            except AttributeError as e:
                logger.info(f"Set corridor_width to {self.corridor_width=}")
                logger.trace(e)
                pass
        else:
            try:
                self.corridor_width = env.get_wrapper_attr("corridor_width")
            except Exception as e:
                logger.error(e)
        assert isinstance(env, Env)
        assert self.__check_storage_types()

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[dict[str, NDArray], dict]:
        obs, info = self.env.reset(seed=seed,options=options)
        self.corridor_config = options.get("corridor_config", None) if options else None
        if not self.corridor_config:
            self.corridor_config = self.env.get_wrapper_attr("corridor_config")
        assert self.corridor_config, "Need to provide corridor_config to compute SAT-based schedule."
        self.__setup_directories()
        self.env.unwrapped.get_wrapper_attr("_set_matching")(self.__compute_schedule())
        ### FIXXME This is not needed when simulating only when manipulating on the STM.
        ## This might do the trick, but needs testing
        if isinstance(self.env.unwrapped, ExperimentalEnvironment):
            self.env.get_wrapper_attr('active_moiety_index')
            self.env.get_wrapper_attr("_set_center_of_scan_position")(self.env.get_wrapper_attr("current_molecule").center)
            self.env.get_wrapper_attr('_get_current_stm_state')()
            self.env.get_wrapper_attr('_update_current_moiety')()
        return self.env.get_wrapper_attr('observation')(), info

    def __compute_schedule(self) -> list[Matching]:
        molecules = self.env.get_wrapper_attr("molecules")
        goals = self.env.unwrapped.get_wrapper_attr("goals")
        obstacles = self.env.unwrapped.get_wrapper_attr("obstacles")

        matching_provided = False
        matching_factory_provided = False
        try:
            matching_factory = self.env.get_wrapper_attr("matching_factory")
            matching_factory_provided = True
        except AttributeError as e:
            matching_factory = HungarianMatching(molecules, goals, obstacles, respect_obstacles=True, corridor_config=self.corridor_config)

        try:
            matching = self.env.get_wrapper_attr("matching")
            matching_provided = True
        except AttributeError as e:
            logger.warning(e)


        t = TicToc()
        matching_found = False
        iteration_count = 0
        t.tic()
        if not matching_provided:
            matching = matching_factory.compute_matching()
        matching_length_initial = sum([m.length for m in matching])

        first_iter = True

        info = dict()
        while not matching_found:
            unsat_matching_dict = dict()

            scheduler = SATBasedScheduling(molecules, goals, obstacles, matching, corridor_config=self.corridor_config, separate_conflicts=True)
            self.render_conflicts(matching, matching_highlight={name:Highlight(True, colour="", draw_corridor=True) for name in scheduler.conflicting_molecules_names})
            schedule, conflicts, info = scheduler.compute_schedule(graph_storage=self.store_explanations.get("img", None), cache=None, cache_sat_schedules=False)
            #schedule, conflicts, info = scheduler.compute_schedule(graph_storage=self.store_explanations.get("img", None), cache=info.get("sat_schedule_cache", dict()), cache_sat_schedules=True)
            try:
                logger.trace(f"Iteration {iteration_count}: Current SAT schedule cache size: {len(info['sat_schedule_cache'])}")
            except Exception:
                logger.trace(f"Iteration {iteration_count}")
            if first_iter:
                first_iter = False
                #logger.bind(task="scheduling", conflicts=str(info["conflict"]).replace(",", ";"), nested_clashes=str(info["nested_clashes"]).replace(",",";")).trace("")
            if conflicts and len(conflicts) >= 0:
                infs = list()
                for explanation, conflict in conflicts.items():
                    infs.append(conflict.molecule.name)
                    matching_factory.set_inf_weight(conflict.molecule, conflict.goal)
                    unsat_matching_dict[conflict.molecule.name] = Highlight(True, "red", 2, True)
                #logger.bind(task="scheduling", infs=str(infs)).trace("")

                self.conflict_history.append(MatchingConflicts(matching=matching, conflicts=conflicts))
                self.render_conflicts(matching, matching_highlight=unsat_matching_dict)
                img = self.env.unwrapped.get_wrapper_attr("get_image")()
                self._store_explanations(img, iteration_count)
                try:
                    matching = matching_factory.compute_matching()
                except InfeasibleError as e:
                    time_needed = round(t.tocvalue(),3)
                    logger.bind(task="scheduling", newline=True, sat_based_iterations=-1, matching_length_initial=matching_length_initial, matching_length_final=-1, sat_based_time_needed=time_needed).info(f"Could not provide a SAT-based schedule: {e}")
                    raise e
                iteration_count += 1
            else:
                matching_found = True
                self.render_conflicts(schedule)
                img = self.env.unwrapped.get_wrapper_attr("get_image")()
                self._store_explanations(img, iteration_count)
                matching = schedule

        time_needed = round(t.tocvalue(),3)
        logger.bind(task="scheduling", newline=True, sat_based_iterations=iteration_count + 1, matching_length_initial=matching_length_initial, matching_length_final=sum([m.length for m in matching]), sat_based_time_needed=time_needed).trace(f"Needed {iteration_count + 1} iterations to compute the SAT-based schedule in {time_needed} seconds.")
        return matching


    def render_conflicts(self, matching: list[Matching], matching_highlight: Optional[dict[str,Highlight]]=dict() ) -> None:
        if self.env.render_mode == "human":
            renderer = self.env.get_wrapper_attr("renderer")
            self.env.unwrapped.render(surpress_matching=True)
            for index, matching in enumerate(matching):
                if matching.molecule.name in matching_highlight:
                    highlight = matching_highlight[matching.molecule.name]
                    renderer.render_matching(matching, highlight=highlight, index=index+1, corridor_config=self.corridor_config)
                else:
                    renderer.render_matching(matching, highlight=Highlight(enabled=True, colour="", draw_corridor=False), index=index+1, corridor_config=self.corridor_config)
            renderer.update()

    def _store_explanations(self, img: np.ndarray, index: int) -> None:
        if self.store_explanations is None or len(self.store_explanations) == 0:
            return
        render_mode = self.env.unwrapped.get_wrapper_attr("render_mode")
        if not (render_mode == "human" or render_mode == "rgb_array"):
            return
        for storage_type, directory in self.store_explanations.items():
            if storage_type == "img":
                seed = str(self.env.unwrapped.get_wrapper_attr("seed"))
                Image.fromarray(img).save(os.path.join(directory, f"{datetime.now().isoformat()}_SATBasedScheduling_seed_{seed}_iteration_{index:03}.png"))
            if storage_type == "tensorboard":
                self.tb_writer.add_image('sat_based_scheduling/schedule', img, index, dataformats="HWC")

    def __setup_directories(self) -> None:
        for storage_type, directory in self.store_explanations.items():
            if storage_type == "tensorboard":
                try:
                    from torch.utils.tensorboard import SummaryWriter
                    directory = os.path.join(directory,f"{self.env.unwrapped.spec.id.replace('/','_')}_seed_{self.env.unwrapped.get_wrapper_attr('seed')}")
                    os.makedirs(directory, exist_ok=True)
                    self.tb_writer = SummaryWriter(log_dir=directory)
                except Exception as e:
                    log_and_raise(e, "Could not setup directory for tensorboard")
            if storage_type == "img":
                try:
                    os.makedirs(directory, exist_ok=True)
                except Exception as e:
                    log_and_raise(e, f"Could not create directory {directory}")

    def __check_storage_types(self) -> bool:
        for storage_type, _ in self.store_explanations.items():
            if storage_type == "tensorboard":
                try:
                    from torch.utils.tensorboard import SummaryWriter
                except ModuleNotFoundError as e:
                    log_and_raise(e, "Could not import torch")
        return True

