import json
import os
import copy
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pandarallel import pandarallel
pandarallel.initialize(progress_bar=True)

def rename_files(directory, max_digits):
    files = os.listdir(directory)
    for filename in files:
        if filename.startswith("scene_") and filename.endswith(".json"):
            number = int(filename.split("_")[1].split(".")[0])
            new_filename = f"scene_{str(number).zfill(max_digits)}.json"
            os.rename(os.path.join(directory, filename), os.path.join(directory, new_filename))

def fix_scene_id(scene_id, max_digits):
    number = scene_id.replace("scene_", "")
    return f"scene_{number.zfill(max_digits)}"

def process_gnn_data(dataset_path):
    scenes_path = os.path.join(dataset_path, "scenes")
    data_annotated = pd.read_json(os.path.join(dataset_path, "data/data_annotated.json"))
    max_digits = len(str(data_annotated.scene_id.nunique()-1))
    rename_files(scenes_path, max_digits)
    data_annotated = data_annotated[["scene_id", "object_id", "generation_type", "frame_id", "abs_x", "abs_y", "abs_z", "abs_yaw", "feasibility", "failing_stage", 
                                     "Top", "Front", "Rear", "Right", "Left", "Top_GO", "Front_GO", "Rear_GO", "Right_GO", "Left_GO",
                                     "Nb_Top_grasps", "Nb_Front_grasps", "Nb_Rear_grasps", "Nb_Right_grasps", "Nb_Left_grasps", "planning_time"]]
    data_annotated.scene_id = data_annotated.scene_id.parallel_apply(lambda x: fix_scene_id(x, max_digits))
    data_annotated = data_annotated.dropna().sort_values(by=["scene_id"]).reset_index(drop=True)
    data_annotated["dim"] = 0
    data_annotated["pose"] = 0
    data_annotated["Top_IK"] = 0
    data_annotated["Front_IK"] = 0
    data_annotated["Rear_IK"] = 0
    data_annotated["Right_IK"] = 0
    data_annotated["Left_IK"] = 0
    data_annotated["Top_scores"] = 0
    data_annotated["Front_scores"] = 0
    data_annotated["Rear_scores"] = 0
    data_annotated["Right_scores"] = 0
    data_annotated["Left_scores"] = 0
    data_annotated["Top_IK_scores"] = 0
    data_annotated["Front_IK_scores"] = 0
    data_annotated["Rear_IK_scores"] = 0
    data_annotated["Right_IK_scores"] = 0
    data_annotated["Left_IK_scores"] = 0
    data_annotated["Top_cause"] = 0
    data_annotated["Front_cause"] = 0
    data_annotated["Rear_cause"] = 0
    data_annotated["Right_cause"] = 0
    data_annotated["Left_cause"] = 0
    data_annotated = data_annotated.parallel_apply(add_dim_pose_columns, args=(scenes_path,), axis=1)
    data_annotated = data_annotated.parallel_apply(process_ic, axis=1)
    data_annotated.to_json(os.path.join(dataset_path, "data/processed_gnn_data.json"))
    return data_annotated

def process_ic(row):
    for g in ["Top", "Front", "Rear", "Right", "Left"]:
        ic_list = []
        score_list = []
        if not row[g+"_GO"]:
            row[g+"_IK"] = 1
            row[g+"_IK_scores"] = 0
            continue
        for i, ic in enumerate(row[g+"_GO"]):
            if ic[0] == "no_ik":
                row[g+"_IK_scores"] = ic[1] / row["Nb_" + g + "_grasps"]
                if ic[1] == row["Nb_" + g + "_grasps"]:
                    row[g+"_IK"] = 0
                    row[g+"_cause"] = "IK"
                    break
                else:
                    row[g+"_IK"] = 1
            elif ic[0] == "robot" or "link" in ic[0]:
                row[g+"_IK"] = 1 
            else:
                row[g+"_IK"] = 1
                ic_list.append(ic[0])
                score_list.append(ic[1])

        if row[g] == 0 and row[g+"_IK"] == 1:
            row[g+"_cause"] = "Collision"
        row[g+"_GO"] = ic_list
        row[g+"_scores"] = score_list
    return row

