import os
import re 
import inspect
import scipy
import json
import time
import collections
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes, mark_inset

savepath = None

class HashableDict(dict):
    def __hash__(self) -> int:
        return hash(tuple(self.items()))

def get_info(datapath):
    reobj = re.compile(r"(\d*)-(\d*)")
    cases = repeat = 0
    for s in os.listdir(datapath):
        mobj = reobj.match(s)
        if mobj:
            cases = max(cases, int(mobj[1]) + 1)
            repeat = max(repeat, int(mobj[2]) + 1)

    return cases, repeat

def read_log(filepath, keys):
    result = HashableDict()
    with open(filepath, "r") as f:
        for line in f.readlines():
            for info in keys:
                if line.startswith(info[0]):
                    result[info[0]] = info[1](line[len(info[0])+2:].strip())
    return result

def extract_cond(path):
    cond = []
    # first, try to load from train_loss.txt and test_loss.txt (For Poisson Equation)
    if os.path.exists(os.path.join(path, "train_loss.txt")) and os.path.exists(os.path.join(path, "test_loss.txt")):
        train_loss = np.loadtxt(os.path.join(path, "train_loss.txt"))
        test_loss = np.loadtxt(os.path.join(path, "test_loss.txt"))
        cond = [train_loss, test_loss]
    
    else: # if not exist, try log.txt (For Wave Equation)
        with open(os.path.join(path, "log.txt"),"r") as f:
            reobj = re.compile(r"^(\d+)\s+\[(.*)\]\s+\[(.*)\]\s+\[(.*)\]\s+$")
            for line in f.readlines():
                mobj = reobj.match(line)
                if mobj:
                    cond.append((float(mobj[2]), float(mobj[3])))
    
    cond = np.array(cond)
    return 1 / np.sqrt(cond.min()) # The loss is actually 1 / cond ** 2, convert back to real condition number

def extract_relerr(path):
    return np.loadtxt(os.path.join(path, "errors.txt"))[-1, 5]

def extract_relerr_hist(path):
    return np.loadtxt(os.path.join(path, "errors.txt"))[:, 5]

def extract_time(path):
    with open(os.path.join(path, "run.log"), "r") as f:
        for line in f.readlines()[::-1]: # find from last line
            mobj = re.match(r"Training costs: ([\d\.]+)s", line)
            if mobj:
                return float(mobj[1])
    print(f"\033[33mWarning: Failed in finding training time in {path}\033[0m")
    return np.nan

def extract_precond_time(path):
    with open(os.path.join(path, "run.log"), "r") as f:
        for line in f.readlines()[::-1]: # find from last line
            mobj = re.match(r"Computing preconditioner costs: ([\d\.]+)s", line)
            if mobj:
                return float(mobj[1])
    print(f"\033[33mWarning: Failed in finding precondition time in {path}\033[0m")
    return np.nan

def extract_time_pinn(path):
    with open(os.path.join(path, "log.txt"), "r") as f:
        for line in f.readlines()[::-1]:
            mobj = re.match(r"'train' took ([\d\.]+) s", line)
            if mobj:
                return float(mobj[1])
    print(f"\033[33mWarning: Failed in finding training time in {path}\033[0m")
    return np.nan

def extract_train_hist(path):
    hist = []
    with open(os.path.join(path, "run.log"), "r") as f:
        for line in f.readlines():
            mobj = re.match(r".*L2RE ([\d\.]+e[-\+\d]+)", line)
            if mobj:
                hist.append(float(mobj[1]))
    return hist

def process_case_repeat(datapath, cases, repeat, func):
    result = []
    for i in range(cases):
        result.append([])
        for j in range(repeat):
            path = os.path.join(datapath, f"{i}-{j}")
            if not os.path.exists(path):
                print(f"\033[33mWarning:Could not find path {path}.\033[0m")
                result[-1].append(np.nan)
            result[-1].append(func(path))
    
    return np.array(result)

