import matplotlib.pyplot as plt
import numpy as np


def format_settings(
        wspace=0.25, 
        hspace=0.4, 
        left=0.12, 
        right=0.9, 
        bottom=0.15, 
        top=0.95,
        fs=12,
        show_dpi=80,
        save_dpi=300,
        lw=1.5,
        ms=5,
        axlw=1.5,
        major_tick_len=5,
        major_tick_width=1.5,
        major_tick_pad=5,
        minor_tick_len=0,
        minor_tick_width=0,
        minor_tick_pad=5,
        ):
    '''
        ：
            fig = plt.figure(figsize=(12, 4), dpi=300)
            format_settings()
            grid = plt.GridSpec(2, 2)
            ax1 = fig.add_subplot(grid[0, 0]) # 
            ax2 = fig.add_subplot(grid[0, 1]) # 
            ax3 = fig.add_subplot(grid[:, 0]) # 
        ：
            figsize12，。
            figsize，。
    '''
    # 
    plt.rcParams['lines.linewidth'] = lw
    
    # 
    plt.rcParams['lines.markersize'] = ms
    
    #   w: h:
    plt.subplots_adjust(wspace=wspace, hspace=hspace, left=left, right=right, bottom=bottom, top=top)

    # 
    plt.rcParams['font.size'] = fs
    plt.rcParams['axes.labelsize'] = fs
    plt.rcParams['axes.titlesize'] = fs
    plt.rcParams['xtick.labelsize'] =fs
    plt.rcParams['ytick.labelsize'] = fs
    plt.rcParams['legend.fontsize'] = fs
    # 
    plt.rcParams['axes.linewidth'] = axlw
    # 
    plt.rcParams['axes.spines.top'] = True
    plt.rcParams['axes.spines.right'] = True
    plt.rcParams['axes.spines.left'] = True
    plt.rcParams['axes.spines.bottom'] = True

    # 
    plt.rcParams['xtick.major.width'] = major_tick_width
    plt.rcParams['ytick.major.width'] = major_tick_width
    plt.rcParams['xtick.minor.width'] = minor_tick_width
    plt.rcParams['ytick.minor.width'] = minor_tick_width
    # 
    plt.rcParams['xtick.major.size'] = major_tick_len
    plt.rcParams['ytick.major.size'] = major_tick_len
    plt.rcParams['xtick.minor.size'] = minor_tick_len
    plt.rcParams['ytick.minor.size'] = minor_tick_len
    # 
    plt.rcParams['xtick.major.pad'] = major_tick_pad
    plt.rcParams['ytick.major.pad'] = major_tick_pad
    plt.rcParams['xtick.minor.pad'] = minor_tick_pad
    plt.rcParams['ytick.minor.pad'] = minor_tick_pad
    
    # 
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    # 
    plt.rcParams['xtick.top'] = False 
    plt.rcParams['ytick.right'] = False
    # 
    plt.rcParams['xtick.minor.visible'] = False
    plt.rcParams['ytick.minor.visible'] = False
    # 
    plt.rcParams['legend.frameon'] = False
    # 
    plt.rcParams['figure.dpi'] = show_dpi
    # 
    plt.rcParams['savefig.dpi'] = save_dpi



def get_color_list(n_colors, cmap='viridis', color_min=0.5, color_max=1, invert=False):
    r'''
        cmapn_colors
        cmap: 
            ：'Blues', 'Greens', 'Reds', 'Oranges', 'Greys', 'Purples'
            ：'viridis', 'plasma', 'inferno', 'magma', 'seismic'
        color_min: ，0.5，0.0
        color_max: 
        invert: ，，invert=True
    '''
    colormap = plt.get_cmap(cmap)
    if invert:
        color_list = [colormap(i) for i in np.linspace(color_max, color_min, n_colors)]
    else:
        color_list = [colormap(i) for i in np.linspace(color_min, color_max, n_colors)]
    return color_list


def get_color_groups(n_group, n_colors, cmap_list=None, color_min=0.5, color_max=1, invert=False):
    r'''
        ，n_colors
        cmap_list: ，None，
    '''
    if cmap_list is None:
        cmap_list = ['Blues', 'Reds', 'Greens', 'Oranges', 'Greys', 'Purples', 'YlOrBr', 'PuBuGn', 'BuPu']
        
    color_groups = [get_color_list(n_colors, cmap=cmap_list[i], color_min=color_min, color_max=color_max, invert=invert) for i in range(n_group)]
    
    return color_groups


def display_fig(img_list, width=300, margin=10, border=1):
    r'''
        jupyter notebook
    '''
    from IPython.display import display, Image, HTML
    
    html_str = ""

    if type(width) == int:
        widths = [width] * len(img_list)
    else:
        widths = width
    if type(margin) == int:
        margins = [margin] * len(img_list)
    else:
        margins = margin
    if type(border) == int:
        borders = [border] * len(img_list)
    else:
        borders = border


    for img, width, margin, border in zip(img_list, widths, margins, borders):
        html_str += f"<img style='width: {width}px; margin: {margin}px; float: left; border: {border}px solid black;' src='{img}' />"

    html_str += "<div style='clear: both;'></div>"

    display(HTML(html_str))
