import torch
from torch.utils import data

import numpy as np
import os
import os.path as osp
import glob
import cv2

import matplotlib.pyplot as plt

class DataGenerator(object):
    
    def __init__(self, batch_size, data_split=0.8, plot=False, augment_counter=0, DIRS=""):
        self.WINDOW_LEN=11

        self.data_split = data_split
        self.batch_size = batch_size
        self.augment_counter = augment_counter
        self.plot = plot
        self.max_len = 0
        self.traj_len = 240
        
        self.joints_of_interest = ['r_shoulder_pan_joint', 
                                   'r_shoulder_lift_joint', 
                                   'r_upper_arm_roll_joint', 
                                   'r_elbow_flex_joint', 
                                   'r_forearm_roll_joint', 
                                   'r_wrist_flex_joint', 
                                   'r_wrist_roll_joint']

        all_norm_ranges = {}
        
        # values taken from the PR2 URDF file and PR2 manual
        all_norm_ranges['torso_lift_joint'] =  {'max': 0.33, 'min': 0.0}
        all_norm_ranges['r_shoulder_pan_joint'] =  {'max': 0.714601836603, 'min': -2.2853981634}
        all_norm_ranges['r_shoulder_lift_joint'] =  {'max': 1.3963, 'min': -0.5236}
        all_norm_ranges['r_upper_arm_roll_joint'] =  {'max': 0.8, 'min': -3.9}
        all_norm_ranges['r_elbow_flex_joint'] =  {'max': 0.0, 'min': -2.3213}
        all_norm_ranges['r_forearm_roll_joint'] =  {'max': 6.283185307179586, 'min': 0}
        all_norm_ranges['r_wrist_flex_joint'] =  {'max': 0.0, 'min': -2.18}
        all_norm_ranges['r_wrist_roll_joint'] =  {'max': 6.283185307179586, 'min': 0}
        all_norm_ranges['r_gripper_joint'] =  {'max': 0.09, 'min': 0.0}
        
        self.ranges = [all_norm_ranges[x] for x in self.joints_of_interest]
        
        self.groups = {0 : ["left", "right"], 
                       1 : ["front", "back"],
                       2 : ["soft", "hard"],
                       3 : ["short", "long"],
                       4 : ["slow", "fast"]}
        
        self.folder_names = []
        for DIR in DIRS:
            for key, items in self.groups.items():
                for item in items:
                    self.folder_names.append(DIR + item)
    
    def smooth(self, x, window_len=11, window='hanning'):
        """smooth the data using a window with requested size.

        This method is based on the convolution of a scaled window with the signal.
        The signal is prepared by introducing reflected copies of the signal 
        (with the window size) in both ends so that transient parts are minimized
        in the begining and end part of the output signal.

        input:
            x: the input signal 
            window_len: the dimension of the smoothing window; should be an odd integer
            window: the type of window from 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'
                flat window will produce a moving average smoothing.

        output:
            the smoothed signal

        example:

        t=linspace(-2,2,0.1)
        x=sin(t)+randn(len(t))*0.1
        y=smooth(x)

        see also: 

        numpy.hanning, numpy.hamming, numpy.bartlett, numpy.blackman, numpy.convolve
        scipy.signal.lfilter

        TODO: the window parameter could be the window itself if an array instead of a string
        NOTE: length(output) != length(input), to correct this: return y[(window_len/2-1):-(window_len/2)] instead of just y.
        """

        if x.ndim != 1:
            raise ValueError("smooth only accepts 1 dimension arrays.")

        if x.size < window_len:
            raise ValueError("Input vector needs to be bigger than window size.")


        if window_len<3:
            return x


        if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']:
            raise ValueError("Window is on of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'")


        s=np.r_[x[window_len-1:0:-1],x,x[-2:-window_len-1:-1]]
        #print(len(s))
        if window == 'flat': #moving average
            w=np.ones(window_len,'d')
        else:
            w=eval('np.'+window+'(window_len)')

        y=np.convolve(w/w.sum(),s,mode='valid')
        return y
    
    def normalize_eff(self, x):
        min_value = -10
        max_value = 10

        norm_const = max_value - min_value

        # normalise in the [0, 1] range
        x = (x - min_value) / norm_const

        # normalise in the [0, 1] range
        x = (x - 0.5) / 0.5
        
        return x
    
    def normalize(self, x):
        for i in range(x.shape[1]):
            
            # cater for the joints with no limit
            if i == 4 or i == 6:
                x[:, i] = x[:, i] % 6.28

            x[:, i] = (x[:, i] - self.ranges[i]['min']) / (self.ranges[i]['max'] - self.ranges[i]['min'])

        # normalise in the [-1, 1] range
        x = (x - 0.5) / 0.5

        return x
    
    def augment(self, pos, eff):
        # add random noise
        pos = pos + np.random.uniform(low=-0.1, high=0.1, size=pos.shape).astype(np.float32)
        eff = eff + np.random.uniform(low=-0.5, high=0.5, size=eff.shape).astype(np.float32)
        