def process_time_based(datapath, func, exp_start_time=None, exp_end_time=None, keys=[("Problem Name", str)]):
    ret = collections.defaultdict(list)
    dup = {}
    for path in os.listdir(datapath):
        mobj = re.match(r"([\d-]+)_([-A-Za-z0-9]+)_(\d+)_(\d+)(_\d+)?", path)
        if mobj:
            date = time.strptime(mobj[1], "%Y-%m-%d-%H-%M-%S")
            name = mobj[2]
            repeat = int(mobj[3])

            if (not exp_start_time or exp_start_time <= date) and \
               (not exp_end_time or date <= exp_end_time):
                log = read_log(os.path.join(datapath, path, "run.log"), keys)
                if dup.get((log, repeat)):
                    print(f"\033[33mWarning: Duplicated Experiment {log} found. Repeat {repeat}\033[0m")
                    print(f"\033[33m    Experiment 1: {dup[(log, repeat)]}\033[0m")
                    print(f"\033[33m    Experiment 2: {mobj[1]}\033[0m")
                dup[(log, repeat)] = mobj[1]
                ret[log].append(func(os.path.join(datapath, path)))
    print("Experiment find:", len(dup))
    return ret

def getcond(datapath = None):
    datapath = input("cond path:") if datapath is None else datapath
    if datapath == "":
        return None
    if datapath.endswith(".txt"):
        with open(datapath, "r") as f:  return np.array(eval(f.readlines()[1]))
    global savepath; savepath = datapath
    cases, repeat = get_info(datapath)
    cond = process_case_repeat(datapath, cases, repeat, extract_cond)
    print(cond)
    return np.nanmax(cond, axis=1)
    
def gettrain(datapath=None, mean=True):
    datapath = input("train path:") if datapath is None else datapath
    if datapath == "":
        return None, None
    global savepath; savepath = datapath
    cases, repeat = get_info(datapath)
    relerr = process_case_repeat(datapath, cases, repeat, extract_relerr)
    print(relerr)
    relerr_hist = process_case_repeat(datapath, cases, repeat, extract_relerr_hist)
    if mean:
        return np.nanmean(relerr, axis=1), relerr_hist
    else:
        return relerr, relerr_hist

def gettime(datapath=None, exp_start_time=None, exp_end_time=None, keys=[("Problem Name", str)]):
    condtime = process_time_based(datapath, extract_time, exp_start_time, exp_end_time, keys)
    return condtime

def getcondtime(datapath=None, exp_start_time=None, exp_end_time=None, keys=[("Problem Name", str)]):
    precondtime = process_time_based(datapath, extract_precond_time, exp_start_time, exp_end_time, keys)
    return precondtime
            
def gettime_pinn(datapath=None):
    cases, repeat = get_info(datapath)
    runtime = process_case_repeat(datapath, cases, repeat, extract_time_pinn)
    return runtime

def gethist(datapath=None, exp_start_time=None, exp_end_time=None, keys=[("Problem Name", str)]):
    hist = process_time_based(datapath, extract_train_hist, exp_start_time, exp_end_time, keys)
    return hist

