import torch
import numpy as np
import gymnasium as gym
# from matplotlib.pyplot import imshow
# import matplotlib as plt
from copy import deepcopy as dc
import copy
from Environment.environment import Environment, Goal, Action
from Environment.Environments.Gym.gym import GymObject
from Environment.Environments.Gym.GymEnvs.gym_pusher_specs import pusher_specs

import mujoco
import numpy as np

from gymnasium import utils
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.spaces import Box
import os

class PusherGoal(Goal):
    def __init__(self, **kwargs):
        self.name = "Goal"
        self.attribute = np.ones(3)
        self.interaction_trace = list()
        self.target_idx = 2
        self.partial = 3 # should be the length of the indices to use as goals, SET IN SUBCLASS
        self.bounds = np.array([(-0.3,-0.2,-0.280),(0,0.2,-0.270)])
        self.bounds_lower = np.array([-0.3, -0.2, -0.280])
        self.range = np.array([0.3, 0.4, 0.005])
        self.all_names = kwargs["all_names"]
        self.goal_epsilon = kwargs["goal_epsilon"]

    def generate_bounds(self):
        return self.bounds_lower[:self.partial], self.bounds[1]

    def sample_goal(self):
        return np.random.rand(3) * self.range + self.bounds_lower
    
    def get_achieved_goal(self, env):
        longest = max([len(env.object_name_dict[n].get_state()) for n in self.all_names])
        state = np.stack([np.pad(env.object_name_dict[n].get_state(), (0,longest - env.object_name_dict[n].get_state().shape[0])) for n in self.all_names], axis=0)
        return self.get_achieved_goal_state(state)

    def get_achieved_goal_state(self, object_state, fidx=None):
        return object_state[...,self.target_idx,:self.partial]

    def add_interaction(self, reached_goal):
        if reached_goal:
            self.interaction_trace += ["Target"]

    def get_state(self):
        return self.attribute # np.array([self.goal_epsilon])
    
    def set_state(self, goal=None):
        if goal is not None: self.attribute = goal
        return self.attribute # np.array([self.goal_epsilon])

    def check_goal(self, env):
        # returns True if all dimensions are less than epsilon
        return np.linalg.norm(self.get_achieved_goal(env) - self.attribute) < self.goal_epsilon

def init_object_bounds(env):
    "Action", "Pusher", "Block", "Goal", "Reward", "Done"
    object_sizes = {
        "Action": 7,
        "Pusher": 17,
        "Block": 3,
        "Goal": 3,
        "Reward": 1,
        "Done": 1
    }
    object_range = {
        "Action": (env.action_space.low, env.action_space.high),
        "Pusher": (env.observation_space.low[:17], env.observation_space.high[:17]),
        "Block": (env.observation_space.low[17:20], env.observation_space.high[17:20]),
        "Goal": (env.observation_space.low[20:23], env.observation_space.high[20:23]),
        "Reward": (-10,10),
        "Done": (-10,10),
    }
    object_dynamics = {
        "Action": (env.action_space.low * 2, env.action_space.high * 2),
        "Pusher": (env.observation_space.low[:17] * 2, env.observation_space.high[:17] * 2),
        "Block": (env.observation_space.low[17:20] / 10, env.observation_space.high[17:20] / 10),
        "Goal": (env.observation_space.low[20:23] / 10, env.observation_space.high[20:23] / 10),
        "Reward": (-10,10),
        "Done": (-10,10),
    }

    return object_sizes, object_range, object_dynamics, copy.deepcopy(object_range), copy.deepcopy(object_dynamics)

DEFAULT_CAMERA_CONFIG = {
    "trackbodyid": -1,
    "distance": 4.0,
}

