import bisect
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from torch import Tensor


def _compute_conf_thresh(data):
    dataset_name = data['dataset_name'][0].lower()
    if dataset_name == 'scannet':
        thr = 5e-4
    elif dataset_name == 'megadepth':
        thr = 1e-4
    else:
        raise ValueError(f'Unknown dataset: {dataset_name}')
    return thr


# --- VISUALIZATION --- #
def make_trend_figures(data, x_interval, path, x_name, y_name, title, ylim=None):
    x_axis = np.array(list(range(len(data)))) * x_interval
    plt.plot(x_axis, data)
    # plt.legend(['{}-loss'.format(title)])
    plt.xlabel(x_name)
    plt.ylabel(y_name)
    plt.title(title)
    if ylim != None:
        if (np.max(data) > ylim[1]):
            ylim = (ylim[0], np.max(data)+1.0)
        plt.ylim(ylim)
    plt.savefig(path)
    plt.close()
    return

def make_multi_trend_figures(data, x_interval, path, x_name, y_name, title, legends, ylim=None):
    x_axis = np.array(list(range(len(data)))) * x_interval
    for i in range(data.shape[1]):
        plt.plot(x_axis, data[:,i])
    plt.legend(legends)
    plt.xlabel(x_name)
    plt.ylabel(y_name)
    plt.title(title)
    if ylim != None:
        if (np.max(data) > ylim[1]):
            ylim = (ylim[0], np.max(data)+1.0)
        plt.ylim(ylim)
    plt.savefig(path)
    plt.close()
    return


def make_original_figures(image, instance, path=None, dpi=100):
    # draw image pair
    assert image.shape == instance.shape
    fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
    axes[0].imshow(image, cmap='gray')
    axes[1].imshow(instance, cmap='jet')
    for i in range(2):
        axes[i].get_yaxis().set_ticks([])
        axes[i].get_xaxis().set_ticks([])
        for spine in axes[i].spines.values():
            spine.set_visible(False)
    plt.tight_layout(pad=1)
    # save the figure
    if path:
        plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
        plt.close()
    else:
        return fig
    return


def make_weight_figures(weight, ref, path=None, dpi=100):
    # draw image pair
    assert weight.shape == ref.shape
    fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
    axes[0].imshow(weight, cmap='jet')
    axes[1].imshow(ref, cmap='jet')
    for i in range(2):
        axes[i].get_yaxis().set_ticks([])
        axes[i].get_xaxis().set_ticks([])
        for spine in axes[i].spines.values():
            spine.set_visible(False)
    plt.tight_layout(pad=1)
    # save the figure
    if path:
        plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
        plt.close()
    else:
        return fig
    return
    

def make_matching_figure_with_boxes(
    img0, img1, mkpts0, mkpts1, boxes_0, boxes_1,
    color, kpts0=None, kpts1=None,
    text=[], dpi=75, path=None, ptSize=4
):
    # draw image pair
    assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
    fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
    axes[0].imshow(img0, cmap='gray')
    axes[1].imshow(img1, cmap='gray')
    for i in range(2):
        axes[i].get_yaxis().set_ticks([])
        axes[i].get_xaxis().set_ticks([])
        for spine in axes[i].spines.values():
            spine.set_visible(False)
    plt.tight_layout(pad=1)
    
    # draw matches
    if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
        fig.canvas.draw()
        transFigure = fig.transFigure.inverted()
        fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
        fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
        line_matches = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
                                             (fkpts0[i, 1], fkpts1[i, 1]),
                                             transform=fig.transFigure, c=color[i], linewidth=1)
                     for i in range(len(mkpts0))]
        
        # line_boxes_0
        box0_ul = boxes_0[:, [2, 0]]
        box0_br = boxes_0[:, [3, 1]]
        box0_ul = transFigure.transform(axes[0].transData.transform(box0_ul))
        box0_br = transFigure.transform(axes[0].transData.transform(box0_br))
        line_boxes_00 = [matplotlib.lines.Line2D((box0_ul[i, 0], box0_br[i, 0]), (box0_ul[i, 1], box0_ul[i, 1]), transform=fig.transFigure, c=color[i], linewidth=0.8) for i in range(len(box0_ul))]
        line_boxes_01 = [matplotlib.lines.Line2D((box0_br[i, 0], box0_br[i, 0]), (box0_ul[i, 1], box0_br[i, 1]), transform=fig.transFigure, c=color[i], linewidth=0.8) for i in range(len(box0_ul))]
        line_boxes_02 = [matplotlib.lines.Line2D((box0_br[i, 0], box0_ul[i, 0]), (box0_br[i, 1], box0_br[i, 1]), transform=fig.transFigure, c=color[i], linewidth=0.8) for i in range(len(box0_ul))]
        line_boxes_03 = [matplotlib.lines.Line2D((box0_ul[i, 0], box0_ul[i, 0]), (box0_br[i, 1], box0_ul[i, 1]), transform=fig.transFigure, c=color[i], linewidth=0.8) for i in range(len(box0_ul))]
        line_boxes_0 = line_boxes_00 + line_boxes_01 + line_boxes_02 + line_boxes_03
        
        # line_boxes_1
        box1_ul = boxes_1[:, [2, 0]]
        box1_br = boxes_1[:, [3, 1]]
        box1_ul = transFigure.transform(axes[1].transData.transform(box1_ul))
        box1_br = transFigure.transform(axes[1].transData.transform(box1_br))
        line_boxes_10 = [matplotlib.lines.Line2D((box1_ul[i, 0], box1_br[i, 0]), (box1_ul[i, 1], box1_ul[i, 1]), transform=fig.transFigure, c=color[i], linewidth=0.8) for i in range(len(box1_ul))]
        line_boxes_11 = [matplotlib.lines.Line2D((box1_br[i, 0], box1_br[i, 0]), (box1_ul[i, 1], box1_br[i, 1]), transform=fig.transFigure, c=color[i], linewidth=0.8) for i in range(len(box1_ul))]
        line_boxes_12 = [matplotlib.lines.Line2D((box1_br[i, 0], box1_ul[i, 0]), (box1_br[i, 1], box1_br[i, 1]), transform=fig.transFigure, c=color[i], linewidth=0.8) for i in range(len(box1_ul))]
        line_boxes_13 = [matplotlib.lines.Line2D((box1_ul[i, 0], box1_ul[i, 0]), (box1_br[i, 1], box1_ul[i, 1]), transform=fig.transFigure, c=color[i], linewidth=0.8) for i in range(len(box1_ul))]
        line_boxes_1 = line_boxes_10 + line_boxes_11 + line_boxes_12 + line_boxes_13
        
        fig.lines = line_matches + line_boxes_0 + line_boxes_1
        
        axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=ptSize)
        axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=ptSize)
        
    # put txts
    txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
    fig.text(
        0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
        fontsize=15, va='top', ha='left', color=txt_color)
    
    # save or return figure
    if path:
        plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
        plt.close()
    else:
        return fig


