import numpy as np 
import os

def create_sampling_time(logMaxIter, log_scale=True):
    num_points = int(2000)
    maxIter = int(10**logMaxIter)
    if log_scale:
        sample_num = np.geomspace(1, 10**logMaxIter, num_points, endpoint=False, dtype=int)
    else:
        sample_num = np.arange(0, maxIter, step=(maxIter)/num_points, dtype=int) 
    return sample_num

class plot_figure(object):
    def __init__(self, algo_name, dir, sub_sample=20, log_flag_ = False, plot_steadystate = False, metric = "dist2ps"):
        self.sub_sample = sub_sample
        self.dir = dir
        self.algo_name = algo_name
        self.log_flag = log_flag_
        self.metric = metric
        self.num_trails, self.res, self.xvals = self.load_data(algo_name)
        self.z = 1.96/np.sqrt(self.num_trails) # 95% confidence， 1.645-90%
    
    def import_package(self):
        import matplotlib 
        import os 
        os.environ['PATH'] = '/usr/bin/pdflatex:' + os.environ['PATH']
        matplotlib.rcParams['ps.useafm'] = True
        matplotlib.rcParams['pdf.use14corefonts'] = True
        matplotlib.rcParams['text.usetex'] = False

    def load_data(self, algo_name):
        file = [f for f in os.listdir(self.dir) if algo_name in f]
        files = [np.load(os.path.join(self.dir, f), allow_pickle=True) for f in file]
        xvals = files[0].item().get('iter_ss')
        num_points = int(1e4)
        log_indices = np.logspace(0, np.log10(len(xvals) - 1), num=num_points).astype(int)
        sample_iter = np.array(xvals)[log_indices]
        res = []
        for f in files:
            data = np.array(f.item().get(self.metric))
            res.append(data[log_indices])
        res = np.array(res)
        return len(file), res, sample_iter

    def plot_lines(self, ax, color, line='-', label='', plot_star = False, shadow_flag=True, legend=True, ax_insert = None, range_=None):
        mean = np.mean(self.res, axis = 0)
        std = np.std(self.res, axis = 0)
        if self.metric.find('acc') != -1:
            print(f"{self.metric}, final acc is {mean[-1]}")

        lb = np.squeeze(mean - self.z * std / np.sqrt(self.num_trails))
        ub = np.squeeze(mean + self.z * std / np.sqrt(self.num_trails))

        ax.plot( self.xvals[0::self.sub_sample], mean[0::self.sub_sample], label=label, color=color,linestyle=line, linewidth=2)
        if shadow_flag:
            ax.fill_between(self.xvals[0::self.sub_sample], lb[0::self.sub_sample], ub[0::self.sub_sample], color=color, alpha=.05)
        if legend:
            ax.legend(prop={'size': 12})
        if self.log_flag:
            ax.set_xscale('log')
            ax.set_yscale('log')
        ax.grid()
        ax.tick_params(axis='both', which='major', labelsize=13)
        if plot_star:
            ax.plot(self.xvals[0], mean[0], marker = '*', color = color, markerfacecolor=color,ms=15)
        if ax_insert is not None:
            left_end, right_end, ub, lb = range_
            left_end_index = next((i for i, x in enumerate(self.xvals) if x > left_end), -1)
            right_end_index = next((i for i, x in enumerate(self.xvals) if x > right_end), -1)
            ax_insert.plot( self.xvals[left_end_index:right_end_index:self.sub_sample], mean[left_end_index:right_end_index:self.sub_sample], label=label, color=color,linestyle=line, linewidth=2)
            ax_insert.set_xscale('log')
            ax_insert.grid(True)
            ax_insert.set_ylim([lb, ub])
            ax_insert.minorticks_off()

    
    