# Code copied from: https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/mujoco/pusher_v4.py
class PusherEnv(MujocoEnv, utils.EzPickle):
    metadata = {
        "render_modes": [
            "human",
            "rgb_array",
            "depth_array",
        ],
        "render_fps": 20,
    }

    def __init__(self, **kwargs):
        utils.EzPickle.__init__(self, **kwargs)
        xml_file = kwargs["xml_file"]
        del kwargs["xml_file"]

        observation_space = Box(low=np.array([-0.5 for i in range(3)] + [-np.pi for i in range(14)] + [-0.5 for i in range(6)]), 
                                high=np.array([0.5 for i in range(3)] + [np.pi for i in range(14)] + [0.5 for i in range(6)]), shape=(23,), dtype=np.float64)
        MujocoEnv.__init__(
            self,
            os.path.join("./Environment", "Environments", "Gym", "GymEnvs", "xmls", xml_file),
            5,
            observation_space=observation_space,
            default_camera_config=DEFAULT_CAMERA_CONFIG,
            **kwargs,
        )
        self.trace_map = {
            # "action": 0,
            "r_upper_arm_roll_link": 1, 
            "r_upper_arm_link": 1, 
            "r_elbow_flex_link": 1, 
            "r_forearm_roll_link": 1, 
            "r_forearm_link": 1, 
            "r_wrist_flex_link": 1, 
            "r_wrist_roll_link": 1, 
            "tips_arm": 1,
            "object": 2,
            "goal": 3,
            # "reward": 4,
            # "done": 5
        }

        self.id_map = {
            self.data.body(name).id: self.trace_map[name] for name in self.trace_map.keys()
        }

    def step(self, a):
        vec_1 = self.get_body_com("object") - self.get_body_com("tips_arm")
        vec_2 = self.get_body_com("object") - self.get_body_com("goal")

        reward_near = -np.linalg.norm(vec_1)
        reward_dist = -np.linalg.norm(vec_2)
        reward_ctrl = -np.square(a).sum()
        reward = reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near

        self.do_simulation(a, self.frame_skip)
        img = None
        if self.render_mode == "rgb_array":
            img = self.render()
            print(img)
        trace = self.get_contacts()

        ob = self._get_obs()
        # truncation=False as the time limit is handled by the `TimeLimit` wrapper added during `make`
        return (
            ob,
            reward,
            False,
            False,
            dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl, trace=trace, img=img),
        )
    
    def get_contacts(self):

        trace = np.eye(6) # action, block, pusher, goal, reward done 
        for i in range(self.data.ncon):
            con = self.data.contact[i]
            
            # print(con)
            if con.geom1 in self.id_map and con.geom2 in self.id_map:
                trace[self.id_map[con.geom1+1], self.id_map[con.geom2+1]] = 1
                trace[self.id_map[con.geom2+1], self.id_map[con.geom1+1]] = 1
        trace[1,0] = 1
        return trace

    def reset_model(self):
        qpos = self.init_qpos

        self.goal_pos = qpos[-2:]
        while True:
            self.cylinder_pos = np.concatenate(
                [
                    self.np_random.uniform(low=-0.3, high=0, size=1),
                    self.np_random.uniform(low=-0.2, high=0.2, size=1),
                ]
            )
            if np.linalg.norm(self.cylinder_pos - self.goal_pos) > 0.17:
                break

        qpos[-4:-2] = self.cylinder_pos
        qvel = self.init_qvel + self.np_random.uniform(
            low=-0.005, high=0.005, size=self.model.nv
        )
        qvel[-4:] = 0
        self.set_state(qpos, qvel)
        return self._get_obs()

    def set_state_from_factors(self, pusher, cylinder, goal):
        # since mujoco doesn't support velocities, we can't actually reset everything from state
        qpos=np.concatenate([pusher[:7], cylinder[:2], goal[:2]])
        qvel=np.concatenate([pusher[7:], np.zeros(4)])
        self.set_state(qpos, qvel)

    def _get_obs(self):
        return np.concatenate(
            [
                self.get_body_com("tips_arm"),
                self.data.qpos.flat[:7],
                self.data.qvel.flat[:7],
                self.get_body_com("object"),
                self.get_body_com("goal"),
            ]
        )

    def get_obs(self):
        return np.concatenate(
            [
                self.get_body_com("tips_arm"),
                self.data.qpos.flat[:7],
                self.data.qvel.flat[:7],
                self.get_body_com("object"),
                self.get_body_com("goal"),
            ]
        )


