
import sys
import random
import numpy as np
from collections import namedtuple, defaultdict
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pyflex
from softgym.envs.tshirt_base_env import TshirtBaselineFold
from softgym.envs.corl_baseline import GCFold
import softgym.envs.tshirt_descriptor as td
import cv2
from copy import deepcopy
from sklearn.neighbors import KDTree
import networkx as nx
from utils import remove_dups

Experience = namedtuple('Experience', ('obs', 'goal', 'act', 'rew', 'nobs', 'done'))

def get_rgbd(env):
    rgbd = pyflex.render_sensor()
    rgbd = np.array(rgbd).reshape(env.camera_height, env.camera_width, 4)
    rgbd = rgbd[::-1, :, :]
    rgb = rgbd[:, :, :3]
    depth = rgbd[:, :, 3]
    img = env.get_image(env.camera_height, env.camera_width)
    mask = depth > 0
    return img, depth, mask

def get_masked(img):
    """Just used for masking goals, otherwise we use depth"""
    img_hsv = cv2.cvtColor(img,cv2.COLOR_BGR2HSV)
    mask = cv2.inRange(img_hsv, np.array([0., 15., 0.]), np.array([255, 255., 255.]))
    kernel = np.ones((3,3),np.uint8)
    morph = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    return morph

def get_rand_action(env, state, action_type='corlbaseline'):
    obs, depth, mask = get_rgbd(env)
    pick_idx = td.random_sample_from_masked_image(mask, 1)
    if action_type == 'uniform':
        place_idx = np.unravel_index(np.random.choice(obs.shape[0]*obs.shape[1], 1, replace=False), (obs.shape[0], obs.shape[1]))
        return np.array([pick_idx, place_idx])
    elif action_type == 'corlbaseline':
        u,v = pick_idx[0][0], pick_idx[1][0]
        angle = np.rad2deg(np.arctan2(100 - u, 100 - v))
        idx = int(np.round((angle / 360.0) * 8))
        idx = ((idx) % 8)
        dist = np.random.randint(3)
        return np.array([idx,dist,u,v])

