import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from typing import Literal
# plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['font.size'] = 16    # fontsize of 16 in figure is just the same as font in context of THU thesis

class Line():
    def __init__(self,x=[],y=[],**kwargs):
        self.x=x
        self.y=y
        self.kwargs=kwargs
    def extend_x(self,new_xs:list):
        self.x.extend(new_xs)

    def extend_y(self,new_ys:list):
        self.y.extend(new_ys)

    def append_x(self,new_x):
        self.x.append(new_x)

    def append_y(self,new_y):
        self.y.append(new_y)

    def set_args(self,**kwargs):
        self.kwargs=kwargs

    def plot(self,ax,**kwargs):
        ax.plot(self.x,self.y, **self.kwargs,**kwargs)

class LineErrorbar(Line):
    def __init__(self, x=[], y=[],y_errorbar=[],eb_method:Literal['std','std_','se']='std_'):
        super().__init__(x, y)
        self.y_errorbar=y_errorbar
        self.eb_method=eb_method

    def setYFromdf(self,df:pd.DataFrame):
        row_means = df.mean(axis=1)
        if self.eb_method=='std':
            # 样本标准差
            eb=df.std(axis=1, ddof=1)
        if self.eb_method=='std_':
            eb=df.std(axis=1, ddof=0)
        if self.eb_method=='se':
            eb=df.sem(axis=1)
        # self.x=row_means.index
        self.y=row_means
        self.y_errorbar=eb
        # repeat_time=len(new_y)
        # mean=np.mean(new_y)
        # error_bar=np.std(new_y, ddof=1) / np.sqrt(repeat_time) 
        # self.y_mean.append(mean)
        # self.y_error_bar.append(error_bar)

    def plot(self,ax,**kwargs):
        ax.errorbar(self.x, self.y, yerr=self.y_errorbar, **self.kwargs,**kwargs)
        # fmt='o-'
        # , color='blue', ecolor='red', capsize=5, label='Mean with Variance as Error Bars'

def plot1(lines:dict,savefile, title='',xlim=None,ylim=None,xlabel='',ylabel=''):
    '''
    ax.plot(x1_data,y1_data,'--o',
            color='blue',linewidth=1.0,markersize=5.0)
    ax.plot(x2_data,y2_data,color='red',linewidth=2)

    '''
    flatui = ["#9b59b6", "#3498db", "#95a5a6", "#e74c3c", "#34495e", "#2ecc71"]
    palette = sns.color_palette(flatui)
    fig , axes = plt.subplots(1,1,figsize=[10,6],dpi=200,squeeze=True)

    ax : plt.Axes = axes
    for i,(name,line) in enumerate(lines.items()):
        line.plot(ax,label=name, color=palette[i])
        # ax.plot(line.x,line.y, **line.kwargs,label=name, color=palette[i])

    ax.legend()
    ax.set_title(title)
    ax.grid()

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    fig.tight_layout()
    fig.savefig(savefile)
    plt.close()


# def plot(lines):
#     fig , axes = plt.subplots(1,2,figsize=[15,6],dpi=200,squeeze=True)

#     # relative position for subplot order
#     xt = -0.1   
#     yt = 0.98

#     # fig1
#     ax : plt.Axes = axes[0]
#     ax.plot(x1_data,y1_data,'--o',
#             color='blue',linewidth=1.0,markersize=5.0)
#     # ax.plot(x1_data,y1_data,'--o',
#     #         color='blue',linewidth=1.0,markersize=5.0)
#     # ax.plot(x2_data,y2_data,color='red',linewidth=2)

#     ax.set_title('(a)',x=xt,y=yt)
#     ax.grid()

#     ax.set_xlim([0,1])
#     ax.set_ylim([-1.0,1.0])
#     ax.set_xlabel(r'$\varphi$ / $\pi$')
#     ax.set_ylabel(r'P ($\uparrow$)')

#     ax.set_xticks(np.linspace(0,1,5))
#     ax.set_yticks(np.linspace(-1,1,5))

#     ax.text(x=0.20,y=0.8,s=r'P = 0.98')

#     twin_ax = ax.twinx()
#     twin_ax.set_ylabel(r'P ($\downarrow$)')



#     fig.tight_layout()