class PusherGym(Environment): # wraps openAI gym environment
    def __init__(self, frameskip = 1, horizon=200, variant="", fixed_limits=False, renderable=False, render_masks=False):
        super().__init__(frameskip, horizon, variant, fixed_limits, renderable, render_masks)
        # %env MUJOCO_GL=egl
        os.environ["MUJOCO_GL"] = "egl"
        self.xml_file, self.sparse_reward, self.goal_epsilon = pusher_specs[variant]
        # gym wrapper specific properties
        gymenv = PusherEnv(render_mode="rgb_array", xml_file=self.xml_file)
        if renderable:
            gymenv = gym.wrappers.HumanRendering(gymenv)
        if horizon != -1:
            if horizon is not None:
                gymenv = gym.wrappers.TimeLimit(gymenv, horizon)

        self.env = gymenv


        # environment properties
        self.name = "GymPusher" # required for an environment 
        self.discrete_actions = type(gymenv.action_space) == gym.spaces.Discrete
        self.num_actions = self.env.action_space.N if self.discrete_actions else -1
        self.frameskip = frameskip
        self.self_reset = True
        self.fixed_limits = fixed_limits
        self.frameskip = frameskip # no frameskip
        self.transpose = True

        # spaces
        self.action_space = self.env.action_space # action space is -2,2, which could result in issues if expecting -1,1
        self.action_shape = (1,) if self.discrete_actions else self.action_space.shape
        self.observation_space = self.env.observation_space
        self.pos_size = 3 # the dimensionality, should be set

        # state components
        self.action = Action(not self.discrete_actions, self.action_shape[0])

        # running values
        self.itr = 0
        self.total_itr = 0
        self.max_steps = horizon if horizon > 0 else 1e12
        self.non_passive_trace_count = 0
        self.total_reward = 0

        # factorized state properties
        self.all_names = ["Action", "Pusher", "Block", "Goal", "Reward", "Done"] # TODO: probably can handle multiple blocks
        self.goal = PusherGoal(all_names=self.all_names, goal_epsilon = self.goal_epsilon)
        self.valid_names = ["Action", "Pusher", "Block", "Goal", "Reward", "Done"] # handle at reset
        self.num_objects = 6
        self.object_names = ["Action", "Pusher", "Block", "Goal", "Reward", "Done"]
        self.object_sizes, self.object_range, self.object_dynamics, self.object_range_true, self.object_range_true = init_object_bounds(self) # TODO: probably can handle multiple blocks
        self.object_instanced = {name: 1 for name in self.all_names}
        self.object_proximal = {"Action": False, "Pusher": True, "Block": True, "Goal": True, "Reward": False, "Done": False}
        self.object_name_dict = {
            "Action": self.action, 
            "Pusher": GymObject("Pusher", self, (0,17)), 
            "Block": GymObject("Block", self, (17,20)),
            "Goal": self.goal, 
            "Reward": self.reward, 
            "Done": self.done
            }
        self.objects = [self.object_name_dict[n] for n in self.all_names]
        self.instance_length = 6

        # proximity state components
        self.position_masks = {name: 3 for name in self.all_names}
        self.pos_size = 3 # the size of the position vector, if used for that object
        self.goal_based = True
        self.goal_idx = 3
        self.goal_trace_idx = self.goal_idx-1 # if not negative, then usually self.goal_idx - 1
        self.goal_space = gym.spaces.Box(low=self.goal.generate_bounds()[0], high=self.goal.generate_bounds()[1])
        self.passive_trace = np.eye(6)
        self.passive_trace[1,0] = 1

        self.reset()

    def seed(self, seed):
        super().seed(seed)

    def reset(self, render=False):
        goal_pos = self.goal.sample_goal()
        self.goal.attribute = goal_pos
        self.env.init_qpos[-2:] = goal_pos[:2] # sets the goal according to the sampled goal, not the one in mujoco
        obs, self.info = self.env.reset() # TODO: mujoco automatically renders, but not implemented here yet, treats frame as flat state
        self.info["trace"] = self.env.get_contacts()
        self.info["trace"][1,0] = 1
        self.itr = 0
        self.info["img"] = self.env.render()
        self.frame = self.info["img"]
        self.extracted_state = self._dict_state(obs, self.reward.attribute, self.done.attribute, self.action)
        # print("resetting")
        return {"raw_state": self.frame, "factored_state": self.extracted_state}, self.info

    def step(self, action, render=False):
        # TODO: rendering can only be toggled at initialization
        self.action.attribute = action
        rew, done = 0.0, False
        for i in range(self.frameskip):
            observation, reward, term, trunc, info = self.env.step(action)
            if self.sparse_reward: reward = self.goal.check_goal(self).astype(float)
            rew += reward
            self.total_reward += reward
            done = done or term or trunc
        self.reward.attribute, self.done.attribute = rew, done # TODO: termination truncation distinction not supported
        if len(action.shape) == 0:
            action = np.array([action])
            self.action = action
        self.itr += 1
        self.total_itr += 1
        info["Timelimit.truncated"] = False
        if self.itr == self.max_steps:
            extracted_state, info = self.reset()
            info["Timelimit.truncated"] = True
        if self.total_itr % 1000 == 0:
            print("non_passive_trace_frequency ", self.non_passive_trace_count / self.total_itr)
            print("total reward ", self.total_reward / (self.total_itr) * self.max_steps)

        extracted_state = self._dict_state(observation, self.reward.attribute, self.done.attribute, action, trace=info["trace"])
        self.extracted_state, self.frame = extracted_state, info["img"]
        del info["img"]
        self.info = info

        return {"raw_state": self.frame, "factored_state": extracted_state}, self.reward, bool(self.done), info

    def extracted_state_dict(self):
        return dc(self.extracted_state)

    def set_interaction_from_trace(self, trace):
        trace[1,0] = 1 # assumes controllable object is second row
        for i, name in enumerate(self.all_names):
            self.object_name_dict[name].set_interaction_from_trace(trace[i], self.all_names)

    def _dict_state(self, observation, reward, done, action, goal=None, trace=None):
        # assigns the proper values from a state into the relevant objects
        self.action.attribute = action
        factored_state = {"Pusher": self.object_name_dict["Pusher"].set_state(observation), 
                "Block": self.object_name_dict["Block"].set_state(observation),
                "Reward": np.array([reward]),
                "Done": np.array([int(done)]), 
                "Goal": self.goal.set_state(goal),
                "Action": self.action.attribute}
        if trace is None:
            trace = self.env.get_contacts()
            trace[1,0] = 1
        if np.linalg.norm(trace - self.passive_trace) != 0: 
            self.non_passive_trace_count += 1
            print(trace, self.total_itr)
        self.set_interaction_from_trace(trace)
        return factored_state

    def get_state(self):
        return {'raw_state': self.frame, 'factored_state': self.extracted_state_dict()}

    def get_info(self):
        return self.info

    def set_from_factored_state(self, factored_state):
        flat_state = self.flatten_factored_state(factored_state)
        return super().set_from_factored_state()
