import os, json
import numpy as np
from tqdm import tqdm
from scipy.optimize import linear_sum_assignment
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import plotly.graph_objects as go

import torch
import torch.nn as nn
import torch.nn.functional as F

from model import BuildGrid

def hungarian_loss(x, y):
    batch_size, num_afford, afford_size = x.shape
    xx = x[...,None,:].expand((batch_size,num_afford,num_afford,afford_size)) #.permute(0,3,1,2)
    yy = y[...,None,:,:].expand((batch_size,num_afford,num_afford,afford_size)).float() #.permute(0,3,1,2)
    # [batch_size, num_afford, num_afford]
    pairwise_cost = torch.mean(nn.HuberLoss(reduction='none')(xx,yy),dim=-1)
    indices = np.transpose(np.array(list(map(linear_sum_assignment, pairwise_cost.detach().cpu().numpy()))), (0,2,1)).reshape(-1,2)
    indices = indices[:,1] + np.arange(0,batch_size*num_afford*num_afford,num_afford)
    return torch.sum(torch.stack([pairwise_cost.flatten()[idx] for idx in indices]))/batch_size

def hungarian_loss_meanIoU(x, y):
    batch_size, num_afford, num_vox = x.shape
    xx = x[...,None,:].expand((batch_size,num_afford,num_afford,num_vox)).int()
    yy = y[...,None,:,:].expand((batch_size,num_afford,num_afford,num_vox)).int()
    # [batch_size, num_afford, num_afford]
    pairwise_cost = torch.sum(torch.logical_and(xx, yy),dim=-1)/torch.sum(torch.logical_or(xx, yy),dim=-1)
    pairwise_cost = -torch.nan_to_num(pairwise_cost, nan=1)
    indices = np.transpose(np.array(list(map(linear_sum_assignment, pairwise_cost.detach().cpu().numpy()))), (0,2,1)).reshape(-1,2)
    indices = indices[:,1] + np.arange(0,batch_size*num_afford*num_afford,num_afford)
    return -torch.sum(torch.stack([pairwise_cost.flatten()[idx] for idx in indices]))/batch_size/num_afford


def ari(gt_slot, recon_slot, device=torch.device('cpu'), num_max_afford=25):
    batch_size = gt_slot.shape[0]
    # gt_slot = torch.LongTensor(gt_slot).to(device).view(batch_size, -1)
    # recon_slot = torch.LongTensor(recon_slot).to(device).view(batch_size, -1) + 1
    gt_slot = gt_slot.view(batch_size, -1).long().to(device)
    recon_slot = recon_slot.view(batch_size, -1).long().to(device) + 1
    true_mask_oh = F.one_hot(gt_slot, num_classes=num_max_afford).float()
    pred_mask_oh = F.one_hot(recon_slot, num_classes=num_max_afford).float()
    n_points = torch.sum(true_mask_oh, dim=[1, 2])
    nij = torch.einsum('bji,bjk->bki', pred_mask_oh, true_mask_oh)
    a = torch.sum(nij, dim=1)
    b = torch.sum(nij, dim=2)

    rindex = torch.sum(nij * (nij - 1), dim=[1, 2])
    aindex = torch.sum(a * (a - 1), dim=1)
    bindex = torch.sum(b * (b - 1), dim=1)
    expected_rindex = aindex * bindex / (n_points*(n_points-1))
    max_rindex = (aindex + bindex) / 2
    ari = (rindex - expected_rindex) / (max_rindex - expected_rindex)
    return ari

def idx2afford(idx):
    # idx to affordance
    affordance_list = ["null",
                   "sittable", "support", "framework", "containment", "liquidcontainment",
                   "handle", "display", "cutting", "backrest", "armrest",
                   "pressable", "openable", "hanging", "wrapgrasp", "illumination",
                   "lyable", "headrest", "step", "pourable", "twistable",
                   "rollable", "lever", "pinchable", "audible"]
    return(affordance_list[idx])

def onehot2afford(onehot, threshold=.5):
    return idx2afford(np.argmax(onehot))


def optimizer_to_device(optimizer, device):
    for state in optimizer.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                state[k] = v.to(device)

