import time
import datetime
import os
import numpy as np
import cv2
import torch
import glob
import json
import utils
# import h5py

class Logger():

    def __init__(self, continue_logging, logging_directory, args=None, dir_name=''):

        # Create directory to save data
        self.continue_logging = continue_logging
        if self.continue_logging:
            self.base_directory = logging_directory
            print('Pre-loading data logging session: %s' % (self.base_directory))
        else:
            if not dir_name:
                dir_name = utils.timeStamped('')
            self.base_directory = os.path.join(logging_directory, dir_name)
            print('Creating data logging session: %s' % (self.base_directory))
        self.info_directory = os.path.join(self.base_directory, 'info')
        self.color_images_directory = os.path.join(self.base_directory, 'data', 'color-images')
        self.depth_images_directory = os.path.join(self.base_directory, 'data', 'depth-images')
        self.color_heightmaps_directory = os.path.join(self.base_directory, 'data', 'color-heightmaps')
        self.depth_heightmaps_directory = os.path.join(self.base_directory, 'data', 'depth-heightmaps')
        self.models_directory = os.path.join(self.base_directory, 'models')
        self.visualizations_directory = os.path.join(self.base_directory, 'visualizations')
        self.recordings_directory = os.path.join(self.base_directory, 'recordings')
        self.transitions_directory = os.path.join(self.base_directory, 'transitions')

        if not os.path.exists(self.info_directory):
            os.makedirs(self.info_directory)
        if not os.path.exists(self.color_images_directory):
            os.makedirs(self.color_images_directory)
        if not os.path.exists(self.depth_images_directory):
            os.makedirs(self.depth_images_directory)
        if not os.path.exists(self.color_heightmaps_directory):
            os.makedirs(self.color_heightmaps_directory)
        if not os.path.exists(self.depth_heightmaps_directory):
            os.makedirs(self.depth_heightmaps_directory)
        if not os.path.exists(self.models_directory):
            os.makedirs(self.models_directory)
        if not os.path.exists(self.visualizations_directory):
            os.makedirs(self.visualizations_directory)
        if not os.path.exists(self.recordings_directory):
            os.makedirs(self.recordings_directory)
        if not os.path.exists(self.transitions_directory):
            os.makedirs(os.path.join(self.transitions_directory, 'data'))

        if args is not None:
            params_path = os.path.join(self.base_directory, 'commandline_args.json')
            with open(params_path, 'w') as f:
                json.dump(vars(args), f, sort_keys=True)

    def save_camera_info(self, intrinsics, pose, depth_scale):
        np.savetxt(os.path.join(self.info_directory, 'camera-intrinsics.txt'), intrinsics, delimiter=' ')
        np.savetxt(os.path.join(self.info_directory, 'camera-pose.txt'), pose, delimiter=' ')
        np.savetxt(os.path.join(self.info_directory, 'camera-depth-scale.txt'), [depth_scale], delimiter=' ')

    def save_heightmap_info(self, boundaries, resolution):
        np.savetxt(os.path.join(self.info_directory, 'heightmap-boundaries.txt'), boundaries, delimiter=' ')
        np.savetxt(os.path.join(self.info_directory, 'heightmap-resolution.txt'), [resolution], delimiter=' ')

    def save_images(self, iteration, color_image, depth_image, mode):
        color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
        cv2.imwrite(os.path.join(self.color_images_directory, '%06d.%s.color.png' % (iteration, mode)), color_image)
        depth_image = np.round(depth_image * 10000).astype(np.uint16) # Save depth in 1e-4 meters
        cv2.imwrite(os.path.join(self.depth_images_directory, '%06d.%s.depth.png' % (iteration, mode)), depth_image)

    def save_heightmaps(self, iteration, color_heightmap, depth_heightmap, mode, poststring=None, debug=False):
        color_heightmap = cv2.cvtColor(color_heightmap, cv2.COLOR_RGB2BGR)
        if debug:
            original_depth_heightmap = depth_heightmap.copy()

        if poststring is not None:
            color_filename = '%06d.%s.%s.color.png' % (iteration, mode, str(poststring))
            depth_filename = '%06d.%s.%s.depth.png' % (iteration, mode, str(poststring))

        else:
            color_filename = '%06d.%s.color.png' % (iteration, mode)
            depth_filename = '%06d.%s.depth.png' % (iteration, mode)

        # save color
        cv2.imwrite(os.path.join(self.color_heightmaps_directory, color_filename), color_heightmap)

        # save depth
        depth_heightmap = np.round(depth_heightmap * 100000).astype(np.uint16) # Save depth in 1e-5 meters
        depth_heightmap_path = os.path.join(self.depth_heightmaps_directory, depth_filename)
        cv2.imwrite(depth_heightmap_path, depth_heightmap)

        if debug:
            converted_depth_heightmap = depth_heightmap.astype(np.float32) / 100000
            saved_reloaded_depth_heightmap = np.array(cv2.imread(depth_heightmap_path, cv2.IMREAD_ANYDEPTH)).astype(np.float32) / 100000
            import matplotlib.pyplot as plt
            f = plt.figure()
            f.add_subplot(1,3, 1)
            plt.imshow(original_depth_heightmap)
            f.add_subplot(1,3, 2)
            # f.add_subplot(1,2, 1)
            plt.imshow(converted_depth_heightmap)
            f.add_subplot(1,3, 3)
            plt.imshow(saved_reloaded_depth_heightmap)
            plt.show(block=True)

    def write_to_log(self, log_name, log, pickle=False, fmt='%.3f'):
        # need to pickle and use savez when saving embeddings (>1 dim)
        if pickle:
            np.savez(os.path.join(self.transitions_directory, '%s.log.txt' % log_name), log)
        else:
            np.savetxt(os.path.join(self.transitions_directory, '%s.log.txt' % log_name), log, delimiter=' ', fmt=fmt)
            if fmt != '%s':
                shortlog = np.squeeze(log)
                if len(shortlog.shape) > 0:
                    np.savetxt(os.path.join(self.transitions_directory, '%s.log.csv' % log_name), shortlog, delimiter=', ', header=log_name, fmt=fmt)

    def save_model(self, model, name):
        torch.save(model.state_dict(), os.path.join(self.models_directory, 'snapshot.%s.pth' % (name)))

    def save_backup_model(self, model, name):
        torch.save(model.state_dict(), os.path.join(self.models_directory, 'snapshot-backup.%s.pth' % (name)))

    def save_visualizations(self, iteration, affordance_vis, name):
        cv2.imwrite(os.path.join(self.visualizations_directory, '%06d.%s.png' % (iteration,name)), affordance_vis)

    # def save_state_features(self, iteration, state_feat):
    #     h5f = h5py.File(os.path.join(self.visualizations_directory, '%06d.state.h5' % (iteration)), 'w')
    #     h5f.create_dataset('state', data=state_feat.cpu().data.numpy())
    #     h5f.close()

    # Record RGB-D video while executing primitive
    # recording_directory = logger.make_new_recording_directory(iteration)
    # camera.start_recording(recording_directory)
    # camera.stop_recording()
    def make_new_recording_directory(self, iteration):
        recording_directory = os.path.join(self.recordings_directory, '%06d' % (iteration))
        if not os.path.exists(recording_directory):
            os.makedirs(recording_directory)
        return recording_directory

    def save_transition(self, iteration, transition):
        depth_heightmap = np.round(transition.state * 100000).astype(np.uint16) # Save depth in 1e-5 meters
        cv2.imwrite(os.path.join(self.transitions_directory, 'data', '%06d.0.depth.png' % (iteration)), depth_heightmap)
        next_depth_heightmap = np.round(transition.next_state * 100000).astype(np.uint16) # Save depth in 1e-5 meters
        cv2.imwrite(os.path.join(self.transitions_directory, 'data', '%06d.1.depth.png' % (iteration)), next_depth_heightmap)
        # np.savetxt(os.path.join(self.transitions_directory, '%06d.action.txt' % (iteration)), [1 if (transition.action == 'grasp') else 0], delimiter=' ')
        # np.savetxt(os.path.join(self.transitions_directory, '%06d.reward.txt' % (iteration)), [reward_value], delimiter=' ')
