import torch
import sys
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

sys.path.insert(0, '..')
from utils.distance import hungarian, chamfer

from scipy.stats import spearmanr, kendalltau

def plot_pc(*pc, ax=None, d_pred=None):

    assert len(pc) in [1,2]

    if len(pc) == 2:
        dist, matching = hungarian(torch.from_numpy(pc[0]), torch.from_numpy(pc[1]), return_matching=True)
    else:
        dist = None

    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = None

    for _pc in pc:
        ax.scatter(_pc[:, 0].flatten(), _pc[:, 1].flatten())
        ax.set_aspect('equal')
        ax.set_xlim(-2, 2)
        ax.set_ylim(-2, 2)

    if len(pc) == 2:
        idxs = np.random.choice(len(pc[0]), len(pc[0])//10, replace=False)
        ax.quiver(pc[0][:, 0][idxs], pc[0][:, 1][idxs], (pc[1][:, 0][matching[0][1]] - pc[0][:, 0])[idxs], (pc[1][:, 1][matching[0][1]]-pc[0][:, 1])[idxs], angles='xy', scale_units='xy', scale=1)

    if dist is not None:
        title = f'd : {dist.item():.2f} | ' + r'$d_{norm}$: ' + f'{dist.item()/len(pc[0]):.2f}'

        if d_pred is not None:
            title += '\n'
            title += r'$\hat{d}$' + f' : {d_pred:.2f} | ' + r'$\hat{d}_{norm}$: ' + f'{d_pred/len(pc[0]):.2f}'
        ax.set_title(title)

    if fig is not None:
        fig.tight_layout()
        
    return ax

def plot_pc_3D(*pc, ax=None, d_pred=None):

    assert len(pc) in [1,2]

    if len(pc) == 2:
        dist, matching = hungarian(torch.from_numpy(pc[0]), torch.from_numpy(pc[1]), return_matching=True)
    else:
        dist = None

    if ax is None:
        fig, ax = plt.subplots(subplot_kw=dict(projection="3d"))
    else:
        fig = None

    for _pc in pc:
        # ax.scatter(_pc[:, 0].flatten(), _pc[:, 1].flatten())
        ax.scatter(_pc[:, 0], _pc[:, 1], _pc[:, 2], marker='o', s=20)
        ax.set_aspect('auto')
        # ax.set_xlim(-2, 2)
        # ax.set_ylim(-2, 2)

    if len(pc) == 2:
        idxs = np.random.choice(len(pc[0]), len(pc[0])//10, replace=False)
        # ax.quiver(pc[0][:, 0][idxs], pc[0][:, 1][idxs], pc[0][:, 2][idxs], (pc[1][:, 0][matching[0][1]] - pc[0][:, 0])[idxs], (pc[1][:, 1][matching[0][1]]-pc[0][:, 1])[idxs], (pc[1][:, 2][matching[0][1]]-pc[0][:, 2])[idxs],  angles='xy', scale_units='xy', scale=1)
        ax.quiver(pc[0][:, 0][idxs], pc[0][:, 1][idxs], pc[0][:, 2][idxs],  (pc[1][:, 0][matching[0][1]] - pc[0][:, 0])[idxs], (pc[1][:, 1][matching[0][1]]-pc[0][:, 1])[idxs], (pc[1][:, 2][matching[0][1]]-pc[0][:, 2])[idxs], color='black')

    if dist is not None:
        # title = f'd : {dist.item():.2f} | ' + r'$d_{norm}$: ' + f'{dist.item()/len(pc[0]):.2f}'
        title = f'EMD : {dist.item():.2f}'

        if d_pred is not None:
            title += '\n'
            title += r'$\hat{d}$' + f' : {d_pred:.2f} | ' + r'$\hat{d}_{norm}$: ' + f'{d_pred/len(pc[0]):.2f}'
        ax.set_title(title)#, fontdict={'fontsize':50, 'fontweight': 'medium'})

    ax.view_init(270,135)
    # ax.view_init(0, -90, 0, vertical_axis='z')
    ax.set_zticklabels([]);

    # if fig is not None:
    #     fig.tight_layout()
        
    return ax

def plot_pcs(pc_source, pc_target, idxs, pred_dists=None):

    if pc_source.shape[-1] == 3:
        plot_fn = plot_pc_3D
    else:
        plot_fn = plot_pc

    if isinstance(pc_source, torch.Tensor):
        pc_source = pc_source.numpy()
    if pc_target is not None and isinstance(pc_target, torch.Tensor):
        pc_target = pc_target.numpy()

    if pc_source.shape[-1] == 3:
        fig, axs = plt.subplots(1, len(idxs), figsize=(3*len(idxs), 3), subplot_kw={'projection': '3d'})
    else:
        fig, axs = plt.subplots(1, len(idxs), figsize=(3*len(idxs), 4))

    for i, idx in enumerate(idxs):
        if pred_dists is None:
            if pc_target is not None:
                plot_fn(pc_source[idx], pc_target[idx], ax=axs[i])
            else:
                plot_fn(pc_source[idx], ax=axs[i])
        else:
            plot_fn(pc_source[idx], pc_target[idx], ax=axs[i], d_pred=pred_dists[idx].item())

    fig.tight_layout()

    return fig

def plot_pcs_recon_v1(pc_source, pc_target, idxs, dist_fn=None, dist_fn_true=None):

    if dist_fn_true == 'chamfer':
        _dist_fn_true = chamfer
    elif dist_fn_true == 'hungarian':
        _dist_fn_true = hungarian
    else:
        raise NotImplemented

    if pc_source.shape[-1] == 3:
        plot_fn = plot_pc_3D
    else:
        plot_fn = plot_pc

    if isinstance(pc_source, torch.Tensor):
        pc_source = pc_source.numpy()
    if isinstance(pc_target, torch.Tensor):
        pc_target = pc_target.numpy()

    if pc_source.shape[-1] == 3:
        fig, axs = plt.subplots(2, len(idxs), figsize=(3*len(idxs), 3*2), subplot_kw={'projection': '3d'})
    else:
        fig, axs = plt.subplots(2, len(idxs), figsize=(3*len(idxs), 3*2))

    for i, idx in enumerate(idxs):
        dist = _dist_fn_true(torch.from_numpy(pc_source[idx]), torch.from_numpy(pc_target[idx]))
        title = 'd : ' + f'{dist.item():.2f}'

        if dist_fn is not None:
            device = next(dist_fn.parameters()).device
            apprx_dist = dist_fn(torch.from_numpy(pc_source[idx]).to(device), None, torch.from_numpy(pc_target[idx]).to(device), None, False)
            title += ' | ' + r'$\hat{d}$ :' + f'{apprx_dist.item():.2f}'

        axs[0][i].set_title(title)
        plot_fn(pc_source[idx], ax=axs[0][i])
        plot_fn(pc_target[idx], ax=axs[1][i])
        
    fig.tight_layout()

    return fig

def plot_pcs_recon_v2(pc_source, pc_target, idxs, dist_fn_pred=None, dist_fn_true=None):

    if dist_fn_true == 'chamfer':
        _dist_fn_true = chamfer
    elif dist_fn_true == 'hungarian':
        _dist_fn_true = hungarian
    else:
        raise NotImplemented

    if pc_source.shape[-1] == 3:
        plot_fn = plot_pc_3D
    else:
        plot_fn = plot_pc

    if isinstance(pc_source, torch.Tensor):
        pc_source = pc_source.numpy()
    if isinstance(pc_target, torch.Tensor):
        pc_target = pc_target.numpy()

    if pc_source.shape[-1] == 3:
        fig, axs = plt.subplots(1, len(idxs), figsize=(4*len(idxs), 4*1), subplot_kw={'projection': '3d'})
    else:
        fig, axs = plt.subplots(1, len(idxs), figsize=(4*len(idxs), 4*1))

    if not isinstance(axs, np.ndarray):
        axs = [axs]

    for i, idx in enumerate(idxs):
        dist = _dist_fn_true(torch.from_numpy(pc_source[idx]), torch.from_numpy(pc_target[idx]))
        # dist = chamfer(torch.from_numpy(pc_source[idx]), torch.from_numpy(pc_target[idx]))
        title = 'EMD : ' + f'{dist.item():.2f}'

        if dist_fn_pred is not None:
            try:
                device = next(dist_fn_pred.parameters()).device
                apprx_dist = dist_fn_pred(torch.from_numpy(pc_source[idx]).to(device), None, torch.from_numpy(pc_target[idx]).to(device), None, False)
            except:
                apprx_dist = dist_fn_pred(torch.from_numpy(pc_source[idx]), None, torch.from_numpy(pc_target[idx]), None, False)
            title += ' | ' + r'$\hat{d}$ :' + f'{apprx_dist.item():.2f}'

            
        plot_fn(pc_source[idx], pc_target[idx], ax=axs[i])
        axs[i].set_title(title)
        axs[i].title.set_size(20)
        
    # fig.tight_layout()

    return fig

plot_pcs_recon = plot_pcs_recon_v2

def plot_corr_v1(true, pred):
    fig, ax = plt.subplots()
    ax.scatter(true, pred);
    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
    ]
    ax.plot(lims, lims, 'k-', alpha=0.75, zorder=0)
    ax.set_aspect('equal');
    ax.set_xlabel('True distance');
    ax.set_ylabel('Predicted distance');
    ax.set_title(f'Corr = {np.corrcoef(true,pred)[0,1]:.4f}')

    return fig

def plot_corr_v2(true, pred):
    snsfig = sns.jointplot(x=true, y=pred, kind='scatter', height=5, marginal_kws=dict(bins=50, fill=True))
    snsfig.ax_joint.set_xlabel('True distance', fontsize=15)
    snsfig.ax_joint.set_ylabel('Predicted distance', fontsize=15)
    lims = [
        np.min([snsfig.ax_joint.get_xlim(), snsfig.ax_joint.get_ylim()]),  # min of both axes
        np.max([snsfig.ax_joint.get_xlim(), snsfig.ax_joint.get_ylim()]),  # max of both axes
    ]
    snsfig.ax_joint.plot(lims, lims, 'k-', alpha=0.75, zorder=0)
    snsfig.ax_marg_x.set_title(f'r = {np.corrcoef(true,pred)[0,1]:.4f}' + r' | $\rho$ = ' + f'{spearmanr(true,pred).correlation:.4f}' + r' | $\tau$ = ' + f'{kendalltau(true,pred).correlation:.4f}', fontsize=15)
    fig = snsfig.fig
    fig.tight_layout()

    return fig

plot_corr = plot_corr_v2

if __name__ == '__main__':
    a = torch.randn(1000)
    b = torch.randn(1000)
    fig = plot_corr_v2(a, b)
    fig.tight_layout()
    fig.savefig('test.png')