# import basic libs
import os
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import linregress

def plot_hist(data: np.ndarray,
              save_path: str,
              fname: str,
              title: str = None,
              xlabel: str = None,
              ylabel: str = None,
              bins: int = 20,
              figsize: tuple = (7, 5)) -> None:
    """plot the histogram.
    Args:
        data (np.ndarray): the data.
        save_path (str): the path to save the figure.
        fname (str): the file name of the figure.
        title (str): the title of the figure.
        xlabel (str): the label of x axis.
        ylabel (str): the label of y axis.
        bins (int): the number of bins.
        figsize (tuple): the size of the figure.
    """
    # Initialise the figure and axes.
    fig, ax = plt.subplots(figsize=figsize)
    # Draw the histogram.
    ax.hist(data, bins=bins)
    ax.grid()
    ax.set(xlabel = xlabel, ylabel = ylabel, title = title)
    # save the fig
    path = os.path.join(save_path, fname)
    plt.savefig(path)
    plt.close()


def plot_multiple_hist(data_dict: dict,
                       save_path: str,
                       fname: str,
                       title: str = None,
                       xlabel: str = None,
                       ylabel: str = None,
                       bins: int = 20,
                       figsize: tuple = (7, 5)) -> None:
    """plot the multiple histograms.

    Args:
        data_dict (dict): the dictionary of the data.
        save_path (str): the path to save the figure.
        fname (str): the file name of the figure.
        title (str): the title of the figure.
        xlabel (str): the label of x axis.
        ylabel (str): the label of y axis.
        bins (int): the number of bins.
        figsize (tuple): the size of the figure.
    """
    # Initialise the figure and axes.
    fig, ax = plt.subplots(figsize=figsize)
    # Draw the histogram.
    for idx, (label, data) in enumerate(data_dict.items()):
        plt.hist(data, bins=bins, alpha=0.5, label=label)
    
    plt.grid()
    plt.legend(loc='upper right')
    
    # set the x and y labels and title
    ax.set(xlabel = xlabel, ylabel = ylabel, title = title)

    # save the fig
    path = os.path.join(save_path, fname)
    plt.savefig(path)
    plt.close()


def plot_twinx_curves(yleft: np.ndarray,
                      yright: np.ndarray,
                      x: np.ndarray,
                      save_path: str,
                      fname: str,
                      xlabel: str,
                      yleftlabel: str,
                      yrightlabel: str,
                      title: str,
                      yleftlim: list,
                      yrightlim: list,
                      figsize: tuple = (7, 6)) -> None:
    """plot two curves in one figure.

    Args:
        yleft (np.ndarray): the left y axis.
        yright (np.ndarray): the right y axis.
        x (np.ndarray): the x axis.
        save_path (str): the path to save the figure.
        fname (str): the file name of the figure.
        xlabel (str): the label of x axis.
        yleftlabel (str): the label of left y axis.
        yrightlabel (str): the label of right y axis.
        title (str): the title of the figure.
        yleftlim (list): the range of left y axis.
        yrightlim (list): the range of right y axis.
        figsize (tuple): the size of the figure.
    """
    # Initialise the figure and axes.
    fig, axleft = plt.subplots(figsize=figsize)
    axright = axleft.twinx()

    # Draw all the lines in the same plot, assigning a label for each one to be
    #   shown in the legend.
    axleft.plot(x, yleft, label=yleftlabel, c="r")
    axleft.scatter(x, yleft, s=5, c="k")
    axright.plot(x, yright, label=yrightlabel, c="b")
    axright.scatter(x, yright, s=5, c="k")
    axleft.grid(), axright.grid()

    # add the x and y labels and title
    axleft.set(xlabel = xlabel, title = title)
    axleft.set_ylabel(yleftlabel, color="r")
    axright.set_ylabel(yrightlabel, color="b")

    # Add a legend, and position it on the lower right (with no box)
    axleft.legend(frameon=False, loc="upper left")
    axright.legend(frameon=False, loc="upper right")

    # set ylim
    axleft.set_ylim(yleftlim)
    axright.set_ylim(yrightlim)
    
    # save the fig
    path = os.path.join(save_path, fname)
    fig.savefig(path)
    plt.close()


def plot_multiple_curves(data: dict,
                         save_path: str, 
                         fname: str,
                         xlabel: str,
                         xticks: list = None,
                         title: str = None,
                         ylim: list = None,
                         rot: int = 0,
                         figsize: tuple = (7, 5)) -> None:
    """plot curves in one figure for each key in dictionary.
    Args:
        data (dict): the dictionary of the curves.
        xlabel (str): the label of x axis.
        save_path (str): the path to save the figure.
        fname (str): the file name of the figure.
        xticks (list): the ticks of x axis.
        title (str): the title of the figure.
        ylim (list): the range of y axis.
        rot (int): the rotation of xticks.
        figsize (tuple): the size of the figure.
    """
    # Initialise the figure and axes.
    fig, ax = plt.subplots(figsize=figsize)
    # Draw all the lines in the same plot, assigning a label for each one to be
    #   shown in the legend.
    for label, (x, y) in data.items():
        ax.plot(x, y, label=label, marker="o", linestyle="-")
    
    ax.grid()
    ax.set(xlabel = xlabel, title = title)
    
    # Add a legend, and position it on the lower right (with no box)
    plt.legend(frameon=True, prop={'size': 10})
    plt.tight_layout()
    plt.ylim(ylim)
    
    if xticks:
        plt.xticks(range(len(xticks)), xticks)
    # rotate the xticks
    plt.xticks(rotation=rot)
    plt.subplots_adjust(bottom=0.15)

    # save the fig
    path = os.path.join(save_path, fname)
    fig.savefig(path)
    plt.close()