#         cut a prefix
        prefix_idx = np.random.randint(20)
        pos = pos[prefix_idx:]
        eff = eff[prefix_idx:]
        
        return pos, eff
    
    def process_img(self, image_paths, train_flag, augment=False):
        orig_dim = 256
        desired_dim = 128
        
        image_list = []
        for idx, image_path in enumerate(image_paths):
            bgr = cv2.imread(image_path)
            bgr = bgr / 255.
            
            scale = desired_dim/float(orig_dim)
            bgr = cv2.resize(bgr.copy(), (0,0), fx=scale, fy=scale).astype(np.float32)
            
            image_list.append(np.swapaxes(bgr, 0, 2))
            
            if self.plot:
                fig = plt.figure(figsize=(5, 5))
                plt.imshow(cv2.cvtColor((bgr.copy()*255).astype(np.uint8), cv2.COLOR_BGR2RGB), vmin=0, vmax=255)
                plt.show()
                
            if idx == 0:
                hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
                mask = np.logical_and(hsv[:, :, 0] > 330, hsv[:, :, 0] < 360).astype(np.uint8)

                kernel = np.ones((3, 3), np.uint8)
                mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)

                idxs = np.argwhere(mask==1)

                if np.sum(idxs) < 30:
                    return

                pos_x = np.mean(idxs[:, 1]) / 128
                pos_y = np.mean(idxs[:, 0]) / 128
                
                if augment:
                    noise = 0.05
                    pos_x += np.random.uniform(low=-noise, high=noise)
                    pos_y += np.random.uniform(low=-noise, high=noise)
                
                if train_flag:
                    self.train_cube_pos.append([pos_x, pos_y])
                else:
                    self.test_cube_pos.append([pos_x, pos_y])

        # ADD the data point to the relevant train/test split
        if train_flag:
            self.train_images.append(image_list)
        else:
            self.test_images.append(image_list)
    
    def process_traj(self, position_traj, effort_traj, train_flag, label_group_idx, label):
        
        if len(effort_traj) > self.max_len:
            self.max_len = len(effort_traj)
        
        # NORMALISE all values in the [-1, 1] range
        position_traj_norm = self.normalize(position_traj)
        effort_traj_norm = self.normalize_eff(effort_traj)
            
        # EXTEND all channels for each effor_traj to be of the same length
        if len(effort_traj_norm) < self.traj_len:
            pad = np.repeat(position_traj_norm[-1][None], self.traj_len - len(position_traj_norm), axis=0)
            pad += np.random.uniform(low=-0.01, high=0.01, size=(len(pad), len(self.joints_of_interest)))
            position_traj_norm = np.concatenate((position_traj_norm, pad))
            
            pad = np.repeat(effort_traj_norm[-1][None], self.traj_len - len(effort_traj_norm), axis=0)
            pad += np.random.uniform(low=-0.01, high=0.01, size=(len(pad), len(self.joints_of_interest)))
            effort_traj_norm = np.concatenate((effort_traj_norm, pad))
        else:
            position_traj_norm = position_traj_norm[:self.traj_len]
            effort_traj_norm = effort_traj_norm[:self.traj_len]
        
        # SMOOTH out noise
        position_traj_norm_smooth = position_traj_norm.copy()
        effort_traj_norm_smooth = effort_traj_norm.copy()
        for i in range(effort_traj_norm.shape[1]):
            position_traj_norm_smooth[:, i] = self.smooth(position_traj_norm[:, i], window_len=self.WINDOW_LEN)[:-(self.WINDOW_LEN-1)]  
            effort_traj_norm_smooth[:, i] = self.smooth(effort_traj_norm[:, i], window_len=self.WINDOW_LEN)[:-(self.WINDOW_LEN-1)]  
        
        # ADD the data point to the relevant train/test split
        both_traj = np.concatenate((position_traj_norm_smooth, effort_traj_norm_smooth), axis=1)
        if train_flag:
            self.train_traj.append(both_traj)
            self.train_labels.append([])
            for idx in range(len(self.groups.keys())):
                if idx == label_group_idx:
                    self.train_labels[-1].append(label)
                else:
                    self.train_labels[-1].append(100)
                    
        else:
            self.test_traj.append(both_traj)
            self.test_labels.append([])
            for idx in range(len(self.groups.keys())):
                if idx == label_group_idx:
                    self.test_labels[-1].append(label)
                else:
                    self.test_labels[-1].append(100)
        
    def generate_data(self):
        self.train_traj = []
        self.train_images = []
        self.train_cube_pos = []
        self.train_labels = []

        self.test_traj = []
        self.test_images = []
        self.test_cube_pos = []
        self.test_labels = []

        for folder_name in self.folder_names:
            print("Loading {0}".format(folder_name))
            
            label_group_idx = [i for i in self.groups if folder_name.split('/')[-1] in self.groups[i]][0]
            demo_folders = np.array(os.listdir(folder_name))
            
            number_of_folders = len(demo_folders)
            train_n = int(self.data_split * number_of_folders)
            test_n = number_of_folders - train_n

            train_indecies = np.random.choice(range(number_of_folders), train_n, replace=False)
            test_indecies = np.array(list(filter(lambda x : x not in train_indecies, range(number_of_folders))))
            
            train_folders = np.take(demo_folders, train_indecies)
            test_folders = np.take(demo_folders, test_indecies)

            for demo_folder in demo_folders[:]:

                if "others" in demo_folder:
                    continue
                    
                label = self.groups[label_group_idx].index(folder_name.split("/")[-1])
                train_flag = demo_folder in train_folders

                # LOAD the names of all joints for which efforts were recorder
                joint_names = []
                joint_names_file = glob.glob(osp.join(folder_name, demo_folder, "joint_names*"))

                with open(joint_names_file[0], "r") as f:
                    joint_names = f.read().split("\n")[:-1] # escape the last empty line when spliting on \n 
                indecies_of_interest = [joint_names.index(x) for x in self.joints_of_interest]

                # INDEX all files containing effort information
                image_name_prefix = 'kinect2_qhd_image_color_rect_'
                image_files = glob.glob(osp.join(folder_name, demo_folder, image_name_prefix + '*'))
                image_files = [int(x.split('_')[-1].replace('.jpg', '')) for x in image_files]
                image_files.sort()
                image_files = [osp.join(folder_name, demo_folder, image_name_prefix + str(x) + '.jpg') for x in image_files]
                
                # TAKE ONLY FIRST IMAGE
                image_files = [image_files[0], image_files[-1]]
                
                joint_eff_prefix = 'joint_effort_'
                joint_effort_files = glob.glob(osp.join(folder_name, demo_folder, joint_eff_prefix + '*'))
                joint_effort_files = [int(x.split('_')[-1].replace('.txt', '')) for x in joint_effort_files]
                joint_effort_files.sort()
                joint_effort_files = [osp.join(folder_name, demo_folder, joint_eff_prefix + str(x) + '.txt') for x in joint_effort_files]

                # LOAD up all effort values for the selected joints
                effort_traj = []
                for joint_effort_file in joint_effort_files:
                    with open(joint_effort_file, "r") as f:
                        effort_traj.append(np.take(f.read().split("\n")[:-1], indecies_of_interest))
                        
                joint_pos_prefix = 'joint_position_'
                joint_position_files = glob.glob(osp.join(folder_name, demo_folder, joint_pos_prefix + '*'))
                joint_position_files = [int(x.split('_')[-1].replace('.txt', '')) for x in joint_position_files]
                joint_position_files.sort()
                joint_position_files = [osp.join(folder_name, demo_folder, joint_pos_prefix + str(x) + '.txt') for x in joint_position_files]

                # LOAD up all position values for the selected joints
                position_traj = []
                for joint_position_file in joint_position_files:
                    with open(joint_position_file, "r") as f:
                        position_traj.append(np.take(f.read().split("\n")[:-1], indecies_of_interest))
                
                # INITIAL traj
                effort_traj = np.array(effort_traj, dtype=np.float32)
                position_traj = np.array(position_traj, dtype=np.float32)
                self.process_img(image_files, train_flag)
                self.process_traj(position_traj.copy(), effort_traj.copy(), train_flag, label_group_idx, label)
                                
                # AUGMENTED trajs
                for _ in range(self.augment_counter):
                    augmented_position_traj, augmented_effort_traj = self.augment(position_traj.copy(), effort_traj.copy())
                    self.process_img(image_files, train_flag, augment=False)
                    self.process_traj(augmented_position_traj, augmented_effort_traj, train_flag, label_group_idx, label)
        
        self.train_traj = np.swapaxes(self.train_traj, 1, 2)
        self.test_traj = np.swapaxes(self.test_traj, 1, 2)
    
        # INIT data loaders
        trainset = data.TensorDataset(torch.tensor(self.train_traj), torch.tensor(self.train_images), torch.tensor(self.train_cube_pos), torch.tensor(self.train_labels))
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, shuffle=True, num_workers=0)

        testset = data.TensorDataset(torch.tensor(self.test_traj), torch.tensor(self.test_images), torch.tensor(self.test_cube_pos), torch.tensor(self.test_labels))
        testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=False, num_workers=0)
                
        return trainloader, testloader, self.train_traj, self.test_traj, self.train_images, self.test_images, self.train_cube_pos, self.test_cube_pos, self.train_labels, self.test_labels