def final_save(): # Save Hardcoded Training & Condition Data
    wave_cond = getcond("runs/09.18-16.48.34Wave_NNCond")
    helm_cond = getcond("results/helm_thycond.txt")
    burg_cond = getcond("results/burger_thycond.txt")

    wave_relr, wave_hist = gettrain("runs/09.22-19.24.48Wave_NoTransfer", mean=False)
    helm_relr, helm_hist = gettrain("runs/09.15-20.45.03Helmholtz", mean=False)
    burg_relr, burg_hist = gettrain("runs/09.12-15.47.11Burger", mean=False)

    P_list = np.linspace(1, 5, num=41)
    pois_cond_thy = 4 / P_list ** 2
    pois_cond_nn = getcond("runs/09.20-20.01.39Poisson_NNCond")
    pois_cond_matrix = np.loadtxt("results/poisson_thycond.txt")

    ablation_exp_hist = gethist("../logs",
            exp_start_time = time.strptime("2023-09-22-14-20-06", "%Y-%m-%d-%H-%M-%S"),
            exp_end_time   = time.strptime("2023-09-22-15-28-05", "%Y-%m-%d-%H-%M-%S"),
            keys=[("Problem Name", str), ("Use Preconditioner", (lambda s:s=='True')), ("Drop Tolerance (in ILU)", float)])
    ablation_exp_hist = {f"{key['Problem Name']}-" + \
                         f"{key['Drop Tolerance (in ILU)'] if key['Use Preconditioner'] else 'None'}":val \
                         for key,val in ablation_exp_hist.items()}
    ablation_exp_cond = {
        "Poisson2d-C":   [1.10e+0, 2.82e+0, 1.52e+1, 6.03e+1, 1.13e+2],
        "Poisson2d-CG":  [1.01e+0, 1.19e+0, 2.55e+0, 7.22e+0, 1.27e+1],
        "Poisson3d-CG":  [6.77e+0, 1.17e+0, 1.38e+0, 1.77e+0, 2.20e+0],
        "Poisson2d-MS":  [3.23e+0, 3.25e+1, 2.47e+2, 3.42e+2, 3.39e+0],
    }

    cpinn_times = gettime("../logs", 
            exp_start_time = time.strptime("2023-09-27-16-06-06", "%Y-%m-%d-%H-%M-%S"),
            exp_end_time   = time.strptime("2023-09-27-17-02-36", "%Y-%m-%d-%H-%M-%S"))
    burgers1d_cpinn_time = cpinn_times[HashableDict({"Problem Name": "Burgers1d-C"})]
    heat2d_cpinn_time = cpinn_times[HashableDict({"Problem Name": "Heat2d-VC"})]
    ns2d_cpinn_time = cpinn_times[HashableDict({"Problem Name": "NS2d-C"})]
    wave1d_cpinn_time = cpinn_times[HashableDict({"Problem Name": "Wave1d-C"})]
    
    pinn_times = gettime_pinn("../../PINNacle/runs/09.26-20.59.40PINN_Time_Test")

    poisson3d_times = gettime("../logs", 
            exp_start_time = time.strptime("2023-09-27-20-52-34", "%Y-%m-%d-%H-%M-%S"),
            exp_end_time   = time.strptime("2023-09-27-22-31-56", "%Y-%m-%d-%H-%M-%S"),
            keys=[("Problem Name", str), ("Grid Size (mesh length)", float)])
    poisson3d_cpinn_time = []
    for mesh in [0.01, 0.02, 0.03, 0.04, 0.05]:
        poisson3d_cpinn_time.append(
            poisson3d_times[HashableDict({"Problem Name": "Poisson3d-CG", 
                                          "Grid Size (mesh length)": mesh})]
        )
    poisson3d_times = getcondtime("../logs",
            exp_start_time = time.strptime("2023-09-27-20-52-34", "%Y-%m-%d-%H-%M-%S"),
            exp_end_time   = time.strptime("2023-09-27-22-31-56", "%Y-%m-%d-%H-%M-%S"),
            keys=[("Problem Name", str), ("Grid Size (mesh length)", float)])
    poisson3d_precond_time = []
    for mesh in [0.01, 0.02, 0.03, 0.04, 0.05]:
        poisson3d_precond_time.append(
            poisson3d_times[HashableDict({"Problem Name": "Poisson3d-CG", 
                                          "Grid Size (mesh length)": mesh})]
        )
    # (mean, std)
    poisson3d_fenics_time = [[2.48e+01, 1.92e+00], 
                             [2.28e+00, 7.87e-01], 
                             [5.25e-01, 3.25e-02], 
                             [2.30e-01, 7.24e-03], 
                             [9.64e-02, 2.76e-03]]

    with open("results/final_results.txt", "w") as f:
        json.dump({
            "wave_cond": wave_cond.tolist(),
            "helm_cond": helm_cond.tolist(),
            "burg_cond": burg_cond.tolist(),
            "wave_relr": wave_relr.tolist(),
            "helm_relr": helm_relr.tolist(),
            "burg_relr": burg_relr.tolist(),
            "wave_hist": wave_hist.tolist(),
            "helm_hist": helm_hist.tolist(),
            "burg_hist": burg_hist.tolist(),
            "pois_cond_thy": pois_cond_thy.tolist(),
            "pois_cond_nn": pois_cond_nn.tolist(),
            "pois_cond_matrix": pois_cond_matrix.tolist(),
            "ablation_exp_hist": ablation_exp_hist,
            "ablation_exp_cond": ablation_exp_cond,
            "cpinn_time" : {
                'burgers1d': burgers1d_cpinn_time,
                'heat2d':    heat2d_cpinn_time,
                'wave1d':    wave1d_cpinn_time,
                'ns2d':      ns2d_cpinn_time,
            },
            "pinn_time": {
                "burgers1d": pinn_times[0].tolist(),
                "poisson3d": pinn_times[1].tolist(),
                "ns2d":      pinn_times[2].tolist(),
                "wave1d":    pinn_times[3].tolist(),
                "heat2d":    pinn_times[4].tolist(),
            }, 
            "scaling_up_times": {
                'cpinn_time': poisson3d_cpinn_time,
                "precond_time": poisson3d_precond_time,
                "fenics_time": poisson3d_fenics_time,
            }
        }, f, indent=4)

