import abc
import time
import os
import gymnasium as gym
import numpy.random
import torch
import pickle
import numpy as np
from gymnasium.envs.mujoco import walker2d_v3
from imitations.data.types import TrajectoryWithRew

###############################################################################
# TORQUE CONSTRAINTS
###############################################################################

ACTION_TORQUE_THRESHOLD = 0.5
VIOLATIONS_ALLOWED = 100


class Walker2dTest(walker2d_v3.Walker2dEnv):
    def reset(self):
        ob = super().reset()
        self.current_timestep = 0
        self.violations = 0
        return ob

    def step(self, action):
        next_ob, reward, done, infos = super().step(action)
        # This is to handle the edge case where mujoco_env calls
        # step in __init__ without calling reset with a random
        # action
        try:
            self.current_timestep += 1
            if np.any(np.abs(action) < ACTION_TORQUE_THRESHOLD):
                self.violations += 1
            if self.violations > VIOLATIONS_ALLOWED:
                done = True
                terminated = done
                truncated = done
                reward = 0
        except:
            pass
        return next_ob, reward, terminated, truncated, infos


###############################################################################

REWARD_TYPE = 'old'  # Which reward to use, traditional or new one?


# =========================================================================== #
#                    Walker With Global Postion Coordinates                   #
# =========================================================================== #

class WalkerWithPos(walker2d_v3.Walker2dEnv):
    def _get_obs(self):
        return np.concatenate([
            self.sim.data.qpos.flat,
            self.sim.data.qvel.flat,
        ])

    def old_reward(self, xposbefore, xposafter, action):
        reward_ctrl = -1e-3 * np.square(action).sum()
        reward_run = abs(xposafter - xposbefore) / self.dt
        alive_bonus = 1
        reward = reward_ctrl + reward_run + alive_bonus

        info = dict(
            reward_run=reward_run,
            reward_ctrl=reward_ctrl,
            x_position=xposafter
        )

        return reward, info

    def new_reward(self, xposbefore, xposafter, action):
        reward_ctrl = -1e-3 * np.square(action).sum()
        # if xposafter < 0:
        #     reward_dist = 1.5 * abs(xposafter)
        # else:
        reward_dist = abs(xposafter)
        reward_run = reward_dist / self.dt

        reward = reward_dist + reward_ctrl
        info = dict(
            reward_run=reward_run,
            reward_ctrl=reward_ctrl,
            reward_dist=reward_dist,
            x_position=xposafter
        )

        return reward, info

    def step(self, action):
        xposbefore = self.sim.data.qpos[0]
        self.do_simulation(action, self.frame_skip)
        xposafter, height, ang = self.sim.data.qpos[0:3]
        ob = self._get_obs()

        if REWARD_TYPE == 'new':
            reward, info = self.new_reward(xposbefore,
                                           xposafter,
                                           action)
        elif REWARD_TYPE == 'old':
            reward, info = self.old_reward(xposbefore,
                                           xposafter,
                                           action)

        done = not (height > 0.8 and height < 2.0 and
                    ang > -1.0 and ang < 1.0)
        terminated = done
        truncated = done
        return ob, reward, terminated, truncated, info

class WalkerWithPosNoise(WalkerWithPos):
    def __init__(self, noise_mean: int = 0, noise_std = 1e-3, noise_seed: int = 0):
        self.noise_mean = noise_mean
        self.noise_std = noise_std
        self.noise_seed = noise_seed
        self.rdm = np.random.RandomState(1)
        super().__init__()

    def _get_obs(self):
        return np.concatenate([
            self.sim.data.qpos.flat,
            self.sim.data.qvel.flat,
        ])

    def old_reward(self, xposbefore, xposafter, action):
        reward_ctrl = -1e-3 * np.square(action).sum()
        reward_run = abs(xposafter - xposbefore) / self.dt
        alive_bonus = 1
        reward = reward_ctrl + reward_run + alive_bonus

        info = dict(
            reward_run=reward_run,
            reward_ctrl=reward_ctrl,
            x_position=xposafter
        )

        return reward, info

    def new_reward(self, xposbefore, xposafter, action):
        reward_ctrl = -1e-3 * np.square(action).sum()
        # if xposafter < 0:
        #     reward_dist = 1.5 * abs(xposafter)
        # else:
        reward_dist = abs(xposafter)
        reward_run = reward_dist / self.dt

        reward = reward_dist + reward_ctrl
        info = dict(
            reward_run=reward_run,
            reward_ctrl=reward_ctrl,
            reward_dist=reward_dist,
            x_position=xposafter
        )

        return reward, info

    def step(self, action):
        xposbefore = self.sim.data.qpos[0]
        self.do_simulation(action, self.frame_skip)

        noise_qpos = self.rdm.uniform(-(self.noise_mean+self.noise_std), self.noise_mean+self.noise_std, self.model.nq)
        noise_qvel = self.rdm.uniform(-(self.noise_mean+self.noise_std), self.noise_mean+self.noise_std, self.model.nv)

        qpos = self.sim.data.qpos + noise_qpos
        qvel = self.sim.data.qvel + noise_qvel

        self.set_state(qpos=qpos, qvel=qvel)

        xposafter, height, ang = self.sim.data.qpos[0:3]
        ob = self._get_obs()

        if REWARD_TYPE == 'new':
            reward, info = self.new_reward(xposbefore,
                                           xposafter,
                                           action)
        elif REWARD_TYPE == 'old':
            reward, info = self.old_reward(xposbefore,
                                           xposafter,
                                           action)

        done = not (height > 0.8 and height < 2.0 and
                    ang > -1.0 and ang < 1.0)
        terminated = done
        truncated = done
        return ob, reward, terminated, truncated, info

    def rollout(self):

        folder_path = r'.\expert_data\BlockedWalker'
        all_trajectories = []

        for file_name in os.listdir(folder_path):
            if file_name.endswith('.pkl'):
                file_path = os.path.join(folder_path, file_name)
                with open(file_path, 'rb') as file:
                    data = pickle.load(file)

                obs = data['observations']
                # obs = obs[:, -8:]
                acts = data['actions'][:len(data['actions'])-1]
                rews = data['reward'][:len(data['actions'])-1]
                terminal = True
                infos = None

                trajectory = TrajectoryWithRew(obs=obs, acts=acts, rews=rews, terminal=terminal, infos=infos)
                all_trajectories.append(trajectory)

        return all_trajectories


class WalkerWithPosTest(WalkerWithPos):
    def step(self, action):
        xposbefore = self.sim.data.qpos[0]
        self.do_simulation(action, self.frame_skip)
        xposafter, height, ang = self.sim.data.qpos[0:3]
        ob = self._get_obs()
        reward_ctrl = -1e-3 * np.square(action).sum()
        reward_run = abs(xposafter - xposbefore) / self.dt
        alive_bonus = 1
        if REWARD_TYPE == 'new':
            reward, info = self.new_reward(xposbefore,
                                           xposafter,
                                           action)
        elif REWARD_TYPE == 'old':
            reward, info = self.old_reward(xposbefore,
                                           xposafter,
                                           action)

        done = not (height > 0.8 and height < 2.0 and
                    ang > -1.0 and ang < 1.0)
        terminated = done
        truncated = done

        # If agent violates constraint, terminate the episode
        if xposafter <= -3:
            print("Violated constraint in the test environment; terminating episode")
            done = True
            terminated = done
            truncated = done
            reward = 0

        return ob, reward, terminated, truncated, info
