import numpy as np
import gymnasium as gym
import os
from gymnasium import spaces
import robosuite
from robosuite.controllers import load_controller_config
import imageio
import copy
import cv2
from Environment.environment import Environment, EnvObject, Action, Goal, strip_instance
from Environment.Environments.RoboPushing.robopushing_specs import *
from Record.file_management import numpy_factored, display_frame
from collections import deque
import robosuite.macros as macros
macros.SIMULATION_TIMESTEP = 0.02

# control_freq, num_obstacles, standard_reward, goal_reward, obstacle_reward, out_of_bounds_reward, 
# joint_mode, hard_obstacles, planar_mode


DEFAULT = 0
JOINT_MODE = 1
HARD_MODE = 2
PLANAR_MODE = 3
DISCRETE_MODE = 4

a = [[-0.7,0,0],
   [0.7,0,0],
   [0,-0.7,0],
   [0,0.7,0],
   [0,0,-0.25],
   [0,0,.25]]
DISCRETE_MOVEMENTS = np.array(a).astype(float)


class RoboGoal(Goal):
    def __init__(self, **kwargs):
        self.name = "Goal"
        self.attribute = np.ones(2)
        self.interaction_trace = list()
        self.target_idx = -4
        self.partial = 2 # should be the length of the indices to use as goals, SET IN SUBCLASS
        self.bounds = kwargs["bounds"]
        self.bounds_lower = self.bounds[0]
        self.range = (self.bounds[1] - self.bounds[0])[:self.partial]
        self.all_names = kwargs["all_names"]
        self.goal_epsilon = kwargs["goal_epsilon"]
        self.obstacle_radius = kwargs["obstacle_radius"]

    def generate_bounds(self):
        return self.bounds_lower[:self.partial], self.bounds[1][:self.partial]

    def sample_goal(self, obstacle_pos):
        goal = np.random.rand(2) * self.range + self.bounds_lower
        if obstacle_pos is None: return goal
        while np.min(np.linalg.norm(goal - obstacle_pos, axis=-1)) < self.obstacle_radius:
            goal = np.random.rand(2) * self.range + self.bounds_lower
        return goal
    
    def get_achieved_goal(self, env):
        longest = max([len(env.object_name_dict[n].get_state()) for n in self.all_names])
        state = np.stack([np.pad(env.object_name_dict[n].get_state(), (0,longest - env.object_name_dict[n].get_state().shape[0])) for n in self.all_names], axis=0)
        return self.get_achieved_goal_state(state)

    def get_achieved_goal_state(self, object_state, fidx=None):
        return object_state[...,self.target_idx,:self.partial]

    def add_interaction(self, reached_goal):
        if reached_goal:
            self.interaction_trace += ["Target"]

    def get_state(self):
        return self.attribute # np.array([self.goal_epsilon])
    
    def set_state(self, goal=None):
        if goal is not None: self.attribute = goal[:self.partial]
        return self.attribute # np.array([self.goal_epsilon])

    def check_goal(self, env):
        # returns True if all dimensions are less than epsilon
        # print(self.goal_epsilon, np.linalg.norm(self.get_achieved_goal(env) - self.attribute))
        return np.linalg.norm(self.get_achieved_goal(env) - self.attribute) < self.goal_epsilon

class RoboObject(EnvObject):
    def __init__(self, name, env, bounds):
        super().__init__(name)
        self.env = env
        self.state = None
        self.bounds = bounds
    
    def set_state(self, state=None):
        if state is None:
            if self.state is not None:
                return self.state
            else:
                self.state = self.env.env.get_obs()[self.bounds[0]:self.bounds[1]]
                return self.state
        else:
            self.state = state
            return self.state