def final_plot(): # Intergral Several different cases
    with open("results/final_results.txt", "r") as f:
        data = json.load(f)
        wave_cond = np.array(data['wave_cond'])
        wave_relr = np.array(data['wave_relr'])
        wave_hist = np.array(data['wave_hist'])
        burg_cond = np.array(data['burg_cond'])
        burg_relr = np.array(data['burg_relr'])
        # burg_hist = np.array(data['burg_hist'])
        helm_cond = np.array(data['helm_cond'])
        helm_relr = np.array(data['helm_relr'])
        # helm_hist = np.array(data['helm_hist'])
        pois_cond_thy    = np.array(data['pois_cond_thy'])
        pois_cond_nn     = np.array(data['pois_cond_nn'])
        pois_cond_matrix = np.array(data['pois_cond_matrix'])

        ablation_exp_hist = data['ablation_exp_hist']
        ablation_exp_cond = data['ablation_exp_cond']

        cpinn_time = data['cpinn_time']
        pinn_time = data['pinn_time']
        for key in cpinn_time.keys():
            cpinn_time[key] = np.array(cpinn_time[key])
            pinn_time[key] = np.array(pinn_time[key])
    
        scaling_up_times = data['scaling_up_times']
        for key in scaling_up_times.keys():
            scaling_up_times[key] = np.array(scaling_up_times[key])
    
    def figure0(savepath="results/figure0.png"):
        fig, ax = plt.subplots()
        plt.rc('font',family='Times New Roman')
        plt.rcParams['font.sans-serif'] = 'times new roman'
        plt.xticks([1.0, 2, 3.0, 4, 5.0], fontproperties = 'times new roman', size=30)
        # Define a custom tick label formatting function
        def format_ticks(value, pos):
            return f'{value:.1f}'  # Format value with two decimal places

        # Apply the custom formatter to the x-axis ticks
        plt.gca().xaxis.set_major_formatter(FuncFormatter(format_ticks))
        plt.yticks(fontproperties = 'times new roman', size=30)
        plt.xlabel(r'$\mathdefault{P}$', fontsize=38)
        plt.ylabel(r'$\mathdefault{||\mathcal{F}^{-1}||}$',fontsize=38)
        plt.yscale('log')
        plt.ylim(3/25, 8)
        plt.grid()

        P_list = np.linspace(1, 5, num=41)
        matrix_colors = [ "#6babf0",  "#2474b5","#4b8fd2", "#8bc7ff","#004e8a",]
        plt.plot(P_list, pois_cond_matrix[0], label=f"FDM {2}", color=matrix_colors[0], linewidth=4)
        plt.plot(P_list, pois_cond_nn, label="NN", color="#008475", linewidth=4)
        plt.plot(P_list, pois_cond_matrix[1], label=f"FDM {4}", color=matrix_colors[1], linewidth=4)
        plt.plot(P_list, pois_cond_thy, label="Theory", color="#b35531", linewidth=4, zorder=5)
        plt.plot(P_list, pois_cond_matrix[4], label=f"FDM {32}", color=matrix_colors[4], linewidth=4)
        
        plt.legend(loc = 'upper center', fontsize=28,  bbox_to_anchor=(0.5, 1.45), ncol=3, handlelength=1,  mode=None)

        axins = zoomed_inset_axes(ax, 22, loc=1)
        for i, mesh in enumerate([2, 4, 8, 16, 32]):
            if mesh in [2, 4, 32]:
                axins.plot(P_list, pois_cond_matrix[i], label=f"Matrix {mesh}", color=matrix_colors[i], linewidth=8)
        axins.plot(P_list, pois_cond_nn, label="NN", color="#008475", linewidth=8)
        axins.plot(P_list, pois_cond_thy, label="Theory", color="#b35531", linewidth=8, zorder=5)
        x1, x2, y1, y2 = 2.0, 2.1, 10**-0.02, 10**0.02
        axins.grid()
        axins.set_xticks([])
        axins.set_yticks([])
        axins.set_xlim(x1, x2)
        axins.set_ylim(y1, y2)
        mark_inset(ax, axins, loc1=2, loc2=3, fc="none", ec="black", zorder=40)

        
        plt.savefig(savepath, bbox_inches='tight')
        plt.close()

    def normalize(arr, method="linear"):
        if method == 'linear':
            return (arr - arr.min()) / (arr.max() - arr.min())
        elif method == 'log':
            arr = np.log(arr)
            return (arr - arr.min()) / (arr.max() - arr.min())
        elif method == 'poly2':
            arr = arr**(1/2)
            return (arr - arr.min()) / (arr.max() - arr.min())
        elif method == 'poly100':
            print(arr)
            arr = arr**(1/100)
            print(arr)
            return (arr - arr.min()) / (arr.max() - arr.min())
        elif method == 'exp':
            arr = np.exp(arr)
            return (arr - arr.min()) / (arr.max() - arr.min())
        elif method == 'log_':
            arr = np.log((arr-arr.min()) + 1e-2)
            return (arr - arr.min()) / (arr.max() - arr.min())

    def scatter_and_regression(x, y, label, color, curve="exp", dist="squarelog"):
        curve = {
            "linear": (lambda x, a, b: a * x + b),
            "exp": (lambda x, a, b: np.exp(a * x + b)),
            "poly2": (lambda x, a, b, c: a * x**2 + b * x + c),
        }[curve]

        dist = {
            "linear": (lambda x, y: np.abs(x-y)),
            "square": (lambda x, y: (x-y)**2),
            "log": (lambda x, y: np.abs(np.log(x) - np.log(y))),
            "squarelog": (lambda x, y: (np.log(x)-np.log(y))**2)
        }[dist]

        def curve_fit_loss(params):
            return dist(y, curve(x, *params)).mean()

        res = scipy.optimize.minimize(curve_fit_loss, np.zeros(len(inspect.signature(curve).parameters)-1))

        slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(x, np.log(y)) # R2 for only linear regress (used now)
        assert np.isclose(np.array([slope, intercept]), res.x).all(), "Please use exp curve and squarelog dist if want R2"
        print(slope, intercept, r_value)

        plt.plot(x, curve(x, *res.x), '--', linewidth=4, color=color, label=label)
        plt.scatter(x, y, s=150, alpha=0.5, color=color, edgecolor='none')
        return r_value

    def figure1(savepath="results/figure1.png", xscale="linear", yscale="log"):
        """Plot the scatter of different cases, which use cond as x and err as y"""
        plt.figure()

        plt.rc('font',family='Times New Roman')
        plt.rcParams['font.sans-serif'] = 'times new roman'
        plt.xticks([1.0, 2, 3.0, 4, 5.0], fontproperties = 'times new roman', size=30)
        # Define a custom tick label formatting function
        def format_ticks(value, pos):
            return f'{value:.1f}'  # Format value with two decimal places

        # Apply the custom formatter to the x-axis ticks
        plt.gca().xaxis.set_major_formatter(FuncFormatter(format_ticks))
        plt.yticks(fontproperties = 'times new roman', size=30)
        plt.xlabel(r'Normalized Cond', fontsize=38)
        plt.ylabel(r'L2RE',fontsize=38)
        plt.xscale(xscale)
        plt.yscale(yscale)
        plt.grid()

        wave_r2 = scatter_and_regression(normalize(wave_cond, "log"), wave_relr.mean(axis=1), label="Wave", color="#2474b5")
        plt.scatter([], [], label=r"$\mathdefault{R^2}$: 0.94", alpha=0)
        burg_r2 = scatter_and_regression(normalize(burg_cond, "log_"), burg_relr.mean(axis=1), label="Burg", color="#b35531")
        plt.scatter([], [], label=r"$\mathdefault{R^2}$: 0.92", alpha=0)
        helm_r2 = scatter_and_regression(normalize(helm_cond, "poly2"), helm_relr.mean(axis=1), label="Helm", color="#008475")
        plt.scatter([], [], label=r"$\mathdefault{R^2}$: 0.97", alpha=0)

        # plt.text(0.4, 1e-4, r"$\mathdefault{R^2= 0.94}$" , fontsize=30, ha='center', va='center', color="#2474b5")
        # plt.text(0.86, 2e-3, r"$\mathdefault{R^2= 0.92}$" , fontsize=30, ha='center', va='center', color="#b35531")
        # plt.text(0.45, 0.13, r"$\mathdefault{R^2= 0.97}$" , fontsize=30, ha='center', va='center', color="#008475")

        plt.legend(loc = 'upper center', fontsize=28,  bbox_to_anchor=(0.5, 1.45), ncol=3, handlelength=1.5)
        plt.savefig(savepath, bbox_inches='tight')
        plt.close()
        print(f"Figure 1: wave_r2: {wave_r2**2} burg_r2: {burg_r2**2} helm_r2: {helm_r2**2}")
    
    def figure2(savepath="results/figure2.png"):
        plt.figure()

        plt.rc('font',family='Times New Roman')
        plt.rcParams['font.sans-serif'] = 'times new roman'
        plt.xticks([0, 10000, 20000],fontproperties = 'times new roman', size=30)
        plt.yticks(fontproperties = 'times new roman', size=30)
        plt.xlabel(r'Iterations', fontsize=38)
        plt.ylabel(r'L2RE',fontsize=38)
        plt.yscale('log')
        plt.grid()
        

        epochs = np.arange(wave_hist.shape[2]) * 100 + 100
        order = np.argsort(wave_cond)
        clip = -1

        ids = [0, 5, 6, 10, 20, 22, 25, 30, 35]
        cases = list(order[i] for i in ids)
        conds = np.array(list(wave_cond[case] for case in cases))
        color_index = {
            "linear": (lambda x: (x - conds.min()) / (conds.max() - conds.min())),
            "log": (lambda x: (np.log(x) - np.log(conds.min())) / (np.log(conds.max()) - np.log(conds.min())))
        }['log'](conds)

        colors = plt.cm.coolwarm(np.linspace(0, 1, order.shape[0], endpoint=False))
        colors = np.concatenate([colors[0:1, :]] * 10 + [colors], axis=0)
        custom_map = mpl.colors.ListedColormap(colors)
        for i in range(len(ids)):
            case = cases[i]
            color_idx = color_index[i]
            plt.plot(epochs[:clip], wave_hist[case,:,:clip].mean(axis=0), alpha=0.9, color=custom_map(color_idx), label=f"Case {i}", linewidth=4)

        plt.text(19500, 20, "Cond", fontsize=38)

        sm = mpl.cm.ScalarMappable(cmap=custom_map, norm=mpl.colors.LogNorm(vmin=conds.min(), vmax=conds.max()))
        sm.set_array(color_index) 
        colorbar = plt.colorbar(sm)
        ticks = [100, 1000]
        tick_labels = ["$\mathdefault{10^%d}$"%(int(np.log10(val))) for val in ticks]
        colorbar.set_ticks(ticks)
        colorbar.set_ticklabels(tick_labels,fontproperties = 'times new roman', size=30)
        plt.savefig(savepath, bbox_inches='tight')
        plt.close()

    def figure_abla(case):
        savepath = f"results/figure_abla_{case}.pdf"
        plt.figure()
        plt.rc('font',family='Times New Roman')
        plt.rcParams['font.sans-serif'] = 'times new roman'
        plt.xticks([0, 10000, 20000],fontproperties = 'times new roman', size=30)
        plt.yticks(fontproperties = 'times new roman', size=30)
        plt.xlabel(r'Iterations', fontsize=38)
        plt.ylabel(r'L2RE',fontsize=38)
        plt.yscale('log')
        plt.grid()
        
        epochs = np.arange(np.array(ablation_exp_hist[case + "-None"]).shape[1]) * 1000
        conds = np.array(ablation_exp_cond[case])

        color_index = {
            "linear": (lambda x: (x - conds.min()) / (conds.max() - conds.min())),
            "log": (lambda x: (np.log(x) - np.log(conds.min())) / (np.log(conds.max()) - np.log(conds.min())))
        }['log'](conds)

        colors = plt.cm.coolwarm(np.linspace(0, 1, endpoint=False))
        # colors = np.concatenate([colors[0:1, :]] * 10 + [colors], axis=0)
        custom_map = mpl.colors.ListedColormap(colors)
        for i, drop in enumerate(["-0.0001", "-0.001", "-0.01", "-0.1", "-None"]):
            data = np.array(ablation_exp_hist[case + drop])
            plt.plot(epochs, data.mean(axis=0), ('--' if drop=='-None' else '-'), \
                     alpha=0.9, color=custom_map(color_index[i]), label=drop[1:], linewidth=4)
            # plt.fill_between(epochs, data.mean(axis=0) - data.std(axis=0), data.mean(axis=0) + data.std(axis=0), \
            #                  alpha=0.2, color=custom_map(color_index[i]), edgecolor=None)

        label_y_position = {
            "Poisson2d-C": 2.5,
            "Poisson2d-CG": 2.5,
            "Poisson3d-CG": 1.5,
            "Poisson2d-MS": 1.3,
        }
        plt.text(19500, label_y_position[case], "Cond", fontsize=38)

        sm = mpl.cm.ScalarMappable(cmap=custom_map, norm=mpl.colors.LogNorm(vmin=conds.min(), vmax=conds.max()))
        sm.set_array(color_index) 
        colorbar = plt.colorbar(sm)
        ticks = []
        for i in [1, 10, 100, 1000]:
            if conds.min() <= i and i <= conds.max():
                ticks.append(i)
        tick_labels = ["$\mathdefault{10^%d}$"%(int(np.log10(val))) for val in ticks]
        colorbar.set_ticks(ticks)
        colorbar.set_ticklabels(tick_labels,fontproperties = 'times new roman', size=30)
        plt.savefig(savepath, bbox_inches='tight')
        plt.close()

    def figure_time0(savepath="results/figure_time0.pdf"):
        cases = {"Burg": "burgers1d", 
                 "Heat": "heat2d", 
                 "NS": "ns2d", 
                 "Wave": "wave1d"}
        x = np.arange(len(cases)) * 2
        bar_width = 0.6

        plt.rc('font',family='Times New Roman')
        plt.rcParams['font.sans-serif'] = 'times new roman'
        plt.xticks(x, list(cases.keys()), fontproperties = 'times new roman', size=30)
        plt.yticks(fontproperties = 'times new roman', size=30)
        plt.xlabel(r'Case', fontsize=38)
        plt.ylabel(r'Time',fontsize=38)
        plt.grid()

        pinn_data = np.array(list(pinn_time[name] for name in cases.values()))
        cpinn_data = np.array(list(cpinn_time[name] for name in cases.values()))

        error_bar_style = dict(lw=3, capsize=5, capthick=3)
        plt.bar(x - bar_width / 2, pinn_data.mean(axis=1), bar_width,
                yerr=(pinn_data.mean(axis=1) - pinn_data.min(axis=1), pinn_data.max(axis=1) - pinn_data.mean(axis=1)), 
                error_kw=error_bar_style, label="PINN", zorder=4, color="#2474B5")
        plt.bar(x + bar_width / 2, cpinn_data.mean(axis=1), bar_width,
                yerr=(cpinn_data.mean(axis=1) - cpinn_data.min(axis=1), cpinn_data.max(axis=1) - cpinn_data.mean(axis=1)), 
                error_kw=error_bar_style, label="PCPINN", zorder=4, color="#B85029")

        plt.legend(loc = 'upper center', fontsize=28, bbox_to_anchor=(0.5, 1.25), ncol=2)
        plt.savefig(savepath, bbox_inches='tight')
        plt.close()
    
    def figure_time1(savepath="results/figure_time1.pdf"):
        plt.figure(figsize=(9.6, 6))
        plt.rcParams['font.sans-serif'] = 'times new roman'
        plt.yscale('log')
        plt.xscale('log')
        # plt.gca().set_aspect(0.3)

        meshes = [0.01, 0.02, 0.03, 0.04, 0.05]
        grid = np.array([702328, 94920, 31092, 13680, 7528])

        plt.xticks(grid, fontproperties = 'times new roman', size=30)
        plt.yticks(fontproperties = 'times new roman', size=30)
        plt.xlabel(r'Grid Size', fontsize=38)
        plt.ylabel(r'Relative Time',fontsize=38)
        plt.grid()

        # Define a custom tick label formatting function
        def format_ticks(value, pos):
            return f'{(value + 500)//1000:d}K'
            return f'{value:d}'  # Format value with two decimal places

        # Apply the custom formatter to the x-axis ticks
        plt.gca().xaxis.set_major_formatter(FuncFormatter(format_ticks))

        data = scaling_up_times['cpinn_time']
        mean, std = data.mean(axis=1), data.std(axis=1)
        plt.errorbar(grid, mean/mean.min(), yerr=std/mean.min(), color="#B85029", linewidth=4, marker='o', markersize=10, zorder=5, label="PCPINN")
        data = scaling_up_times['precond_time']
        mean, std = data.mean(axis=1), data.std(axis=1)
        plt.errorbar(grid, mean/mean.min(), yerr=std/mean.min(), color="#008339", linewidth=4, marker='o', markersize=10, zorder=5, label="PreCond")
        mean, std = scaling_up_times['fenics_time'][:, 0], scaling_up_times['fenics_time'][:, 1]
        plt.errorbar(grid, mean/mean.min(), yerr=std/mean.min(), color="#2474B5", linewidth=4, marker='o', markersize=10, zorder=5, label="FEM")

        plt.legend(loc = 'upper center', fontsize=28, bbox_to_anchor=(0.5, 1.25), ncol=3)
        plt.savefig(savepath, bbox_inches='tight')
        plt.close()


    figure0("results/figure0.pdf")
    figure1("results/figure1.pdf")
    figure2("results/figure2.pdf")
    figure_time0()
    figure_abla("Poisson2d-C")
    figure_abla("Poisson2d-CG")
    figure_abla("Poisson3d-CG")
    figure_abla("Poisson2d-MS")
    figure_time1()


if __name__ == "__main__":
    # final_save()
    final_plot()