def add_dim_pose_columns(row, scenes_path):
    with open(os.path.join(scenes_path, row.scene_id + ".json"), "r") as f:
        scene = json.load(f)
    row["dim"] = scene["objects"][row.object_id]["dimensions"]
    row["pose"] = scene["objects"][row.object_id]["abs_pose"][:3] + [scene["objects"][row.object_id]["abs_pose"][-1]]
    return row

def compute_distance(pose1, pose2):
    return np.linalg.norm(np.array(pose1) - np.array(pose2))

def compute_threshold(dim1, dim2):
    return (max(dim1) + max(dim2)+0.6) / 2

def switch_dimensions(row):
    l, w = copy.deepcopy(row["o1_dim"][0]), copy.deepcopy(row["o1_dim"][1])
    row["o1_dim"][0], row["o1_dim"][1] = w, l
    row["o1_pose"][-1] = (row["o1_pose"][-1] - np.pi/2) % (2*np.pi)
    l, w = copy.deepcopy(row["o2_dim"][0]), copy.deepcopy(row["o2_dim"][1])
    row["o2_dim"][0], row["o2_dim"][1] = w, l
    row["o2_pose"][-1] = (row["o2_pose"][-1] - np.pi/2) % (2*np.pi)
    f, re, ri, l = copy.deepcopy(row.Front), copy.deepcopy(row.Rear), copy.deepcopy(row.Right), copy.deepcopy(row.Left)
    row["Front"], row["Rear"], row["Right"], row["Left"] = ri, l, re, f
    row["augmentation"] += "-dimswitch-"
    return row

def process_ic_data(dataset_path):
    data = pd.read_json(os.path.join(dataset_path, "data/processed_gnn_data.json"))
    d = {"scene_id": [], "o1": [], "o2": [], "o1_pose": [], "o2_pose": [], "o1_dim": [], "o2_dim": [], "Top": [], "Front": [], "Rear": [], "Right": [], "Left": [],
         "mTop": [], "mFront": [], "mRear": [], "mRight": [], "mLeft": [], "Top_score": [], "Front_score": [], "Rear_score": [], "Right_score": [], "Left_score": []}

    seen = {}
    for i in tqdm(range(len(data))):
        with open(os.path.join(dataset_path, "scenes", data.scene_id.iloc[i] + ".json"), "r") as f:
            scene = json.load(f)
        for o in scene["objects"]:
            if o == data.object_id.iloc[i] or o == data.frame_id.iloc[i] or o == "base":
                continue
            if compute_distance(scene["objects"][data.object_id.iloc[i]]["abs_pose"][:3], 
                                scene["objects"][o]["abs_pose"][:3]) > compute_threshold(scene["objects"][data.object_id.iloc[i]]["dimensions"][:2], 
                                                                                         scene["objects"][o]["dimensions"][:2]):
                continue
            d["scene_id"].append(data.scene_id.iloc[i])
            d["o1"].append(data.object_id.iloc[i])
            d["o2"].append(o)
            d["o1_pose"].append(scene["objects"][data.object_id.iloc[i]]["abs_pose"][:3]+scene["objects"][data.object_id.iloc[i]]["abs_pose"][-1:])
            d["o2_pose"].append(scene["objects"][o]["abs_pose"][:3]+scene["objects"][o]["abs_pose"][-1:])
            d["o1_dim"].append(scene["objects"][data.object_id.iloc[i]]["dimensions"])
            d["o2_dim"].append(scene["objects"][o]["dimensions"])
            for g in ["Top", "Front", "Rear", "Right", "Left"]:
                if data[g+"_IK"].iloc[i] == 0:
                    d[g].append(0)
                    d["m"+g].append(0)
                    d[g+"_score"].append(0)
                elif o in data[g+"_GO"].iloc[i]:
                    d[g].append(1)
                    d["m"+g].append(1)
                    d[g+"_score"].append(data[g+"_scores"].iloc[i][data[g+"_GO"].iloc[i].index(o)]/data["Nb_"+g+"_grasps"].iloc[i])
                else:
                    d[g].append(0)
                    d["m"+g].append(1)
                    d[g+"_score"].append(0)

    data = pd.DataFrame(d)
    data = data.drop(data[data.parallel_apply(lambda row: row.mTop == 0 and row.mFront == 0 and row.mRear == 0 and row.mLeft == 0 and row.mRight == 0, axis=1)].index)
    data.to_json(os.path.join(dataset_path, "data/processed_ic_data.json"))

    inputs, labels, masks = to_tensors(data)
    torch.save(inputs, os.path.join(dataset_path, "data/inputs.pt"))
    torch.save(labels, os.path.join(dataset_path, "data/labels.pt"))
    torch.save(masks, os.path.join(dataset_path, "data/masks.pt"))
    return data
    
