import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import numpy as np
import torch

'''
masked_mae: calculate mae, excluding unmasked points
see: https://discuss.pytorch.org/t/ignore-padding-area-in-loss-computation/95804/6

Inputs:
    pred - predicted label
    gt - groundtruth label
    criterion - original loss function with no reduction
    mask - mask with value 1 at pixels to keep, and value 0 at pixels to remove
'''
def masked_loss(pred, gt, criterion, mask, reduction='mean'):
    loss = criterion(pred*mask, gt*mask) # B, 1, 10, 10

    if reduction == 'mean':
        return loss.sum() / (loss.shape[0]*loss.shape[1]*mask.sum())
    elif reduction == 'batch_mean': # mean across each batch
        return loss.sum(dim=(1,2,3)) / (loss.shape[1]*mask.sum())
    elif reduction == 'none':
        return loss
    else:
        raise NotImplementedError
    
'''
clean_vf: clean VF for plotting
Parameters:
    raw_vf - raw VF in grayscale representation
    padding_mask - padding mask (true where there is padding, false elsewhere)
    factor - size increase factor
Returns: cleaned VF
'''
def clean_vf(raw_vf, padding_mask, factor=20):

    vf_unscaled = raw_vf
    if len(vf_unscaled.shape)>3:
        vf_unscaled[:,:,padding_mask] = 1
        vf_unscaled[:,:,3:5,7] = 0
    else:
        vf_unscaled[padding_mask,:] = 1
        vf_unscaled[3:5,7,:] = 0

    # increase size of vf image
    vf_resized = np.repeat(vf_unscaled, factor, axis=1)
    vf_resized = np.repeat(vf_resized, factor, axis=0)

    return vf_unscaled

def save_single_image(x, idx, pth):
    fig = plt.figure()
    x = x[idx].transpose(0,1).transpose(1,2).detach().cpu().numpy()
    plt.imshow(x, cmap='gray')
    plt.savefig(pth)
    plt.close()

def save_image(baseline,future,generated, idx, pth, mask):
    b = baseline[idx].transpose(0,1).transpose(1,2).detach().cpu().numpy()
    f = future[idx].transpose(0,1).transpose(1,2).detach().cpu().numpy()
    g = generated[idx].transpose(0,1).transpose(1,2).detach().cpu().numpy()

    b, f, g = clean_vf(b,mask), clean_vf(f,mask), clean_vf(g,mask)

    fig, ax = plt.subplots(1,3)
    ax[0].imshow(b, cmap='gray')
    ax[1].imshow(f, cmap='gray')
    ax[2].imshow(g, cmap='gray')
    plt.savefig(pth)
    plt.close()

def draw_sample_image(x, postfix, pth, mask):
    plt.figure(figsize=(8,8))
    plt.axis('off')
    plt.title(f'Visualization of {postfix}')
    x = clean_vf(x.detach().cpu(), mask)
    plt.imshow(np.transpose(make_grid(x, padding=2, normalize=True), (1,2,0)))
    plt.savefig(pth)
    plt.close()
    
'''
remove_subplot_lines: remove axes ticks, axes labels, and subplot frames
Parameters:
    ax - matplotlib axis object
'''
def remove_subplot_lines(ax):
    ax.set_xticks([])
    ax.set_yticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    
'''
plot_gt_pred_batch: plot a batch of groundtruth and predicted VF images
Parameters:
    in_batch - batch of inputs
    gt_batch - batch of groundtruth data
    pred_batch - batch of predicted data
    padding_mask - padding mask (true where there is padding, false elsewhere)
    save_pth - location to save image
    title - title for plot (default None)
'''
def plot_gt_pred_batch(in_batch, gt_batch, pred_batch, padding_mask, save_pth, sub_titles=None, title=None):
    assert len(gt_batch) == len(pred_batch) == len(in_batch)
    # batch_size = len(gt_batch)
    batch_size, n_frames, _, _ = in_batch.shape

    fig, axs = plt.subplots(2+n_frames,batch_size, figsize=(20,8))

    in_batch = torch.permute(in_batch, (0,2,3,1)).detach().cpu().numpy()
    gt_batch = torch.permute(gt_batch, (0,2,3,1)).detach().cpu().numpy()
    pred_batch = torch.permute(pred_batch, (0,2,3,1)).detach().cpu().numpy()

    for vf in range(batch_size):
        baselines = clean_vf(in_batch[vf], padding_mask)
        label = clean_vf(gt_batch[vf], padding_mask)
        pred = clean_vf(pred_batch[vf], padding_mask)
        for i in range(n_frames):
            axs[i,vf].imshow(baselines[:,:,i], cmap='gray')
        axs[n_frames,vf].imshow(label, cmap='gray')
        axs[n_frames+1,vf].imshow(pred, cmap='gray')

        for i in range(2+n_frames):
            remove_subplot_lines(axs[i,vf])

    for i in range(n_frames):
        axs[i,0].set_ylabel(f'Groundtruth {i+1}', fontsize=18)
    axs[n_frames,0].set_ylabel('Pattern', fontsize=18)
    axs[n_frames+1,0].set_ylabel('Sampled', fontsize=18)
    if title is not None:
        fig.suptitle(title)
    if sub_titles is not None:
        assert len(sub_titles) == batch_size
        for t in range(batch_size):
            axs[0,t].set_title(sub_titles[t])

    plt.tight_layout()
    plt.savefig(save_pth)
    plt.close()

SKELETON = [
            [16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13],[6,7],
            [6,8],[7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]
        ]
def plot_mocap(ax, joints):
    ax.scatter(joints[:,0], joints[:,1], s=50, c='k')
    for link in SKELETON:
        ax.plot([joints[link[0]-1,0] , joints[link[1]-1,0]],
                [joints[link[0]-1,1] , joints[link[1]-1,1]], linewidth=4)

def plot_mocap_gt_pred(batch_in, batch_gts, batch_patterns, batch_sampled, save_pth, sub_titles=None, title=None):
    h, _, _ = batch_in.shape # h, 17, 3
    H, _, _ = batch_gts.shape # H, 17, 3 (same shape for batch_patterns, batch_sampled)

    fig, axs = plt.subplots(3, h+H, figsize=(100,30))

    for i in range(3):
        for j in range(h+H):
            axs[i,j].axes.get_xaxis().set_visible(False)
            axs[i,j].axes.get_yaxis().set_visible(False)

    # histories
    for i in range(h):
        plot_mocap(axs[0,i], batch_in[i,:,:])
    # groundtruth horizons
    for i in range(H):
        plot_mocap(axs[0,h+i], batch_gts[i,:,:])
    # patterns
    for i in range(H):
        plot_mocap(axs[1,h+i], batch_patterns[i,:,:])
    # predictions
    for i in range(H):
        plot_mocap(axs[2,h+i], batch_sampled[i,:,:])

    plt.tight_layout()
    plt.savefig(save_pth)
    plt.close()   