#! /usr/bin/env python3
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The environment class for blocks world tasks."""
import os
import numpy as np
import gym
import highway_env
import highway_env.vehicle.behavior

from gym.envs.robotics import fetch_env

import jacinle.random as random
from jaclearn.rl.env import SimpleRLEnvBase

from .block import randomly_generate_world
from .represent import get_coordinates
from .represent import decorate

import mujoco_py

from gym import utils
from copy import deepcopy

__all__ = ['FinalBlocksWorldEnv', 'StackBlocksWorldEnv']


class BlocksWorldEnv(SimpleRLEnvBase):
    """The Base BlocksWorld environment.

      Args:
          nr_blocks: The number of blocks.
          random_order: randomly permute the indexes of the blocks. This
            option prevents the models from memorizing the configurations.
          decorate: if True, the coordinates in the states will also include the
              world index (default: 0) and the block index (starting from 0).
          prob_unchange: The probability that an action is not effective.
          prob_fall: The probability that an action will make the object currently
              moving fall on the ground.
    """

    def __init__(self,
                 nr_blocks,
                 random_order=False,
                 decorate=False,
                 prob_unchange=0.0,
                 prob_fall=0.0):
        super().__init__()
        self.nr_blocks = nr_blocks
        self.nr_objects = nr_blocks + 1
        self.random_order = random_order
        self.decorate = decorate
        self.prob_unchange = prob_unchange
        self.prob_fall = prob_fall

    def _restart(self):
        self.world = randomly_generate_world(
            self.nr_blocks, random_order=self.random_order)
        self._set_current_state(self._get_decorated_states())
        self.is_over = False
        self.cached_result = self._get_result()

    def _get_decorated_states(self, world_id=0):
        state = get_coordinates(self.world)
        if self.decorate:
            state = decorate(state, self.nr_objects, world_id)
        return state


class FinalBlocksWorldEnv(BlocksWorldEnv):
    """The BlocksWorld environment for the final task."""

    def __init__(self,
                 nr_blocks,
                 random_order=False,
                 shape_only=False,
                 fix_ground=False,
                 prob_unchange=0.0,
                 prob_fall=0.0):
        super().__init__(nr_blocks, random_order, True, prob_unchange, prob_fall)
        self.shape_only = shape_only
        self.fix_ground = fix_ground

    def _restart(self):
        self.start_world = randomly_generate_world(
            self.nr_blocks, random_order=False)
        self.final_world = randomly_generate_world(
            self.nr_blocks, random_order=False)
        self.world = self.start_world
        if self.random_order:
            n = self.world.size
            # Ground is fixed as index 0 if fix_ground is True
            ground_ind = 0 if self.fix_ground else random.randint(n)

            def get_order():
                raw_order = random.permutation(n - 1)
                order = []
                for i in range(n - 1):
                    if i == ground_ind:
                        order.append(0)
                    order.append(raw_order[i] + 1)
                if ground_ind == n - 1:
                    order.append(0)
                return order

            self.start_world.blocks.set_random_order(get_order())
            self.final_world.blocks.set_random_order(get_order())

        self._prepare_worlds()
        self.start_state = decorate(
            self._get_coordinates(self.start_world), self.nr_objects, 0)
        self.final_state = decorate(
            self._get_coordinates(self.final_world), self.nr_objects, 1)

        self.is_over = False
        self.cached_result = self._get_result()

    def _prepare_worlds(self):
        pass

    def _action(self, action):
        assert self.start_world is not None, 'you need to call restart() first'

        if self.is_over:
            return 0, True
        # r, is_over = self.cached_result
        # if is_over:
        #  self.is_over = True
        #  return r, is_over

        x, y = action
        assert 0 <= x <= self.nr_blocks and 0 <= y <= self.nr_blocks

        p = random.rand()
        if p >= self.prob_unchange:
            if p < self.prob_unchange + self.prob_fall:
                y = self.start_world.blocks.inv_index(0)  # fall to ground
            self.start_world.move(x, y)
            self.start_state = decorate(
                self._get_coordinates(self.start_world), self.nr_objects, 0)
        r, is_over = self._get_result()
        if is_over:
            self.is_over = True
        return r, is_over

    def _get_current_state(self):
        assert self.start_world is not None, 'Should call restart() first.'
        return np.vstack([self.start_state, self.final_state])

    def _get_result(self):
        sorted_start_state = self._get_coordinates(self.start_world, sort=True)
        sorted_final_state = self._get_coordinates(self.final_world, sort=True)
        if (sorted_start_state == sorted_final_state).all():
            return 1, True
        else:
            return 0, False

    def _get_coordinates(self, world, sort=False):
        # If shape_only=True, only the shape of the blocks need to be the same.
        # If shape_only=False, the index of the blocks should also match.
        coordinates = get_coordinates(world, absolute=not self.shape_only)
        if sort:
            if not self.shape_only:
                coordinates = decorate(coordinates, self.nr_objects, 0)
            coordinates = np.array(sorted(list(map(tuple, coordinates))))
        return coordinates


class StackBlocksWorldEnv(BlocksWorldEnv):

    def __init__(self,
                 nr_blocks,
                 random_order=False,
                 shape_only=False,
                 fix_ground=False,
                 prob_unchange=0.0,
                 prob_fall=0.0):
        super().__init__(nr_blocks, random_order, True, prob_unchange, prob_fall)
        self.shape_only = shape_only
        self.fix_ground = fix_ground

    def _restart(self):
        self.start_world = randomly_generate_world(
            self.nr_blocks, random_order=False)
        self.world = self.start_world
        if self.random_order:
            n = self.world.size
            # Ground is fixed as index 0 if fix_ground is True
            ground_ind = 0 if self.fix_ground else random.randint(n)

            def get_order():
                raw_order = random.permutation(n - 1)
                order = []
                for i in range(n - 1):
                    if i == ground_ind:
                        order.append(0)
                    order.append(raw_order[i] + 1)
                if ground_ind == n - 1:
                    order.append(0)
                return order

            self.start_world.blocks.set_random_order(get_order())

        self._prepare_worlds()
        self.start_state = decorate(
            self._get_coordinates(self.start_world), self.nr_objects, 0)

        self.is_over = False
        self.cached_result = self._get_result()

    def _prepare_worlds(self):
        pass

    def _action(self, action):
        assert self.start_world is not None, 'you need to call restart() first'

        if self.is_over:
            return 0, True
        r, is_over = self.cached_result
        if is_over:
            self.is_over = True
            return r, is_over

        x, y = action
        assert 0 <= x <= self.nr_blocks and 0 <= y <= self.nr_blocks

        p = random.rand()
        if p >= self.prob_unchange:
            if p < self.prob_unchange + self.prob_fall:
                y = self.start_world.blocks.inv_index(0)  # fall to ground
            self.start_world.move(x, y)
            self.start_state = decorate(
                self._get_coordinates(self.start_world), self.nr_objects, 0)
        r, is_over = self._get_result()
        if is_over:
            self.is_over = True
        return r, is_over

    def _get_current_state(self):
        assert self.start_world is not None, 'Should call restart() first.'
        return self.start_state[:, 2:4]

    def _get_result(self):
        sorted_start_state = self._get_coordinates(self.start_world, sort=True)
        x_coor_blocs = sorted_start_state[1:, 2]
        if np.all(x_coor_blocs[0] == x_coor_blocs):
            return 1, True
        else:
            return 0, False

    def _get_coordinates(self, world, sort=False):
        # If shape_only=True, only the shape of the blocks need to be the same.
        # If shape_only=False, the index of the blocks should also match.
        coordinates = get_coordinates(world, absolute=not self.shape_only)
        if sort:
            if not self.shape_only:
                coordinates = decorate(coordinates, self.nr_objects, 0)
            coordinates = np.array(sorted(list(map(tuple, coordinates))))
        return coordinates


class SimpleEnv(SimpleRLEnvBase):
    def __init__(self):
        super().__init__()
        self._data = None
        self._cnt = None
        self.is_over = False

    def _reset_data(self):
        self._data = [0, 1, 2, 3, 4]
        random.shuffle(self._data)

    def _restart(self):
        self._cnt = 1
        self.is_over = False
        self._reset_data()

    def _action(self, action):
        assert self._cnt is not None, 'you need to call _restart() first'

        index = self._data[2]
        # print("Action:", action, "Answer:", index)
        reward = 1 if action == index else 0
        is_over = False if action == index else True
        if self._cnt == 10:
            reward = 5
            is_over = True
        else:
            self._cnt += 1
        self._reset_data()

        return reward, is_over

    def _get_current_state(self):
        assert self._cnt is not None, 'Should call restart() first.'
        return self._data

PICK_MODEL_XML_PATH = os.path.join('fetch', 'pick_and_place.xml')
REACH_MODEL_XML_PATH = os.path.join('fetch', 'reach.xml')
PUSH_MODEL_XML_PATH = os.path.join('fetch', 'push.xml')


class FetchEnvWithStates(fetch_env.FetchEnv):

    def get_sim_state(self):
        return deepcopy(self.sim.get_state())

    def set_sim_state(self, state):
        self.sim.reset()
        old_state = self.sim.get_state()
        new_state = mujoco_py.MjSimState(old_state.time, state.qpos, state.qvel,
                                         state.act, state.udd_state)
        self.sim.set_state(new_state)
        self.sim.forward()
        return self._get_obs()

class FetchPickAndPlaceEnv(FetchEnvWithStates, utils.EzPickle):
    def __init__(self, reward_type='sparse'):
        initial_qpos = {
            'robot0:slide0': 0.405,
            'robot0:slide1': 0.48,
            'robot0:slide2': 0.0,
            'object0:joint': [1.25, 0.53, 0.4, 1., 0., 0., 0.],
        }
        fetch_env.FetchEnv.__init__(
            self, PICK_MODEL_XML_PATH, has_object=True, block_gripper=False, n_substeps=20,
            gripper_extra_height=0.2, target_in_the_air=True, target_offset=0.0,
            obj_range=0.15, target_range=0.15, distance_threshold=0.05,
            initial_qpos=initial_qpos, reward_type=reward_type)
        utils.EzPickle.__init__(self)

class ObservationWrapper(gym.Env):
    '''
    Wraps an environment modifying dictionary obervations into array observations.
    '''

    def __init__(self, env, keys, relative=None, max_timesteps=10000):
        self.env = env
        self.keys = keys
        self.relative = relative
        self.max_timesteps = max_timesteps

        obs = self.env.reset()
        obs_dim = sum([obs[key].shape[0] for key in self.keys])
        print([(key, obs[key].shape[0]) for key in self.keys])
        print("obs_dim", obs_dim)
        if self.relative is not None:
            obs_dim += relative[0][2] - relative[0][1]
        self.action_space = self.env.action_space
        self.observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(obs_dim,))

    def reset(self):
        self.t = 0
        obs = self.env.reset()
        return self.flatten_obs(obs)

    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        self.t += 1
        done = done or self.t > self.max_timesteps
        return self.flatten_obs(obs), rew, done, info

    def render(self):
        return self.env.render()

    def get_sim_state(self):
        return self.env.get_sim_state()

    def set_sim_state(self, state):
        return self.flatten_obs(self.env.set_sim_state(state))

    def close(self):
        self.env.close()

    def flatten_obs(self, obs):
        flat_obs = np.concatenate([obs[key] for key in self.keys])
        if self.relative is not None:
            (key1, i1, j1) = self.relative[0]
            (key2, i2, j2) = self.relative[1]
            rel_obs = obs[key1][i1:j1] - obs[key2][i2:j2]
            flat_obs = np.concatenate([flat_obs, rel_obs])
        return flat_obs

class SimpleHighwayEnv(SimpleRLEnvBase):
    def __init__(self, num):
        super().__init__()
        self._car_num = 1
        self._env = gym.make("highway-v0")
        self._env.configure({
            "vehicles_count": 20,            #cjy
            "observation": {
                "type": "Kinematics",
                "vehicles_count": 5,
                "features_range": {
                    "x": [-100, 100],
                    "y": [-100, 100],
                    "vx": [-20, 20],
                    "vy": [-20, 20]
                },
                "absolute": False,
                "normalize": False,
                "order": "sorted",
            },
            "action": {
                "type": "ContinuousAction"
            },
            "duration": 200,
            "right_lane_reward": 0,
            "high_speed_reward": 0.8,  
            "reward_speed_range": [25, 30],     
            "collision_reward": -3, # needs to change 
            "vehicles_density": 1,
            #"offroad_terminal": True,   
        })
        self._env = ObservationWrapper(FetchPickAndPlaceEnv(), ['observation', 'desired_goal'],
                                 relative=(('desired_goal', 0, 3), ('observation', 3, 6)))
        self._cnt = None
        self._obs = None
        self._render = False                    #这边不起作用，主函数里起作用
        self._duration = 200
        self._penalty = 0                     #cjy 从25改到0
        self._test = False

        # collect information for final models
        self._test_data_file = "./test_data_correct.txt"
        self._action_list = None
        self._reward_list = None
        self._speed_list = None
        self._distance_list = None
        

    def get_target_score(self):
        return self._duration * 0.85

    def _restart(self):
        self._cnt = 1
        self._obs = self._env.reset()
        if self._test:
            self._action_list = []
            self._reward_list = []
            self._speed_list = []
            self._distance_list = []

    def _action(self, action):
        assert self._cnt is not None, 'you need to call _restart() first'
        ori_obs = self._obs
            
        obs, reward, is_over, info = self._env.step(action)      #obs就是next state
        if self._render:
            self._env.render()
        
        self._obs = obs
       
        if self._cnt >= self._duration:
            is_over = True
        self._cnt += 1

        # test file
        # if self._test:
        #     self._action_list.append(action)
        #     self._reward_list.append(reward)
        #     self._speed_list.append(self._obs[0][3])
        #     self._distance_list.append(self._obs[0][1])
        #     # episode end
        #     if is_over:
        #         # data processing
        #         length = len(self._action_list)
        #         accelerate = self._action_list.count(3)
        #         decelerate = self._action_list.count(4)
        #         left_turn = self._action_list.count(0)
        #         right_turn = self._action_list.count(2)
        #         total_reward = sum(self._reward_list)
        #         avg_speed = sum(self._speed_list) / length
        #         distance = self._distance_list[-1] - self._distance_list[0]
        #         succ = int(self._cnt == self._duration)

        #         # data output
        #         print(f"***--- New Episode ---***")
        #         print(f"[*] Length: {length}, {'Success' if bool(succ) else 'Crash'}")
        #         print(f"[*] Score: {total_reward}")
        #         print(f"[*] Average speed: {avg_speed}" )
        #         print(f"[*] Passed distance: {distance}")
        #         print(f"[*] Operations: {accelerate} acceleration, {decelerate} deceleration, "
        #               f"{left_turn} left turns, {right_turn} right turns")

        #         # data dump
        #         with open(self._test_data_file, "a+", encoding="utf8") as file:
        #             data = ""
        #             data += f"{length},{succ},{total_reward},{avg_speed},{distance},"
        #             data += f"{accelerate},{decelerate},{left_turn},{right_turn}\n"
        #             file.write(data)

        return reward, is_over

    def _get_current_state(self):
        assert self._cnt is not None, 'Should call restart() first.'
        return self._obs