def to_tensors(data):
    inputs = torch.zeros((len(data), 14))
    labels = torch.zeros((len(data), 6))
    masks = torch.zeros((len(data), 5))
    inputs[:, :3] = torch.tensor(data.o1_dim.values.tolist())
    inputs[:, 3:7] = torch.tensor(data.o1_pose.values.tolist())
    inputs[:, 7:10] = torch.tensor(data.o2_dim.values.tolist())
    inputs[:, 10:14] = torch.tensor(data.o2_pose.values.tolist())
    labels = torch.tensor(data[["Top_score", "Front_score", "Rear_score", "Right_score", "Left_score"]].values.tolist())
    masks = torch.tensor(data[["mTop", "mFront", "mRear", "mRight", "mLeft"]].values.tolist())
    return inputs, labels, masks

def get_corners(dimensions, pose):
    half_length = dimensions[0]/2
    half_width = dimensions[1]/2
    half_height = dimensions[2]/2
    Trans = np.array([pose[0], pose[1], pose[2]])
    Rot = np.array([[np.cos(pose[-1]), -1*np.sin(pose[-1]), 0],
                    [np.sin(pose[-1]), np.cos(pose[-1]), 0],
                    [0, 0, 1]])
    
    corners = np.zeros((8,3))
    corners[0, :] = np.matmul(Rot, np.array([-half_length, -half_width, -half_height])) + Trans   #rbl
    corners[1, :] = np.matmul(Rot, np.array([-half_length, half_width, -half_height])) + Trans    #rbr
    corners[2, :] = np.matmul(Rot, np.array([-half_length, half_width, half_height])) + Trans     #rtr
    corners[3, :] = np.matmul(Rot, np.array([-half_length, -half_width, half_height])) + Trans    #rtl
    corners[4, :] = np.matmul(Rot, np.array([half_length, -half_width, half_height])) + Trans     #ftl
    corners[5, :] = np.matmul(Rot, np.array([half_length, -half_width, -half_height])) + Trans    #fbl
    corners[6, :] = np.matmul(Rot, np.array([half_length, half_width, -half_height])) + Trans     #fbr
    corners[7, :] = np.matmul(Rot, np.array([half_length, half_width, half_height])) + Trans      #ftr
    return corners

def transform_pose(pose, frame_pose, transform):
    transformed_pose = copy.deepcopy(pose)
    T = np.array([[np.cos(frame_pose[-1]), -np.sin(frame_pose[-1]), 0, frame_pose[0]],
                  [np.sin(frame_pose[-1]), np.cos(frame_pose[-1]), 0, frame_pose[1]],
                  [0, 0, 1, frame_pose[2]],
                  [0, 0, 0, 1]])
    T_inv = np.linalg.inv(T)
    transformed_pose = np.matmul(T_inv, np.array([transformed_pose[0], transformed_pose[1], transformed_pose[2], 1]))
    transformed_pose[-1] = pose[-1] - frame_pose[-1]
    T = np.array([[transform[0], 0, 0],
                    [0, transform[1], 0],
                    [0, 0, transform[2]]])
    transformed_pose[:3] = np.matmul(T, transformed_pose[:3])
    if transform[1] == -1:
        transformed_pose[-1] += np.pi
    # elif transform[0] == -1:
    #     transformed_pose[-1] = np.pi - transformed_pose[-1]
    return transformed_pose

