# Basic I/O utils
import os
import sys
import yaml
import json
import time
import torch
import cv2
import random
import fnmatch
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict 
import softgym.envs.tshirt_descriptor as td
from copy import deepcopy

def generate_perlin_noise_2d(shape, res=(5*8, 5*8), ampl = 0.02):
    def f(t):
        return 6*t**5 - 15*t**4 + 10*t**3
    
    delta = (res[0] / shape[0], res[1] / shape[1])
    d = (shape[0] // res[0], shape[1] // res[1])
    grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1
    
    # Gradients
    angles = 2*np.pi*np.random.rand(res[0]+1, res[1]+1)

    gradients = np.dstack((np.cos(angles), np.sin(angles)))
    g00 = gradients[0:-1,0:-1].repeat(d[0], 0).repeat(d[1], 1)
    g10 = gradients[1:,0:-1].repeat(d[0], 0).repeat(d[1], 1)
    g01 = gradients[0:-1,1:].repeat(d[0], 0).repeat(d[1], 1)
    g11 = gradients[1:,1:].repeat(d[0], 0).repeat(d[1], 1)
    # Ramps
    n00 = np.sum(grid * g00, 2)
    n10 = np.sum(np.dstack((grid[:,:,0]-1, grid[:,:,1])) * g10, 2)
    n01 = np.sum(np.dstack((grid[:,:,0], grid[:,:,1]-1)) * g01, 2)
    n11 = np.sum(np.dstack((grid[:,:,0]-1, grid[:,:,1]-1)) * g11, 2)
    # Interpolation
    t = f(grid)
    n0 = n00*(1-t[:,:,0]) + t[:,:,0]*n10
    n1 = n01*(1-t[:,:,0]) + t[:,:,0]*n11
    p = np.sqrt(2)*((1-t[:,:,1])*n0 + t[:,:,1]*n1)
    return p / p.max() * ampl


def get_visible(camera_params, knots, coords, depth, rgb=None, zthresh=0.001):
    """Get knots that are visible in the depth image.

    Returns
        vis: bool list of length knots indicating visibility (1 visible, 0 occluded)
    """
    if depth.shape[0] < camera_params['default_camera']['height']:
        print('Warning: resizing depth')
        depth = cv2.resize(depth, (camera_params['default_camera']['height'], camera_params['default_camera']['width']))

    vis = []
    for i, uv in enumerate(knots):
        u_f, v_f = uv[0], uv[1]
        if np.isnan(u_f) or np.isnan(v_f):
            vis.append(0)
            continue
        u, v = int(np.rint(u_f)), int(np.rint(v_f))

        if u < 0 or v < 0 or u >= depth.shape[0] or v >= depth.shape[1]:
            # pixel is outside of image bounds
            knots[i] = [float('NaN'), float('NaN')]
            vis.append(0)
            continue
        
        d = depth[u, v]

        # Get depth into world coordinates
        proj_coords = td.uv_to_world_pos(camera_params, depth, u_f, v_f, particle_radius=0, on_table=False)[0:3]
        z_diff = proj_coords[1] - coords[i][1]

        # Check is well projected xyz point
        if z_diff > zthresh:
            vis.append(0)
            continue

        vis.append(1)
    
    if False: # debug visualization
        fig, ax = plt.subplots(1, 3, dpi=200)
        ax[0].set_title('depth')
        ax[0].imshow(depth)
        ax[1].set_title('occluded points\nin red')
        ax[1].imshow(depth)
        if occluded_knots != []:
            occluded_knots = np.array(occluded_knots)
            ax[1].scatter(occluded_knots[:, 1], occluded_knots[:, 0], marker='.', s=1, c='r', alpha=0.4)
        ax[2].imshow(depth)
        ax[2].set_title('unoccluded points\nin blue')
        unoccluded_knots = np.array(unoccluded_knots)
        ax[2].scatter(unoccluded_knots[:, 1], unoccluded_knots[:, 0], marker='.', s=1, alpha=0.4)
        plt.show()
        
    return vis

def remove_dups(camera_params, knots, coords, depth, rgb=None, zthresh=0.001):
    knots = deepcopy(knots)
    if depth.shape[0] < camera_params['default_camera']['height']:
        print('Warning: resizing depth')
        depth = cv2.resize(depth, (camera_params['default_camera']['height'], camera_params['default_camera']['width']))

    unoccluded_knots = []
    occluded_knots = []
    for i, uv in enumerate(knots):
        u_f, v_f = uv[0], uv[1]
        if np.isnan(u_f) or np.isnan(v_f):
            continue
        u, v = int(np.rint(u_f)), int(np.rint(v_f))

        if u < 0 or v < 0 or u >= depth.shape[0] or v >= depth.shape[1]:
            # pixel is outside of image bounds
            knots[i] = [float('NaN'), float('NaN')]
            continue
        
        d = depth[u, v]

        # Get depth into world coordinates
        proj_coords = td.uv_to_world_pos(camera_params, depth, u_f, v_f, particle_radius=0, on_table=False)[0:3]
        z_diff = proj_coords[1] - coords[i][1]

        # Check is well projected xyz point
        if z_diff > zthresh:
            # invalidate u, v and continue
            occluded_knots.append(deepcopy(knots[i]))
            knots[i] = [float('NaN'), float('NaN')]
            continue

        unoccluded_knots.append(deepcopy(knots[i]))
    
    if False: # debug visualization
        fig, ax = plt.subplots(1, 3, dpi=200)
        ax[0].set_title('depth')
        ax[0].imshow(depth)
        ax[1].set_title('occluded points\nin red')
        ax[1].imshow(depth)
        if occluded_knots != []:
            occluded_knots = np.array(occluded_knots)
            ax[1].scatter(occluded_knots[:, 1], occluded_knots[:, 0], marker='.', s=1, c='r', alpha=0.4)
        ax[2].imshow(depth)
        ax[2].set_title('unoccluded points\nin blue')
        unoccluded_knots = np.array(unoccluded_knots)
        ax[2].scatter(unoccluded_knots[:, 1], unoccluded_knots[:, 0], marker='.', s=1, alpha=0.4)
        plt.show()
        
    return knots

def getDenseCorrespondenceSourceDir():
    return "/home/exx/projects/fabric_descriptors"

def getDictFromYamlFilename(filename):
    """
    Read data from a YAML files
    """
    return yaml.load(open(filename))

def getDictFromJSONFilename(filename):
    with open(filename, "r") as stream:
        return json.load(stream)

def add_dense_correspondence_to_python_path():
    dc_source_dir = getDenseCorrespondenceSourceDir()
    sys.path.append(dc_source_dir)
    sys.path.append(os.path.join(dc_source_dir, 'pytorch-segmentation-detection'))

    # for some reason it is critical that this be at the beginning . . .
    sys.path.insert(0, os.path.join(dc_source_dir, 'pytorch-segmentation-detection', 'vision'))

def convert_to_absolute_path(path):
    """
    Converts a potentially relative path to an absolute path by pre-pending the home directory
    :param path: absolute or relative path
    :type path: str
    :return: absolute path
    :rtype: str
    """
    if os.path.isdir(path):
        return path

    home_dir = os.path.expanduser("~")
    return os.path.join(home_dir, path)

def get_current_time_unique_name():
    """
    Converts current date to a unique name
    :return:
    :rtype: str
    """

    unique_name = time.strftime("%Y%m%d-%H%M%S")
    return unique_name

def save_to_yaml(data, filename):
    """
    Save a data to a YAML file
    """
    with open(filename, 'w') as outfile:
        yaml.dump(data, outfile, default_flow_style=False)

def get_padded_string(idx, width=6):
    return str(idx).zfill(width)

def uv_to_flattened_pixel_locations(uv_tuple, image_width):
    """
    Converts to a flat tensor
    """
    flat_pixel_locations = uv_tuple[1]*image_width + uv_tuple[0]
    return flat_pixel_locations

def flattened_pixel_locations_to_u_v(flat_pixel_locations, image_width):
    """
    :param flat_pixel_locations: A torch.LongTensor of shape torch.Shape([n,1]) where each element
     is a flattened pixel index, i.e. some integer between 0 and 307,200 for a 640x480 image

    :type flat_pixel_locations: torch.LongTensor

    :return A tuple torch.LongTensor in (u,v) format
    the pixel and the second column is the v coordinate

    """
    return (flat_pixel_locations%image_width, torch.floor(flat_pixel_locations/image_width).long())

def reset_random_seed():
    SEED = 1
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)

