import torch as th
import numpy as np
import copy
import json
import os
import gym
import gym.wrappers
import gym.spaces
import gym.envs

import time

MIN_GAIT_LEN = 0.55

class TorqueEvaEnv(gym.Wrapper):
    def __init__(
        self, env, pp_joint_name, episode_steps=2000, skip_steps=120, dt=None
    ):
        super().__init__(env)
        self.total_step = episode_steps + skip_steps
        self.skip_steps = skip_steps
        self.dt = dt
        self.pp_joint_name = pp_joint_name

        self.pp_left_joint_idx = None
        self.pp_right_joint_idx = None

        self.prev_contact = None
        self.stride_num = None
        self.steps_since = None
        self.steps = None
        self.torques = None

        if hasattr(self.unwrapped, "evaluation_mode"):
            self.unwrapped.evaluation_mode()

        while isinstance(env, gym.Wrapper):
            if isinstance(env, gym.wrappers.TimeLimit):
                env._max_episode_steps = float("inf")
            env = env.env

    def reset(self, *args, **kwargs):
        obs = self.env.reset(*args, **kwargs)
        if self.dt is None:
            self.dt = self.unwrapped.scene.dt

        joints = np.array(self.unwrapped.robot.ordered_joints)

        left_joint_inds = [
            i for i, j in enumerate(joints) if "left" in j.joint_name
        ]
        right_joint_inds = [
            i
            for i, j in enumerate(joints)
            if "left" not in j.joint_name and "abdomen" not in j.joint_name
        ]

        self.pp_left_joint_idx = [
            idx
            for idx in left_joint_inds
            if self.pp_joint_name in joints[idx].joint_name
        ][0]
        self.pp_right_joint_idx = [
            idx
            for idx in right_joint_inds
            if self.pp_joint_name in joints[idx].joint_name
        ][0]

        self.torques = []
        self.steps = 0
        return obs
    
    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        self.steps += 1
        if self.steps > self.skip_steps:
            self.torques.append(action[[self.pp_left_joint_idx, self.pp_right_joint_idx]])
        if self.steps == self.total_step:
            info["torques"] = self.torques
            done = True
        return obs, rew, done, info

    # def count_stride(self):
    #     side_contact = self.unwrapped.robot.feet_contact[0]
    #     strike = not self.prev_contact and side_contact
    #     self.prev_contact = side_contact
    #     self.steps_since += 1
    #
    #     if strike and self.steps_since * self.dt > MIN_GAIT_LEN:
    #         self.steps_since = 0
    #         self.stride_num += 1
    #         return True
    #     return False