def get_robot_mesh_points(path):
    with open(path, 'r') as file:
        data = file.read()
    vertices = []
    faces = []
    lines = data.splitlines()

    for line in lines:
        slist = line.split()
        if slist:
            if slist[0] == 'v':
                vertex = np.array(slist[1:], dtype=float)
                vertices.append(vertex)
            elif slist[0] == 'f':
                face = []
                for k in range(1, len(slist)):
                    face.append([int(s) for s in slist[k].replace('//','/').split('/')])
                if len(face) > 3: # triangulate the n-polyonal face, n>3
                    faces.extend([[face[0][0]-1, face[k][0]-1, face[k+1][0]-1] for k in range(1, len(face)-1)])
                else:
                    faces.append([face[j][0]-1 for j in range(len(face))])
            else: pass

    vertices = np.array(vertices)
    faces = np.array(faces)
    I, J, K =  faces.T
    x, y, z = vertices.T
    return x, y, z, I, J, K

def visualize_action_predictions(data, preds, robot_mesh_path=None):
    if robot_mesh_path is not None:
        x, y, z, i, j, k = get_robot_mesh_points(robot_mesh_path)
        robot_trace = go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color='grey', opacity=0.2)
        traces = [robot_trace]
    else:
        traces = []
    data.x = data.x.cpu()
    data.pos = data.pos.cpu()
    preds = preds.cpu()
    preds = torch.where(preds > 0.5, 1., 0.)
    for obj in range(data.pos.shape[0]):
        if not data.movable_mask[obj].item():
            color = "#E9E9E9"
            opacity = 0.2
        else:
            #color dependent on the value of preds[i, 0] from green to red
            color = "#%02X%02X%02X" % (int(255*(1-preds[obj, 0])), int(255*preds[obj, 0]), 0)
            opacity = 1.

        corners = get_corners(data.x[obj, :3].tolist(), data.pos[obj, :].tolist())
        traces.append(go.Mesh3d(x=corners[:,0], y=corners[:,1], z=corners[:,2],
                                    i = [7, 2, 0, 0, 4, 4, 6, 6, 4, 0, 0, 0],
                                    j = [3, 3, 1, 2, 5, 6, 7, 2, 0, 3, 6, 1],
                                    k = [4, 7, 2, 3, 6, 7, 2, 1, 5, 4, 5, 6],
                                    opacity=opacity, color=color, flatshading = True, showscale=True))

    fig = go.Figure(data=traces)
    fig.update_layout(title="Action Feasibility")
    axis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title='')
    fig.update_layout(autosize=False, width=1000, height=1000, margin=dict(l=50, r=50, b=50, t=50, pad=4),
                      scene=dict(xaxis=dict(axis), yaxis=dict(axis), zaxis=dict(axis),))
    fig.show()

