import torch
import matplotlib.pyplot as plt
plt.rcParams.update({'pdf.fonttype': 42})
import math

def put_state_dict_on_gpu(state_dict, gpu_id):
    new_state_dict = {}
    for key, param in state_dict.items():
        new_state_dict[key] = param.cuda(gpu_id)
    return new_state_dict
    
    
    
def set_size(width, fraction=1):
    """Set figure dimensions to avoid scaling in LaTeX.

    Parameters
    ----------
    width: float
            Document textwidth or columnwidth in pts
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy

    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    # Width of figure (in pts)
    fig_width_pt = width * fraction

    # Convert from pt to inches
    inches_per_pt = 1 / 72.27

    # Golden ratio to set aesthetic figure height
    # https://disq.us/p/2940ij3
    golden_ratio = (5**.5 - 1) / 2

    # Figure width in inches
    fig_width_in = fig_width_pt * inches_per_pt
    # Figure height in inches
    fig_height_in = fig_width_in * golden_ratio

    fig_dim = (fig_width_in, fig_height_in)

    return fig_dim
    
    
def plot_trends(trends, x_axis, y_axis, start = 0, end = float('inf'), save_in = None, dataset_folder = None, name = None):
    
    fig = plt.figure()
    fig.set_facecolor('white')
    plt.xlabel(x_axis)
    plt.ylabel(y_axis)
    plt.grid(True)
    
    shapes = ["8", "s", "p", "P", "*", "h", "H", "x", "d", "D"]
    
    for i , case in enumerate(trends):
        trend = trends[case]
        X, Y = trend[x_axis], trend[y_axis]
        Z = zip(X, Y)
        X, Y = [], []
        for z in Z:
            if z[0] >= start and z[0] <= end:
                X.append(z[0])
                Y.append(z[1])
        count = len(X)
        plt.plot(X, Y, marker=shapes[i], label=case, markevery= math.ceil(count * 0.1))
    
    plt.legend()
    if name != None:
        plt.savefig(f'Results/{dataset_folder}/{name}_{y_axis}_{x_axis}.pdf') 
    plt.show()
    
    
def plot_save(trends, x_axis, y_axis, start = 0, end = float('inf'), name = None):
    
    # Using seaborn's style
    plt.style.use('seaborn')
    width = 430
    # Initialise figure instance
    fig, ax = plt.subplots(1, 1, figsize=set_size(width))

    ax.set_xlabel(x_axis, fontsize=16)
    ax.set_ylabel(y_axis, fontsize=16)

    
    shapes = ["8", "s", "p", "P", "*", "h", "H", "x", "d", "D"]
    
    for i , case in enumerate(trends):
        trend = trends[case]
        X, Y = trend[x_axis], trend[y_axis]
        Z = zip(X, Y)
        X, Y = [], []
        for z in Z:
            if z[0] >= start and z[0] <= end:
                X.append(z[0])
                Y.append(z[1])
        count = len(X)

        ax.plot(X, Y, marker=shapes[i], label=case, markevery= math.ceil(count * 0.1))

    plt.legend()
    fig.savefig(f'{name}_{y_axis}_{x_axis}.pdf', format='pdf', bbox_inches='tight')
    plt.show()