import gymnasium as gym
import numpy as np

import os
from typing import Optional, Type

from loguru import logger

import shapely
from shapely import Point

from molecule_movement.envs.MoleculeEnvironment import MoleculeEnvironment
from molecule_movement import Molecule, Goal, Matching, AngleSymmetry
from molecule_movement.constants import MoietyType


class StarTrainingEnv(MoleculeEnvironment):
    def __init__(self,
                 **kwargs):
        logger.info(kwargs)
        super().__init__(**kwargs)
        assert len(self.system_configuration) == 1
        self.type = list(self.system_configuration.keys())[0]
        configuration = self.system_configuration[self.type]
        self.shape = configuration.moiety.shape
        self.response_map = configuration.response_map
        self.molecule_action_space = configuration.action_space
        self.moiety_point_symmetry = configuration.moiety.point_symmetry

    def _create_initial_distribution(self, seed: Optional[int] = None):
        self.molecules = list()
        initial_pos = Point(self.surface_width /2, self.surface_height / 2)
        self.molecules = [(Molecule(Point(initial_pos.x, initial_pos.y),
                                       self.shape,
                                       self.response_map,
                                       self.num_sensors,
                                       orientation="random",
                                       name="i",
                                       type=self.type,
                                       molecule_point_symmetry=self.moiety_point_symmetry,
                                       substrate_point_symmetry=self.substrate_point_symmetry,
                                       action_space=self.molecule_action_space))]

    def _set_goals(self, seed: Optional[int] = None) -> None:
        angle_symmetry = AngleSymmetry(molecule_point_symmetry=4, substrate_point_symmetry=self.substrate_point_symmetry)
        initial_goal_position = Point(self.surface_width / 2, 0.15 * self.surface_height / 2)
        self.goals = [Goal(initial_goal_position, self.shape, angle_symmetry.random_angle(), self.type)]
        num_arms = 15
        angles = list(range(360 // num_arms, 360, 360 // num_arms))
        for angle in angles:
            goal_position = shapely.affinity.rotate(initial_goal_position, angle, origin=(self.surface_width//2,self.surface_height//2))
            self.goals.append(Goal(goal_position, self.shape, angle_symmetry.random_angle(), self.type))

    def reached_goal_orientation(self) -> bool:
        return np.abs((int(self.current_molecule.orientation)) - int(self.current_matching.goal.rotation)) == 0

    def _set_obstacles(self, seed: Optional[int] = None) -> None:
        super()._set_obstacles(seed)

    def _get_matching(self, options: Optional[dict] = None, seed: Optional[int] = None):
        self.matching = [Matching(self.molecules[0], self.goals[np.random.choice(len(self.goals))])]

    def _parse_molecule_data(self):
        pass

    def terminated(self):
        if self.type == MoietyType.Circular:
            info = {"crashed": False, "reached_goal": False, "reached_goal_orientation": False, "destroyed": False}
            if self.current_molecule.crashed:
                info["crashed"] = True
            if self.reached_goal_position():
                info["reached_goal"] = True
            if self.reached_goal_position() and self.reached_goal_orientation():
                info["reached_goal_orientation"] = True
            logger.bind(task="stats", crashed=int(info["crashed"])).trace("")
            logger.bind(task="stats", reached=int(info["reached_goal"])).trace("")
            logger.bind(task="stats", reached_orientation_at_goal=int(info["reached_goal_orientation"])).trace("")
            return self.current_molecule.crashed or self.reached_goal_position(), info
        else:
            return super().terminated()