def visualize_grasp_predictions(data, preds, robot_mesh_path=None):
    if robot_mesh_path is not None:
        x, y, z, i, j, k = get_robot_mesh_points(robot_mesh_path)
        robot_trace = go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color='grey', opacity=0.2)
        traces = [robot_trace]
    else:
        traces = []

    data.x = data.x.cpu()
    data.pos = data.pos.cpu()
    preds = preds.cpu()
    preds = torch.where(preds > 0.7, 1., 0.)

    for obj in range(data.pos.shape[0]):
        corners = get_corners(data.x[obj, :3].tolist(), data.pos[obj, :].tolist())
        if not data.movable_mask[obj].item():
            color = "#E9E9E9"
            opacity = 0.2
            traces.append(go.Mesh3d(x=corners[:,0], y=corners[:,1], z=corners[:,2],
                                    i = [7, 2, 0, 0, 4, 4, 6, 6, 4, 0, 0, 0],
                                    j = [3, 3, 1, 2, 5, 6, 7, 2, 0, 3, 6, 1],
                                    k = [4, 7, 2, 3, 6, 7, 2, 1, 5, 4, 5, 6],
                                    opacity=opacity, color=color, flatshading = True, showscale=True))
        else:
            #color dependent on the value of preds[i, 0] from green to red
            opacity = 1.
            facecolor = []
            for g in [1, 3, 2, 5, 4]:
                if preds[obj, g] > 0.5:
                    facecolor.extend(["#00FF00", "#00FF00"])
                else:
                    facecolor.extend(["#FF00000", "#FF00000"])
            facecolor.extend(["E9E9E9", "E9E9E9"])
            traces.append(go.Mesh3d(x=corners[:,0], y=corners[:,1], z=corners[:,2],
                                    i = [7, 2, 0, 0, 4, 4, 6, 6, 4, 0, 0, 0],
                                    j = [3, 3, 1, 2, 5, 6, 7, 2, 0, 3, 6, 1],
                                    k = [4, 7, 2, 3, 6, 7, 2, 1, 5, 4, 5, 6],
                                    facecolor=facecolor, opacity=opacity, flatshading = True, showscale=True))

    fig = go.Figure(data=traces)
    fig.update_layout(title="Grasp Feasibilty")
    axis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title='')
    fig.update_layout(autosize=False, width=1000, height=1000, margin=dict(l=50, r=50, b=50, t=50, pad=4),
                      scene=dict(xaxis=dict(axis), yaxis=dict(axis), zaxis=dict(axis),))
    fig.show()