def visualize_voxels(voxels, afford_label=None, out_file=None, show=False):
    r''' Visualizes voxel data.
    Args:
        voxels (tensor): voxel data. e.g., shape: (32,32,32)
        out_file (string): output file
        show (bool): whether the plot should be shown
    '''
    # Use numpy
    voxels = np.asarray(voxels)
    # Create plot
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    if afford_label:
        ax.set_title(afford_label, fontdict={"fontsize":48})
    voxels = voxels.transpose(2, 0, 1)
    ax.voxels(voxels, edgecolor='k')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.view_init(elev=30, azim=45)
    if out_file is not None:
        plt.savefig(out_file)
    if show:
        plt.show()
    plt.close(fig)

def visualize_voxels_cuboids(voxels, cuboid_vertices, afford_label=None, cuboid_color=None, xyz_range=None, out_file=None):
    build_grid = BuildGrid([32,32,32])
    voxel_grid = build_grid([1,32,32,32])[0]
    voxel_sparse = voxel_grid[voxels>0][:].numpy()
    ys, zs, xs = np.split(voxel_sparse, [1,2], axis=-1)
    
    data=[go.Scatter3d(x=xs.flatten(), y=ys.flatten(), z=zs.flatten(),
                                    mode='markers',marker=dict(size=3, color='#1f609e'),
                                    showlegend=False)]
    for cuboid_idx in range(cuboid_vertices.shape[0]):
        cuboid_ys, cuboid_zs, cuboid_xs = np.split(cuboid_vertices[cuboid_idx], [1,2], axis=-1)
        vert_ind_group0 = [0,1,2,3,0]
        x_0 = [cuboid_xs[i][0] for i in vert_ind_group0]
        y_0 = [cuboid_ys[i][0] for i in vert_ind_group0]
        z_0 = [cuboid_zs[i][0] for i in vert_ind_group0]
        data.append(go.Scatter3d(x=x_0, y=y_0, z=z_0,
                                    mode='lines+markers',
                                    line=dict(color=cuboid_color[cuboid_idx])))
        vert_ind_group1 = [4,5,6,7,4]
        x_1 = [cuboid_xs[i][0] for i in vert_ind_group1]
        y_1 = [cuboid_ys[i][0] for i in vert_ind_group1]
        z_1 = [cuboid_zs[i][0] for i in vert_ind_group1]
        data.append(go.Scatter3d(x=x_1, y=y_1, z=z_1,
                                    mode='lines+markers',
                                    line=dict(color=cuboid_color[cuboid_idx])))
        for start_idx in range(4):
            data.append(go.Scatter3d(x=cuboid_xs[start_idx::4].flatten(), y=cuboid_ys[start_idx::4].flatten(), z=cuboid_zs[start_idx::4].flatten(),
                                    mode='lines',
                                    line=dict(color=cuboid_color[cuboid_idx])))
    fig = go.Figure(data)
    fig.update_traces(showlegend=False)
    if not xyz_range:
        xyz_range = [np.min(xs)-.1,np.max(xs)+.1,
                     np.min(ys)-.1,np.max(ys)+.1,
                     np.min(zs)-.1,np.max(zs)+.1]
    fig.update_layout(
        title = {
        'text': afford_label,
        'font': dict(size=48),
        'x': .5
        },
        scene = dict(
            xaxis = dict(range=[*xyz_range[:2]]),
            yaxis = dict(range=[*xyz_range[2:4]]),
            zaxis = dict(range=[*xyz_range[4:]]),
            aspectratio=dict(x=1, y=1, z=1))
    )
    fig.write_image(out_file, scale=3)
    return xyz_range

def merge_test_image(instance_path, num_slots = 6):
    fig, ax = plt.subplots(1, num_slots + 2, figsize=(7, 1), dpi=600)
    img = mpimg.imread(os.path.join(instance_path, 'gt.png'))
    ax[0].imshow(img)
    ax[0].set_title('Voxel')
    img = mpimg.imread(os.path.join(instance_path, 'recon.png'))
    ax[1].imshow(img)
    ax[1].set_title('Recon.')
    for i in range(num_slots):
        img = mpimg.imread(os.path.join(instance_path, 'recon_'+str(i)+'.png'))
        ax[i + 2].imshow(img)
        ax[i + 2].set_title('Slot %s' % str(i + 1))
    for i in range(len(ax)):
        ax[i].grid(False)
        ax[i].axis('off')
    fig.savefig(os.path.join(instance_path, "merged.png"))
    fig.clf() # clear the figure
    plt.close(fig)