def get_model_param_file_from_directory(model_folder, iteration=None):
    """
    Gets the 003500.pth and 003500.pth.opt files from the specified folder

    :param model_folder: location of the folder containing the param files 001000.pth. Can be absolute or relative path. If relative then it is relative to pdc/trained_models/
    :type model_folder:
    :param iteration: which index to use, e.g. 3500, if None it loads the latest one
    :type iteration:
    :return: model_param_file, optim_param_file, iteration
    :rtype: str, str, int
    """
    if not os.path.isdir(model_folder):
        path = getDenseCorrespondenceSourceDir()
        model_folder = os.path.join(path, "trained_models", model_folder)

    # find idx.pth and idx.pth.opt files
    if iteration is None:
        files = os.listdir(model_folder)
        model_param_file = sorted(fnmatch.filter(files, '*.pth'))[-1]
        iteration = int(model_param_file.split(".")[0])
        optim_param_file = sorted(fnmatch.filter(files, '*.pth.opt'))[-1]
    else:
        prefix = get_padded_string(iteration, width=6)
        model_param_file = prefix + ".pth"
        optim_param_file = prefix + ".pth.opt"

    print("model_param_file", model_param_file)
    model_param_file = os.path.join(model_folder, model_param_file)
    optim_param_file = os.path.join(model_folder, optim_param_file)

    return model_param_file, optim_param_file, iteration

def remove_occlusions(l1, l2):
    """
    Remove occluded indexes from knots_info
    """
    occ1 = [idx for idx, val in enumerate(l1) if val[0][0] != val[0][0] or val[0][1] != val[0][1]] # NaN check
    occ2 = [idx for idx, val in enumerate(l2) if val[0][0] != val[0][0] or val[0][1] != val[0][1]]
    occ = sorted(occ1 + occ2)
    occ = list(OrderedDict.fromkeys(occ)) 

    lc1 = l1[:]
    lc2 = l2[:]
    
    for i in sorted(occ, reverse=True):
        del lc1[i]
        del lc2[i]
    return lc1, lc2