def plot_multi_curves_with_hlines(Y: dict,
                                  H: dict,
                                  x: np.ndarray,
                                  save_path: str, 
                                  fname: str,
                                  xlabel: str,
                                  xticks: list = None,
                                  title: str = None,
                                  ylim: list = None,
                                  rot: int = 0,
                                  seed: int = 123,
                                  figsize: tuple = (7, 5)) -> None:
    """plot curves in one figure for each key in dictionary.
    Pluse the horizontal lines.
    Args:
        Y (dict): the dictionary of the curves.
        H (dict): the dictionary of the horizontal lines.
        x (np.ndarray): the x axis.
        xlabel (str): the label of x axis.
        save_path (str): the path to save the figure.
        fname (str): the file name of the figure.
        xticks (list): the ticks of x axis.
        title (str): the title of the figure.
        ylim (list): the range of y axis.
        rot (int): the rotation of xticks.
        figsize (tuple): the size of the figure.
    """
    # Initialise the figure and axes.
    fig, ax = plt.subplots(figsize=figsize)
    # Draw all the lines in the same plot, assigning a label for each one to be
    #   shown in the legend.
    for label, y in Y.items():
        ax.plot(x, y, label=label, marker="o", linestyle="-")

    # add the horizontal lines
    # set the seed
    np.random.seed(seed)
    for label, h in H.items():
        # use different color for the horizontal lines
        ax.hlines(y=h, xmin=0, xmax=len(x), linestyles="--", label=label, colors=np.random.rand(3,)) 
       
    ax.grid()
    ax.set(xlabel = xlabel, title = title)
    
    # Add a legend, and position it on the lower right (with no box)
    plt.legend(frameon=True, prop={'size': 10})
    plt.ylim(ylim)
    
    if xticks:
        plt.xticks(range(len(xticks)), xticks)
    # rotate the xticks
    plt.xticks(rotation=rot)
    
    # save the fig
    path = os.path.join(save_path, fname)
    fig.savefig(path)
    plt.close()


def plot_correlation(y: np.ndarray,
                     x: np.ndarray,
                     save_path: str, 
                     fname: str,
                     ylabel: str,
                     xlabel: str,
                     title: str = None,
                     figsize: tuple = (7, 5)) -> None:
    """plot the correlation between x and y

    Args
        y (np.ndarray): the y values.
        x (np.ndarray): the x values.
        save_path (str): the path to save the figure.
        fname (str): the file name of the figure.
        ylabel (str): the label of y axis.
        xlabel (str): the label of x axis.
        title (str): the title of the figure.
        figsize (tuple): the size of the figure.
    """
    # Calculate the linear correlation and regression line
    slope, intercept, r_value, p_value, std_err = linregress(x, y)
    # Create the regression line
    line = slope * x + intercept
    # Initialise the figure and axes
    fig, ax = plt.subplots(figsize=figsize)
    # Plot the data and the regression line
    ax.scatter(x, y, label='Data points')
    ax.plot(x, line, color='red', label='Fitted line')
    # add r_value and p_value, slop, intercept to the figure
    ax.text(0.1, 0.9, f"r_value: {r_value:.3f}", transform=ax.transAxes)
    ax.text(0.1, 0.85, f"p_value: {p_value:.3f}", transform=ax.transAxes)
    ax.text(0.1, 0.8, f"slope: {slope:.3f}", transform=ax.transAxes)
    ax.text(0.1, 0.75, f"intercept: {intercept:.3f}", transform=ax.transAxes)
    ax.set(xlabel = xlabel, ylabel = ylabel, title = title)
    plt.legend()
    # save the fig
    path = os.path.join(save_path, fname)
    fig.savefig(path)

def save_images(imgs: torch.Tensor, 
                save_path: str,
                filename: str,  
                nrow: int = 1) -> None:
    """save images.
    Args:
        imgs (torch.Tensor): the images to save.
        save_path (str): the path to save images.
        filename (str): the filename of images.
        nrow (int): the number of images to save in a row.
    Returns:
        None
    """
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    # imgs = imgs / 2 + 0.5     # unnormalize
    # assert imgs.ndim == 3
    # for t, m, s in zip(imgs, (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)):
    #     t.mul_(s).add_(m)
    torchvision.utils.save_image(imgs, 
                                 os.path.join(save_path, filename), 
                                 nrow=nrow)