import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
import pickle
from typing import Literal
# plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['font.size'] = 16    # fontsize of 16 in figure is just the same as font in context of THU thesis
from .save_csv import SaveCsvHeader

def get_line_from_df(df:pd.DataFrame,col_name,x_name):
    # print(col_name,x_name,len(df.loc[:,x_name]),len(df.loc[:,col_name]))
    # assert len(df.loc[:,x_name])==len(df.loc[:,col_name])
    return Line(df.loc[:,x_name],df.loc[:,col_name])

class Line():
    def __init__(self,x=[],y=[],**kwargs):
        self.x=x
        self.y=y
        self.kwargs=kwargs
    def extend_x(self,new_xs:list):
        self.x.extend(new_xs)

    def extend_y(self,new_ys:list):
        self.y.extend(new_ys)

    def append_x(self,new_x):
        self.x.append(new_x)

    def append_y(self,new_y):
        self.y.append(new_y)

    def set_args(self,**kwargs):
        self.kwargs=kwargs

    def plot(self,ax,**kwargs):
        ax.plot(self.x,self.y, **self.kwargs,**kwargs)

class LineErrorbar(Line):
    def __init__(self, x=[], y=[],y_errorbar=[],eb_method:Literal['std','std_','se']='std_'):
        super().__init__(x, y)
        self.y_errorbar=y_errorbar
        self.eb_method=eb_method

    def setYFromdf(self,df:pd.DataFrame):
        row_means = df.mean(axis=1)
        if self.eb_method=='std':
            # 样本标准差
            eb=df.std(axis=1, ddof=1)
        if self.eb_method=='std_':
            eb=df.std(axis=1, ddof=0)
        if self.eb_method=='se':
            eb=df.sem(axis=1)
        # self.x=row_means.index
        self.y=row_means
        self.y_errorbar=eb
        # repeat_time=len(new_y)
        # mean=np.mean(new_y)
        # error_bar=np.std(new_y, ddof=1) / np.sqrt(repeat_time) 
        # self.y_mean.append(mean)
        # self.y_error_bar.append(error_bar)

    def append_y(self,new_y_list):
        row_means = np.mean(new_y_list)
        if self.eb_method=='std':
            eb=np.std(new_y_list, ddof=1)
        if self.eb_method=='std_':
            eb=np.std(new_y_list, ddof=0)
        if self.eb_method=='se':
            eb=np.std(new_y_list, ddof=1) / np.sqrt(len(new_y_list))
        self.y.append(row_means)
        self.y_errorbar.append(eb)

    def plot(self,ax,**kwargs):
        ax.errorbar(self.x, self.y, yerr=self.y_errorbar, **self.kwargs,**kwargs)
        # fmt='o-'
        # , color='blue', ecolor='red', capsize=5, label='Mean with Variance as Error Bars'

def plot1(lines:dict,savefile, title='',xlim=None,ylim=None,xlabel='',ylabel=''):
    '''
    ax.plot(x1_data,y1_data,'--o',
            color='blue',linewidth=1.0,markersize=5.0)
    ax.plot(x2_data,y2_data,color='red',linewidth=2)

    '''
    flatui = ["#9b59b6", "#3498db", "#95a5a6", "#e74c3c", "#34495e", "#ccaf2e"]
    # ["#b49d18","#34495e","#95a5a6","#9b59b6"]
    palette = sns.color_palette(flatui)
    fig , axes = plt.subplots(1,1,figsize=[10,6],dpi=200,squeeze=True)

    ax : plt.Axes = axes

    for i,(name,line) in enumerate(lines.items()):
        line.plot(ax,label=name, color=palette[i])
        # ax.plot(line.x,line.y, **line.kwargs,label=name, color=palette[i])

    ax.legend()
    ax.set_title(title)
    ax.grid()

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    fig.tight_layout()
    fig.savefig(savefile)
    plt.close()


class FigStrategy():
    def __init__(self,save_name):
        self.save_name=save_name

    @staticmethod
    def get_df(save_folder,save_name):
        save_file=os.path.join(save_folder,f"{save_name}.csv")
        saveCsvObj=SaveCsvHeader(save_file,None)
        df=saveCsvObj.read_to_df()
        return df
    
    def get_figs(save_folder):
        pass

def plotFigStrategies(save_folder,figStrategies:list[FigStrategy]):
    figs={}
    for figStrategy in figStrategies:
        figs.update(figStrategy.get_figs(save_folder))
        for name in figs:
            save_name=figStrategy.save_name
            plot1(figs[name],f"{save_folder}/{save_name}_{name}.png")

def dumpFigStrategies(save_folder,figStrategies:list[FigStrategy]):
    figs={}
    for figStrategy in figStrategies:
        figs.update(figStrategy.get_figs(save_folder))
        for name in figs:
            save_name=figStrategy.save_name
            fig=figs[name]
            print(dict(fig))
            with open(f"{save_folder}/{save_name}_{name}.pkl", 'wb') as f:
                pickle.dump(dict(fig),f)


# def plot(lines):
#     fig , axes = plt.subplots(1,2,figsize=[15,6],dpi=200,squeeze=True)

#     # relative position for subplot order
#     xt = -0.1   
#     yt = 0.98

#     # fig1
#     ax : plt.Axes = axes[0]
#     ax.plot(x1_data,y1_data,'--o',
#             color='blue',linewidth=1.0,markersize=5.0)
#     # ax.plot(x1_data,y1_data,'--o',
#     #         color='blue',linewidth=1.0,markersize=5.0)
#     # ax.plot(x2_data,y2_data,color='red',linewidth=2)

#     ax.set_title('(a)',x=xt,y=yt)
#     ax.grid()

#     ax.set_xlim([0,1])
#     ax.set_ylim([-1.0,1.0])
#     ax.set_xlabel(r'$\varphi$ / $\pi$')
#     ax.set_ylabel(r'P ($\uparrow$)')

#     ax.set_xticks(np.linspace(0,1,5))
#     ax.set_yticks(np.linspace(-1,1,5))

#     ax.text(x=0.20,y=0.8,s=r'P = 0.98')

#     twin_ax = ax.twinx()
#     twin_ax.set_ylabel(r'P ($\downarrow$)')



#     fig.tight_layout()