def visualize_go_predictions(data, IK_preds, GO_preds, main_obj, robot_mesh_path=None):
    if robot_mesh_path is not None:
        x, y, z, i, j, k = get_robot_mesh_points(robot_mesh_path)
    data.x = data.x.cpu()
    data.pos = data.pos.cpu()
    data.edge_index = data.edge_index[:, data.proximity_mask].cpu()
    edges = torch.where(data.edge_index[1] == main_obj)[0]
    neighbors = data.edge_index[0, edges].tolist()
    GO_preds = GO_preds[data.proximity_mask][edges].cpu()

    d = {}
    for g, grasp in enumerate(["Top", "Front", "Rear", "Right", "Left"]):
        if robot_mesh_path is not None:
            robot_trace = go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color='grey', opacity=0.2)
            traces = [robot_trace]
        else:
            traces = []
        if IK_preds[main_obj, g] > 0.5:
            
            for obj in range(data.pos.shape[0]):
                corners = get_corners(data.x[obj, :3].tolist(), data.pos[obj, :].tolist())
                if obj != main_obj and obj not in neighbors:
                    opacity = 0.1
                    if not data.movable_mask[obj].item():
                        color = "#E9E9E9"
                    else:
                        color = "#0C76BD"
                    traces.append(go.Mesh3d(x=corners[:,0], y=corners[:,1], z=corners[:,2],
                                            i = [7, 2, 0, 0, 4, 4, 6, 6, 4, 0, 0, 0],
                                            j = [3, 3, 1, 2, 5, 6, 7, 2, 0, 3, 6, 1],
                                            k = [4, 7, 2, 3, 6, 7, 2, 1, 5, 4, 5, 6],
                                            opacity=opacity, color=color, flatshading = True, showscale=True))
                elif obj == main_obj:
                    opacity = 1.
                    color = "#0C76BD"
                    traces.append(go.Mesh3d(x=corners[:,0], y=corners[:,1], z=corners[:,2],
                                            i = [7, 2, 0, 0, 4, 4, 6, 6, 4, 0, 0, 0],
                                            j = [3, 3, 1, 2, 5, 6, 7, 2, 0, 3, 6, 1],
                                            k = [4, 7, 2, 3, 6, 7, 2, 1, 5, 4, 5, 6],
                                            opacity=opacity, color=color, flatshading = True, showscale=True))
                else:
                    opacity = 1.
                    intensity = [GO_preds[neighbors.index(obj), g].item() for i in range(12)]
                    print("Intensity : ", intensity)
                    traces.append(go.Mesh3d(x=corners[:,0], y=corners[:,1], z=corners[:,2],
                                            i = [7, 2, 0, 0, 4, 4, 6, 6, 4, 0, 0, 0],
                                            j = [3, 3, 1, 2, 5, 6, 7, 2, 0, 3, 6, 1],
                                            k = [4, 7, 2, 3, 6, 7, 2, 1, 5, 4, 5, 6],
                                            opacity=opacity, intensity=intensity, cmin=0., cmax=1., 
                                            colorscale='YlOrRd', flatshading = True, showscale=False))
        else:
            for obj in range(data.pos.shape[0]):
                corners = get_corners(data.x[obj, :3].tolist(), data.pos[obj, :].tolist())
                if obj != main_obj:
                    opacity = 0.1
                else:
                    opacity = 1.
                if not data.movable_mask[obj].item():
                    color = "#E9E9E9"
                else:
                    color = "#0C76BD"
                traces.append(go.Mesh3d(x=corners[:,0], y=corners[:,1], z=corners[:,2],
                                        i = [7, 2, 0, 0, 4, 4, 6, 6, 4, 0, 0, 0],
                                        j = [3, 3, 1, 2, 5, 6, 7, 2, 0, 3, 6, 1],
                                        k = [4, 7, 2, 3, 6, 7, 2, 1, 5, 4, 5, 6],
                                        opacity=opacity, color=color, flatshading = True, showscale=True))
        d[grasp] = traces

    fig = make_subplots(rows=1, cols=5, subplot_titles=("Top", "Front", "Rear", "Right", "Left"), 
                        specs=[[{'type': 'scene'}, {'type': 'scene'}, {'type': 'scene'}, {'type': 'scene'}, {'type': 'scene'}]])
    fig.add_traces(d["Top"], rows=1, cols=1)
    fig.add_traces(d["Front"], rows=1, cols=2)
    fig.add_traces(d["Rear"], rows=1, cols=3)
    fig.add_traces(d["Right"], rows=1, cols=4)
    fig.add_traces(d["Left"], rows=1, cols=5)
    fig.update_layout(title="Grasp Obstructions")
    axis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title='')
    fig.update_layout(autosize=False, width=1750, height=500, margin=dict(l=50, r=50, b=50, t=50, pad=4),
                      scene=dict(xaxis=dict(axis), yaxis=dict(axis), zaxis=dict(axis),))
    fig.update_scenes(xaxis=dict(axis), yaxis=dict(axis), zaxis=dict(axis))
    fig.show()

def visualize_scene(data, robot_mesh_path=None):
    if robot_mesh_path is not None:
        x, y, z, i, j, k = get_robot_mesh_points(robot_mesh_path)
        robot_trace = go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color='grey', opacity=0.2)
        traces = [robot_trace]
    else:
        traces = []
    colors = {False: "#E9E9E9", True: "#0C76BD"}
    for i in range(data.pos.shape[0]):
        if i == 0:
            color = "grey"
        corners = get_corners(data.x[i, :3].tolist(), data.pos[i, :].tolist())
        traces.append(go.Mesh3d(x=corners[:,0], y=corners[:,1], z=corners[:,2],
                                i = [7, 2, 0, 0, 4, 4, 6, 6, 4, 0, 0, 0],
                                j = [3, 3, 1, 2, 5, 6, 7, 2, 0, 3, 6, 1],
                                k = [4, 7, 2, 3, 6, 7, 2, 1, 5, 4, 5, 6],
                                opacity=1., color=colors[data.movable_mask[i].item()], flatshading = True))
    fig = go.Figure(data=traces)
    axis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title='')
    fig.update_layout(autosize=False, width=1000, height=1000, margin=dict(l=50, r=50, b=50, t=50, pad=4),
                    scene=dict(xaxis=dict(axis), yaxis=dict(axis), zaxis=dict(axis),))
    fig.show()