class RoboPushing(Environment):
    def __init__(self, variant="default", horizon=30, renderable=False, fixed_limits=False, flat_obs=False, append_id=False):
        super().__init__()
        self.self_reset = True
        self.fixed_limits = fixed_limits
        self.variant=variant
        control_freq, var_horizon, num_obstacles, standard_reward, \
            goal_reward, obstacle_reward, out_of_bounds_reward, mode,\
            hard_obstacles, cube_halfsize, goal_radius  = variants[variant]
        horizon = var_horizon if horizon < 0 else horizon
        self.mode = mode
        self.hard_obstacles = hard_obstacles
        self.goal_reward, self.goal_radius = goal_reward, goal_radius
        controller = "JOINT_POSITION" if self.mode==JOINT_MODE else "OSC_POSE" # TODO: handles only two action spaces at the moment
        self.env = robosuite.make(
                "Push",
                robots=["Panda"],
                controller_configs=load_controller_config(default_controller=controller),
                has_renderer=False,
                has_offscreen_renderer=renderable,
                render_visual_mesh=renderable,
                render_collision_mesh=False,
                camera_names=["frontview"] if renderable else None,
                camera_heights=1080, # TODO: renders in HD right now
                camera_widths=1080,
                control_freq=control_freq,
                horizon=horizon,
                use_object_obs=True,
                use_camera_obs=renderable,
                hard_reset = False,
                num_obstacles=num_obstacles,
                standard_reward=float(standard_reward), 
                goal_reward=float(goal_reward), 
                obstacle_reward=float(obstacle_reward), 
                out_of_bounds_reward=float(out_of_bounds_reward),
                hard_obstacles=self.hard_obstacles,
                keep_gripper_in_cube_plane=mode == PLANAR_MODE,
                cube_halfsize=cube_halfsize,
                goal_radius=goal_radius,
            )
        # environment properties
        self.num_actions = -1 # this must be defined, -1 for continuous. Only needed for primitive actions
        self.name = "RobosuitePushing" # required for an environment 
        self.discrete_mode = mode == DISCRETE_MODE
        self.discrete_actions = self.discrete_mode
        self.frameskip = control_freq
        self.timeout_penalty = -horizon
        self.planar_mode = mode == PLANAR_MODE
        self.pos_size = 3
        self.obstacle_radius = self.env.OBSTACLE_HALF_SIDELENGTH * np.sqrt(2)
        self.num_obstacles = num_obstacles

        # spaces
        if self.discrete_mode:
            self.action_shape = (1,)
            self.num_actions = 6 # up down left right forward backward
            self.action_space = spaces.Discrete(self.num_actions)
            limit = 1
        else:
            low, high = self.env.action_spec
            limit = 7 if self.mode == JOINT_MODE else 3
            self.action_shape = (limit,)
            self.action_space = spaces.Box(low=low[:limit], high=high[:limit])
            self.action = np.zeros(self.action_shape)
        self.action = Action(not self.discrete_actions, self.action_shape[0])
        self.goal_space = spaces.Box(low=ranges["Goal"][0], high=ranges["Goal"][1])
        self.renderable = renderable
        self.flat_obs = flat_obs
        self.append_id = append_id
        self.goal_based = True

        # running values
        self.itr = 0
        self.total_itr = 0
        self.total_reward = 0
        self.non_passive_trace_count = 0

        # state components
        self.extracted_state = None

        # factorized state properties
        obstacle_list = list() if self.num_obstacles <=0 else ["Obstacle"]
        self.object_names = ["Action", "Gripper"] + obstacle_list + ["Target", "Goal", 'Done', "Reward"] # must be initialized, a list of names that controls the ordering of things
        self.object_sim_names = {"Action": "", "Gripper": "gripper0_pushing_gripper", "Target": "cube", "Goal": "goal", "Done": "", "Reward": ""}
        self.object_sim_names = {**{"Obstacle" + str(i): "obstacle" + str(i) for i in range(self.num_obstacles)}, **self.object_sim_names}
        self.sim_objects = {"Action": None, "Gripper": self.env.robots[0].gripper, "Target": self.env.cube, "Goal": self.env.goal, "Done": "", "Reward": ""}
        self.sim_objects = {**{"Obstacle" + str(i): self.env.obstacles[i] for i in range(self.num_obstacles)}, **self.sim_objects}
        self.object_obs_names = {"Action": "", "Gripper": "robot0_eef_pos", "Target": "cube_pos", "Goal": "goal_pos", "Done": "", "Reward": ""}
        self.object_obs_names = {**{"Obstacle" + str(i):  "obstacle"+str(i)+"_pos" for i in range(self.num_obstacles)}, **self.object_obs_names}
        self.sim_object_names = {name: simname for (simname, name) in self.object_sim_names.items() if len(simname) > 0}
        self.object_sizes = {"Action": limit, "Gripper": 3, "Target": 3, "Obstacle": 3,"Goal": 2, "Done": 1, "Reward": 1} # must be initialized, a dictionary of name to length of the state
        self.flat_indices = []
        if self.discrete_mode:
            self.object_range = discrete_ranges
            self.object_dynamics = discrete_dynamics
            self.object_range_true = discrete_ranges
            self.object_dynamics_true = discrete_dynamics
            self.position_masks = discrete_position_masks
        else:
            self.object_range = ranges if not self.fixed_limits else ranges_fixed # the minimum and maximum values for a given feature of an object
            self.object_dynamics = dynamics if not self.fixed_limits else dynamics_fixed
            self.object_range_true = ranges
            self.object_dynamics_true = dynamics
            self.position_masks = position_masks

        # obstacles and objects
        self.num_obstacles = num_obstacles
        self.num_objects = 3 + num_obstacles + 3 # 4 for gripper and Target Goal, 2 for relative gripper Target Goal
        self.object_instanced = instanced
        self.object_instanced["Obstacle"] = num_obstacles
        self.all_names = sum([[(name + str(i) if instanced[name] > 1 else name) for i in range(instanced[name])] for name in self.object_names], start = [])
        self.goal = RoboGoal(all_names=self.all_names, goal_epsilon = self.goal_radius, bounds=ranges["Goal"], obstacle_radius=self.obstacle_radius)
        self.instance_length = len(self.all_names)
        self.object_name_dict = dict()
        len_up_to = 0
        for i in range(len(self.all_names)):
            name = self.all_names[i]
            if name not in ["Action", "Reward", "Done", "Goal"]:
                self.object_name_dict[name] = RoboObject(self.all_names[i], self, 
                                            (len_up_to, len_up_to + self.object_sizes[strip_instance(name)]))
                len_up_to += self.object_sizes[strip_instance(name)]
        self.object_name_dict = {**{"Action": self.action, "Reward": self.reward, "Done": self.done, "Goal": self.goal}, **self.object_name_dict}
        self.objects = [self.object_name_dict[n] for n in self.all_names]
        self.obstacles = [n for n in self.all_names if n.find("Obstacle") != -1]
        self.valid_names = self.all_names
        self.passive_trace = np.eye(len(self.all_names))
        self.passive_trace[1,0] = 1
        self.trace_graph = np.zeros((self.num_objects, self.num_objects))
        self.trace_graph[0,0] = 1

        # position mask
        self.pos_size = 3
        self.length, self.width = 0.3, 0.3

        obs = self.reset()
        self.trace = self.get_full_current_trace()

        self.reward_collect = 0
        if self.flat_obs: self.observation_space = spaces.Box(low=-1, high=1, shape=self.reset().shape)
        else: self.observation_space = spaces.Box(low=-1, high=1, shape=[9])

    def set_named_state(self, obs_dict, set_objects=True):
        if set_objects:
            for name in self.all_names:
                if name not in ["Action", "Reward", "Done"]:
                    self.object_name_dict[name].set_state(obs_dict[self.object_obs_names[name]])
                else:
                    if name in obs_dict:
                        self.object_name_dict[name].attribute = obs_dict[name]

    def set_action(self, action):
        if self.mode == JOINT_MODE:
            use_act = action
        elif self.mode == PLANAR_MODE:
            use_act = np.concatenate([action[:2], [0,0,0,0]])
        elif self.mode == DISCRETE_MODE:
            use_act = np.concatenate([DISCRETE_MOVEMENTS[action], [0,0,0]])
        else:
            use_act = np.concatenate([action, [0, 0, 0]])
        return use_act

    def step(self, action, render=False): # render will NOT change renderable, so it will still render or not render
        # step internal robosuite environment
        self.reset_traces()
        use_act = self.set_action(action)
        next_obs, reward, done, info = self.env.step(use_act)
        # print(self.reward_collect, next_obs["cube_pos"], next_obs["robot0_eef_pos"])
        info["TimeLimit.truncated"] = False
        if done:
            info["TimeLimit.truncated"] = True
            # print("end of episode:", self.reward_collect, ",", next_obs["cube_pos"], next_obs["goal_pos"])
        if self.reward == self.goal_reward: # don't wait at the goal, just terminate
            # print("end of episode:", self.reward_collect, ",", next_obs["cube_pos"], next_obs["goal_pos"])
            done = False
            info["TimeLimit.truncated"] = False
        # set state
        next_obs["Action"], next_obs["Done"], next_obs["Reward"] = action, done, reward

        self.set_named_state(next_obs) # sets the state objects
        reward = self.goal.check_goal(self).astype(int)
        self.reward_collect += reward
        self.reward.attribute = reward

        # handle rendering
        self.frame = next_obs["frontview_image"][::-1] if self.renderable else None

        # handle specialized values
        self.assign_traces()
        self.trace = self.get_factor_graph(complete_graph=True)
        info["trace"] = self.trace
        self.factor_graph = self.get_factor_graph(complete_graph=False)
        self.trace_graph += self.trace

        full_state =  self.get_state()

        # step timers 
        self.itr += 1
        self.total_itr += 1
    
        if np.linalg.norm(self.trace - self.passive_trace) != 0: 
            self.non_passive_trace_count += 1
        # if self.total_itr % 1000 == 0:
        #     print("non_passive_trace_frequency ", self.non_passive_trace_count / self.total_itr)
        #     print("total reward ", self.total_reward / (self.total_itr) * self.max_steps)
        #     print("trace_graph ", self.trace_graph / self.total_itr)


        if self.done.attribute:
            self.reset()
            self.itr = 0
        info = self.get_info(info)
        # print("step",self.env, np.array([obs['factored_state']["Obstacle" + str(i)] for i in range(15)]))
        return full_state, self.reward.attribute, self.done.attribute, info

    # def get_factor_graph(self, all_names="", complete_graph=True):
    #     graph = super().get_factor_graph(all_names=all_names, complete_graph=complete_graph)
    #     if complete_graph: graph[1,0] = 1
    #     else: graph[0,0] = 1
    #     return graph

    def get_state(self, render=False):
        factored_state = {obj.name: obj.get_state() for obj in self.objects}
        factored_state["VALID_NAMES"] = self.valid_binary(self.valid_names)
        factored_state["TRACE"] = self.trace
        if self.flat_obs:
            return np.concatenate([obj.get_state() for obj in self.objects])
        return {"raw_state": self.frame, "factored_state": factored_state}

    def get_info(self, info=None):
        if "Goal" in self.object_name_dict: achieved_goal, desired_goal, success = self.goal.get_achieved_goal(self), self.goal.get_state(), self.goal.check_goal(self)
        new_info = {"TimeLimit.truncated": self.done.attribute, "trace": self.trace, "factor_graph": self.factor_graph, "valid": self.valid_binary(self.valid_names), "achieved_goal": achieved_goal, "desired_goal": desired_goal, "success": success}
        if info is not None: return {**new_info, **info}
        return new_info

    def assign_traces(self):
        for name in self.all_names:
            if name == "Gripper":
                self.object_name_dict[name].interaction_trace.append("Action")
            elif name in ["Action", "Done", "Reward"] or name.find("Obstacle") != -1:
                continue
            else:
                self.object_name_dict[name].interaction_trace = self.check_contacts(name)

    def reset_traces(self):
        for name in self.all_names:
            self.object_name_dict[name].interaction_trace = list()

    def strip_collision(self, name):
        if name.find("_collision") != -1:
            return name[:name.find("_collision")]
        elif name.find("_g0") != -1:
            return name[:name.find("_g0")]

    def check_contacts(self, object1_name):
        object1 = self.sim_objects[object1_name]
        contact_names = self.env.get_contacts(object1)
        contact_names = [self.strip_collision(c) for c in contact_names]
        contacts =  [self.sim_object_names[cn] for cn in contact_names if cn in self.sim_object_names]
        return contacts

    def reset(self, seed=-1, options={}, goal=None, **kwargs):
        obs = self.env.reset()
        self.valid_names = self.all_names # TODO: implement valid names on resets
        self.assign_traces()
        self.trace = self.get_factor_graph(complete_graph=True)
        self.factor_graph = self.get_factor_graph(complete_graph=False)
        if goal is not None: 
            self.goal.attribute = goal
            self.env.sim.model.body_pos[self.env.goal_body_id,:self.goal.partial] = goal
        self.reward_collect = 0
        self.set_named_state(obs)
        self.frame = obs["frontview_image"][::-1] if self.renderable else None
        self.itr = 0
        
        # reset should handle goal resampling, set_named_state should set it
        # obstacle_pos = None if self.num_obstacles == 0 else np.stack([self.object_name_dict[name].get_state() for name in self.obstacles], axis=0)
        # self.goal.set_state(self.goal.sample_goal(obstacle_pos))
        return self.get_state()

    def render(self):
        return self.frame
