#!/usr/bin/env python
from __future__ import print_function
###
# install Kinova_gen3 software by following https://github.com/Kinovarobotics/ros_kortex
# type 'source devel/setup.bash' in Kinova_gen3 software workspace
# source your python2 virtualenv: source your_venv/bin/activate
# change the directory to this package, 'chmod +x RL.py
# type 'python RL.py'
###

import sys
import os
from .ROS_RL_env import ROSRLEnv
import rospy
import select, tty, termios
import time
import numpy as np
import pyrealsense2 as rs
import cv2
import gym
import copy
from kortex_driver.srv import *
from kortex_driver.msg import *
import sounddevice as sd
import pickle
from datetime import datetime
from PIL import Image, ImageEnhance


class ROSPretextEnv(ROSRLEnv):

    def __init__(self):
        ROSRLEnv.__init__(self)

        # # observation space
        d = {
            'image': gym.spaces.Box(low=0, high=255, shape=self.config.img_dim, dtype='uint8'),
            'sound_positive': gym.spaces.Box(low=-np.inf, high=np.inf,
                                             shape=self.config.sound_dim, dtype=np.float32),
            'sound_negative': gym.spaces.Box(low=-np.inf, high=np.inf,
                                             shape=self.config.sound_dim, dtype=np.float32),
            'ground_truth': self.observation_space['ground_truth']
        }

        self.observation_space = gym.spaces.Dict(d)
        self.maxSteps = self.config.pretextEnvMaxSteps
        self.saved_triplets = []  # a buffer used to save triplets during manual triplets collection

    def gen_obs(self):

        image = self.get_image()
        self.saveEpisodeImage(image)

        # update odom
        self.odom = self.get_odom()
        s = {'eeState': self.odom}

        self.ground_truth=2 
        if self.ground_truth==4:
            sound_positive = np.zeros(shape=self.config.sound_dim)
        else:
            sound_positive, positive_audio = self.audio.genSoundFeat(objIndx=self.ground_truth, featType='MFCC',
                                                                       rand_fn=self.np_random.randint)

        objIndx_negative = self.np_random.randint(0, self.config.taskNum)
        if self.ground_truth == objIndx_negative:
            sound_negative = np.zeros(shape=self.config.sound_dim)
        else:
            sound_negative, negative_audio = self.audio.genSoundFeat(objIndx=objIndx_negative, featType='MFCC',
                                                                     rand_fn=self.np_random.randint)

        if self.config.render:
            print ('Sound positive is------------------------------', self.ground_truth)

        obs = {
            'image': np.transpose(image, (2, 0, 1)),  # for PyTorch convolution,
            'sound_positive': sound_positive,
            'sound_negative': sound_negative,
            'ground_truth':np.int32(self.ground_truth),
        }

        return obs, s

    def saveTriplets(self):
        filePath = os.path.join(self.config.pretextDataDir[0], 'train')
        if not os.path.isdir(filePath):
            os.makedirs(filePath)

        datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
        filePath = os.path.join(filePath, 'data_' + datetime.now().strftime("%m_%d_%Y_%H_%M_%S") +'_gt'+str(self.ground_truth)+ '.pickle')

        with open(filePath, 'wb') as f:
            pickle.dump(self.saved_triplets, f, protocol=pickle.HIGHEST_PROTOCOL)
        self.saved_triplets[:]=[]

    def step(self, action):
        if self.config.pretextManualControl or self.config.pretextManualCollect:  # use keyboard to control the real robot
            dx, dy, dz, k=self.keyboardControl()

        else:
            raise NotImplementedError

        self.desiredEndEffectorPos[0] = np.clip(self.desiredEndEffectorPos[0] + dx, a_min=self.config.xMin, a_max=self.config.xMax)
        self.desiredEndEffectorPos[1] = np.clip(self.desiredEndEffectorPos[1] + dy, a_min=self.config.yMin, a_max=self.config.yMax)
        self.desiredEndEffectorPos[2] = self.desiredEndEffectorPos[2] + dz

        try:
            # apply action
            self.send_cartesian_pose()
            self.wait_for_action_end_or_abort()
            self.envStepCounter = self.envStepCounter + 1

            # get observations
            obs, s = self.gen_obs()

        except rospy.ROSInterruptException:
            exit()

        if k == 'r':  # save this triplet to buffer
            self.saved_triplets.append(obs)
            print("Number of triplets collected", len(self.saved_triplets))
        elif k == 'z':  # save collected triplets in the buffer to disk
            self.saveTriplets()
            print("Triplets saved to", self.config.pretextDataDir[0])

        print ('Step', self.envStepCounter)

        r = [self.rewards()]  # calculate reward
        self.reward = sum(r)
        self.episodeReward = self.episodeReward + self.reward
        self.done = self.termination()

        infoDict = {}

        return obs, self.reward, self.done, infoDict  # reset will be called if done
