import gym
import gym.wrappers

import numpy as np


def get_signal_segments(seq):
    pre_value = seq[0]
    i = 0
    seg_list = []
    seg_pos = i
    seg_len = 0
    # construct segment list
    while i < len(seq):
        value = seq[i]
        if value != pre_value:
            seg = {
                "pos": seg_pos,
                "len": seg_len,
                "value": pre_value
            }
            seg_list.append(seg)
            seg_pos = i
            seg_len = 1
        else:
            seg_len += 1
        pre_value = value
        i += 1
    seg_list.append({"pos": seg_pos, "len": seg_len, "value": pre_value})
    return seg_list


def denoise_binary(seq, min_len=9, ignore_start=True, ignore_end=True):
    assert len(seq) >= min_len
    denoised_seq = []
    seg_list = get_signal_segments(seq)
    # find segment that its length less than min_len
    candidate_list = []
    start_seg = 1 if ignore_start else 0
    end_seg = len(seg_list) - 1 if ignore_end else len(seg_list)
    next_pos = 0
    for seg in seg_list[start_seg: end_seg]:
        if seg["len"] < min_len:
            # put succesive candidates into same group
            if candidate_list and next_pos == seg["pos"]:
                candidate_list[-1].append(seg)
            else:
                candidate_list.append([seg])
            next_pos = seg['pos'] + seg['len']

    for group in candidate_list:
        if len(group) == 1:
            group[0]["value"] = 1 - group[0]["value"]
        else:
            pos_len = 0
            neg_len = 0
            for seg in group:
                if seg["value"] > 0.5:
                    pos_len += seg["len"]
                else:
                    neg_len += seg["len"]
            val = 0.0
            if pos_len > neg_len:
                val = 1.0
            for seg in group:
                seg["value"] = val

    for seg in seg_list:
        denoised_seq += [seg["value"]] * seg["len"]
    return denoised_seq


class BaseRecorder(object):
    def __init__(self, name, env):
        self.env = env
        env.register_recorder(name, self)

    def before_step(self, action):
        pass

    def after_step(self, observation, reward, done, info, action):
        pass

    def before_reset(self, **kwargs):
        pass

    def after_reset(self, obs, **kwargs):
        pass


class RecorderEnv(gym.Wrapper):
    def __init__(self, env, max_episode_steps=None):
        self.max_episode_steps = max_episode_steps
        self.recorders = {}

        if max_episode_steps is not None:
            is_time_limit = False
            dummy_env = env
            while isinstance(dummy_env, gym.Wrapper):
                if isinstance(dummy_env, gym.wrappers.TimeLimit):
                    dummy_env._max_episode_steps = float(max_episode_steps)
                    is_time_limit = True
                dummy_env = dummy_env.env
            if not is_time_limit:
                env = gym.wrappers.TimeLimit(env, max_episode_steps=max_episode_steps)
        super(RecorderEnv, self).__init__(env)

    def register_recorder(self, recorder_id, recorder: BaseRecorder):
        assert recorder_id not in self.recorders
        self.recorders[recorder_id] = recorder

    def deregister_recorder(self, recorder_id):
        del self.recorders[recorder_id]

    def reset(self, **kwargs):
        for recorder in self.recorders.values():
            recorder.before_reset(**kwargs)
        obs = self.env.reset(**kwargs)
        for recorder in self.recorders.values():
            recorder.after_reset(obs, **kwargs)
        return obs

    def step(self, action):
        for recorder in self.recorders.values():
            recorder.before_step(action)
        observation, reward, done, info = self.env.step(action)
        for recorder in self.recorders.values():
            recorder.after_step(observation, reward, done, info, action)
        return observation, reward, done, info


class SimpleRecorder(BaseRecorder):
    def __init__(self, name, env: RecorderEnv, pp_joint_name):
        self.pp_joint_name = pp_joint_name
        self.pp_left_joint = None
        self.pp_right_joint = None
        self.pp_left_joint_idx = None
        self.pp_right_joint_idx = None
        self.torque_limits = None
        self.ql = []
        self.qr = []
        self.qdotl = []
        self.qdotr = []
        self.actl = []
        self.actr = []
        self.body_roll = []
        self.foot_contact = []
        self.rltorques = []
        self.actions = []
        super(SimpleRecorder, self).__init__(name, env)

    def after_reset(self, obs, **kwargs):
        joints = np.array(self.env.unwrapped.robot.ordered_joints)

        self.left_joint_inds = [
            i for i, j in enumerate(joints) if "left" in j.joint_name
        ]
        self.right_joint_inds = [
            i
            for i, j in enumerate(joints)
            if "left" not in j.joint_name and "abdomen" not in j.joint_name
        ]

        self.pp_left_joint_idx = [idx for idx in self.left_joint_inds if self.pp_joint_name in joints[idx].joint_name][0]
        self.pp_right_joint_idx = [idx for idx in self.right_joint_inds if self.pp_joint_name in joints[idx].joint_name][0]
        # print("left_hip: ", self.pp_left_joint_idx)
        # print("right_hip: ", self.pp_right_joint_idx)

        self.pp_left_joint = [
            j
            for j in joints[self.left_joint_inds]
            if self.pp_joint_name in j.joint_name
        ][0]
        self.pp_right_joint = [
            j
            for j in joints[self.right_joint_inds]
            if self.pp_joint_name in j.joint_name
        ][0]

        self.torque_limits = np.array(
            [
                self.env.unwrapped.robot.power * j.power_coef
                if hasattr(j, "power_coef")
                else j.torque_limit
                for j in joints[self.left_joint_inds]
            ]
        )
        self.ql = []
        self.qr = []
        self.qdotl = []
        self.qdotr = []
        self.actl = []
        self.actr = []
        self.body_roll = []
        self.foot_contact = []
        self.rltorques = []
        self.actions = []

    def after_step(self, observation, reward, done, info, action):
        ql, qdotl = self.pp_left_joint.get_state()
        qr, qdotr = self.pp_right_joint.get_state()
        self.ql.append(ql)
        self.qr.append(qr)
        self.qdotl.append(qdotl)
        self.qdotr.append(qdotr)
        self.actl.append(action[self.pp_left_joint_idx])
        self.actr.append(action[self.pp_right_joint_idx])
        self.body_roll.append(self.env.unwrapped.robot.robot_body.pose().rpy()[0])
        self.foot_contact.append(self.env.unwrapped.robot.feet_contact[0])
        self.rltorques.append(
            [action[self.left_joint_inds] * self.torque_limits,
             action[self.right_joint_inds] * self.torque_limits]
        )
        self.actions.append(action)
        # print(self.env.unwrapped.robot.feet_contact[0])
