import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

__all__ = ['str2bool', 'plot_function', 'UnNormalize']

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def plot_function(target_x, target_y, context_x, context_y, pred_y, var, filename, title):
    """Plots the predicted mean and variance and the context points.
    
    Args: 
        target_x: An array of shape batchsize x number_targets x 1 that contains the
            x values of the target points.
        target_y: An array of shape batchsize x number_targets x 1 that contains the
            y values of the target points.
        context_x: An array of shape batchsize x number_context x 1 that contains 
            the x values of the context points.
        context_y: An array of shape batchsize x number_context x 1 that contains 
            the y values of the context points.
        pred_y: An array of shape batchsize x number_targets x 1  that contains the
            predicted means of the y values at the target points in target_x.
        pred_y: An array of shape batchsize x number_targets x 1  that contains the
            predicted variance of the y values at the target points in target_x.
    """
    sns.set_theme(style='white')
    #sns.color_palette('tab10')
    color = sns.color_palette('tab10')

    FIGSIZE = [40, 20]
    LINEWIDTH = 10
    MARKERSIZE=30
    FONTSIZE=70
    FONTWEIGHT='bold'

    target_x, target_y, context_x, context_y, pred_y, var = \
            target_x[:1].to('cpu').data, target_y[:1].to('cpu').data, context_x[:1].to('cpu').data, context_y[:1].to('cpu').data, pred_y[:1].to('cpu').data, var[:1].to('cpu').data
    
    plt.figure(figsize=[20, 10])

    plt.plot(target_x[0], pred_y[0], 'b', color=color[0], linewidth=LINEWIDTH, label='RC', alpha=0.5)
    
    plt.plot(target_x[0], target_y[0], 'k:', color='black', linewidth=LINEWIDTH, label='FS')
    plt.plot(context_x[0], context_y[0], 'ro', color=color[3], markersize=MARKERSIZE, label='SS')

    plt.fill_between(
        target_x[0, :, 0],
        pred_y[0, :, 0] - var[0, :, 0],
        pred_y[0, :, 0] + var[0, :, 0],
        alpha=0.2,
        facecolor='#65c9f7',
        interpolate=True)

    plt.yticks([-2, 0, 2], fontsize=FONTSIZE, fontweight=FONTWEIGHT)
    plt.xticks([-2, 0, 2], fontsize=FONTSIZE, fontweight=FONTWEIGHT)
    plt.grid(False)
    ax = plt.gca()
    
    ax.set_title(title.upper(), y=1.0, pad=-70, fontsize=FONTSIZE, fontweight=FONTWEIGHT)
    plt.tight_layout()
    plt.savefig(filename, dpi=300)
    plt.close()

class UnNormalize(object):
    def __init__(self, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
        return tensor