class EdgeMasker(object):
    def __init__(self, env, env_type, tshirtmap_path=None, edgethresh=10):
        self.env_type = env_type
        self.edgethresh = edgethresh
        self.tshirtmap_path = tshirtmap_path

        # Build edge map
        if self.env_type == 'towel':
            self.edge_map = self.get_towel_edge_map(env.config['ClothSize'])
        if self.env_type == 'tshirt':
            self.edge_map = self.get_tshirt_edge_map()

    def get_fg_edge_mask(self, mask):
        fg_edge_mask = np.zeros_like(mask, dtype=np.uint8)
        fg_edge_mask[mask != 0] = 255.0
        canny = cv2.Canny(fg_edge_mask, 100.0, 200.0)
        kernel = np.ones((9, 9), np.uint8)
        edges = cv2.dilate(canny, kernel, iterations=1)
        edges[mask == 0] = 0.0
        edges[edges>0] = 255.0
        return edges

    def get_closest_indices(self, uv_vis, uv_vis_nonan):
        tree = KDTree(uv_vis, leaf_size=2)
        # ind contains the closest k indices to the current one
        # dist contains the distance to each of the closest indices 
        _, ind = tree.query(uv_vis_nonan, k=9)
        ind = ind[:, 1:] # remove self
        return ind

    def get_edge_idxs(self, coords, uv_vis, ind):
        # Build graph of coordinate indices
        g = nx.Graph()
        g.add_nodes_from(np.arange(len(coords)))
        for v1 in self.edge_map.keys():
            for v2 in self.edge_map[v1]:
                g.add_edge(v1, v2)

        # For each index, find neighbors with long shortest paths
        long_edges = []
        long_idxs = []
        geopath = [[] for i in range(ind.shape[0])]
        i = 0
        for idx, (u, v) in enumerate(uv_vis):
            if np.isnan(u) or np.isnan(v):
                continue
            for col, nb_idx in enumerate(ind[i]):
                u1, v1 = uv_vis[nb_idx]
                if np.isnan(u1) or np.isnan(v1) or self.edge_map[idx] == [] or self.edge_map[nb_idx] == []:
                    geopath[i].append([])
                else:
                    try:
                        path = nx.shortest_path(g, source=idx, target=nb_idx)
                        
                        geopath[i].append(path)
                            
                        if len(path) >= self.edgethresh:
                            long_edges.append((idx, nb_idx))
                            long_idxs.append((i, col))
                    except Exception as e:
                        print(e)
            i += 1
            if i >= len(ind):
                break

        # For long shortest paths, choose higher vertex
        edge_idxs = set()
        for idx1, idx2 in long_edges:
            # Add higher point 
            idx_to_add = idx1 if coords[idx1, 1] > coords[idx2, 1] else idx2
            edge_idxs.add(idx_to_add)
            
            idx_to_rem = idx2 if coords[idx1, 1] > coords[idx2, 1] else idx1
            if idx_to_rem in edge_idxs:
                edge_idxs.remove(idx_to_rem)

        edge_idx_tuples = []
        for idx1, idx2 in long_edges:
            if idx1 in edge_idxs or idx2 in edge_idxs:
                edge_idx_tuples.append((idx1, idx2))

        edge_idxs = list(edge_idxs)
        return edge_idxs, long_idxs, geopath, edge_idx_tuples

    def get_towel_edge_map(self, clothsize):
        cloth_width, cloth_height = clothsize
        all_idx = np.arange(cloth_height * cloth_width).reshape([cloth_height, cloth_width])
        senders = []
        receivers = []
        # Horizontal
        idx_s = all_idx[:, :-1].reshape(-1, 1)
        idx_r = idx_s + 1
        senders.append(idx_s)
        receivers.append(idx_r)
        # Vertical
        idx_s = all_idx[:-1, :].reshape(-1, 1)
        idx_r = idx_s + cloth_width
        senders.append(idx_s)
        receivers.append(idx_r)
        # Diagonal
        idx_s = all_idx[:-1, :-1].reshape(-1, 1)
        idx_r = idx_s + 1 + cloth_width
        senders.append(idx_s)
        receivers.append(idx_r)
        idx_s = all_idx[1:, :-1].reshape(-1, 1)
        idx_r = idx_s + 1 - cloth_width
        senders.append(idx_s)
        receivers.append(idx_r)

        # Build edge map
        senders = np.concatenate(senders, axis=0)
        receivers = np.concatenate(receivers, axis=0)
        new_senders = np.concatenate([senders, receivers], axis=0)
        new_receivers = np.concatenate([receivers, senders], axis=0)
        edges = np.concatenate([new_senders, new_receivers], axis=1)
        edge_map = defaultdict(list)
        for v1, v2 in edges:
            edge_map[v1].append(v2)
        return edge_map

    def get_tshirt_edge_map(self):
        edge_map = defaultdict(list) 
        with open(self.tshirtmap_path, 'r') as f:
            for line in f.readlines():
                v1, v2 = [int(x) for x in line.strip().split(' ')]
                if v1 == v2:
                    continue
                edge_map[v1].append(v2)
                edge_map[v2].append(v1)
        return edge_map

    def get_act_mask(self, env, coords, rgb, depth, mask):
        # Get visible u, v points
        uv = td.particle_uv_pos(env.camera_params, None, particle_pos=coords)
        uv_vis = remove_dups(env.camera_params, uv, coords, depth, rgb)
        uv_vis_nonan = uv_vis[~np.any(np.isnan(uv_vis), axis=1)]

        # Get background-foreground edge mask
        edge_mask = self.get_fg_edge_mask(mask)
        
        # Resize and combine edge points with edge mask
        fge_mask = deepcopy(edge_mask) # foreground edge mask
        fge_mask = cv2.resize(fge_mask, (200, 200))
        fge_mask[fge_mask > 0] = 255.0

        ce_mask = np.zeros_like(fge_mask) # cloth edge mask
        all_mask = deepcopy(fge_mask)
        if len(uv_vis_nonan) != 0:
            # For each visible u, v point, get close u, v points
            ind = self.get_closest_indices(uv_vis, uv_vis_nonan)

            # Compute geodesic distance and compare with ind
            edge_idxs, long_idxs, geopath, edge_idx_tuples = self.get_edge_idxs(coords, uv_vis, ind)

            uv_ints = np.rint(uv_vis[edge_idxs] / 719 * 199).astype(int)
            try:
                ce_mask[uv_ints[:, 0], uv_ints[:, 1]] = 255.0
                all_mask[uv_ints[:, 0], uv_ints[:, 1]] = 255.0
            except Exception as e:
                print(e)
                import IPython; IPython.embed()

        if False:
            # Plotting
            uv_vis = uv_vis / 719 * 199
            rgb = cv2.resize(rgb, (200, 200))
    
            fig, ax = plt.subplots(1, 2)
            ax[0].set_title('RGB')
            ax[0].imshow(rgb)

            ax[1].set_title('edge particles')
            ax[1].imshow(rgb)
            ax[1].scatter(uv_vis[edge_idxs, 1], uv_vis[edge_idxs, 0], s=3, c='r', label='edge particles')
            for idx1, idx2 in edge_idx_tuples: 
                ax[1].plot([uv_vis[idx1, 1], uv_vis[idx2, 1]], [uv_vis[idx1, 0], uv_vis[idx2, 0]], color='g', alpha=0.2)
            ax[1].legend()

            plt.tight_layout()
            plt.show()

        return all_mask, fge_mask, ce_mask