def make_matching_figure(
        img0, img1, mkpts0, mkpts1, color,
        kpts0=None, kpts1=None, text=[], dpi=75, path=None, ptSize=4):
    # draw image pair
    assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
    fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
    axes[0].imshow(img0, cmap='gray')
    axes[1].imshow(img1, cmap='gray')
    for i in range(2):   # clear all frames
        axes[i].get_yaxis().set_ticks([])
        axes[i].get_xaxis().set_ticks([])
        for spine in axes[i].spines.values():
            spine.set_visible(False)
    plt.tight_layout(pad=1)
    
    if kpts0 is not None:
        assert kpts1 is not None
        axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
        axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)

    # draw matches
    if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
        fig.canvas.draw()
        transFigure = fig.transFigure.inverted()
        fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
        fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
        fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
                                            (fkpts0[i, 1], fkpts1[i, 1]),
                                            transform=fig.transFigure, c=color[i], linewidth=1)
                                        for i in range(len(mkpts0))]
        
        axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=ptSize)
        axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=ptSize)

    # put txts
    txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
    fig.text(
        0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
        fontsize=15, va='top', ha='left', color=txt_color)

    # save or return figure
    if path:
        plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
        plt.close()
    else:
        return fig


def _make_evaluation_figure(data, b_id, alpha='dynamic'):
    b_mask = data['m_bids'] == b_id
    conf_thr = _compute_conf_thresh(data)
    
    img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
    img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
    kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
    kpts1 = data['mkpts1_f'][b_mask].cpu().numpy()
    
    # for megadepth, we visualize matches on the resized image
    if 'scale0' in data:
        kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]]
        kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]]

    epi_errs = data['epi_errs'][b_mask].cpu().numpy()
    correct_mask = epi_errs < conf_thr
    precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
    n_correct = np.sum(correct_mask)
    n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu())
    recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
    # recall might be larger than 1, since the calculation of conf_matrix_gt
    # uses groundtruth depths and camera poses, but epipolar distance is used here.

    # matching info
    if alpha == 'dynamic':
        alpha = dynamic_alpha(len(correct_mask))
    color = error_colormap(epi_errs, conf_thr, alpha=alpha)
    
    text = [
        f'#Matches {len(kpts0)}',
        f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
        f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
    ]
    
    # make the figure
    figure = make_matching_figure(img0, img1, kpts0, kpts1,
                                  color, text=text)
    return figure

def _make_confidence_figure(data, b_id):
    # TODO: Implement confidence figure
    raise NotImplementedError()

def make_depth_predictions(d_map0, d_map1, mode='depth'):
    assert mode in ['depth']
    figures = {mode: []}
    for b_id in range(d_map0.shape[0]):
        if mode == 'depth':
            fig = make_weight_figures(d_map0[b_id], d_map1[b_id])
        else:
            raise ValueError(f"Unknown plot mode: {mode}")
        figures[mode].append(fig)
    return figures

def make_matching_figures(data, config, mode='evaluation'):
    """ Make matching figures for a batch.
    
    Args:
        data (Dict): a batch updated by PL_LoFTR.
        config (Dict): matcher config
    Returns:
        figures (Dict[str, List[plt.figure]]
    """
    assert mode in ['evaluation', 'confidence']  # 'confidence'
    figures = {mode: []}
    for b_id in range(data['image0'].size(0)):
        if mode == 'evaluation':
            fig = _make_evaluation_figure(
                data, b_id,
                alpha=config.TRAINER.PLOT_MATCHES_ALPHA)
        elif mode == 'confidence':
            fig = _make_confidence_figure(data, b_id)
        else:
            raise ValueError(f'Unknown plot mode: {mode}')
        figures[mode].append(fig)
    return figures


def dynamic_alpha(n_matches,
                  milestones=[0, 300, 1000, 2000],
                  alphas=[1.0, 0.8, 0.4, 0.2]):
    if n_matches == 0:
        return 1.0
    ranges = list(zip(alphas, alphas[1:] + [None]))
    loc = bisect.bisect_right(milestones, n_matches) - 1
    _range = ranges[loc]
    if _range[1] is None:
        return _range[0]
    return _range[1] + (milestones[loc + 1] - n_matches) / (
        milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])


def error_colormap(err, thr, alpha=1.0):
    assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
    x = 1 - np.clip(err / (thr * 2), 0, 1)
    return np.clip(
        np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
