import math
import numpy as np
import jax
import jax.numpy as jnp
import yaml
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.gridspec import GridSpec
import datetime
import scipy.sparse as sp
import scipy.sparse.linalg as splinalg
import pickle
import os
import pennylane as qml

import logging

logging.basicConfig( format = "%(asctime)s - %(filename)s  : %(message)s "
                , level = logging.INFO
                , datefmt = "%I:%M:%S"
                )



def setting(BASE_DIR):
    import argparse
    parser = argparse.ArgumentParser("SEM")

    parser.add_argument("--jax_seed", type=int, default=42)
    parser.add_argument("--random_seed", type=int, default=1993)

    ### PDE Information
    parser.add_argument("--exp_name", type=str
                        # , default='EXERCISE'
                        # , default='Ansatz'
                        # , default='DD'
                        # , default='CNN'
                        # , default='Precondition'
                        # , default='RD'
                        # , default='Helmholtz'
                        # , default='CD'
                        , default='Classic'
                        )
    
    ### PDE Information
    DIMENSION = ['1D', '2D', '3D']
    EQUATIONS = ['RD_1D', 'CD_1D','Helmholtz_1D' ,'RD_2D', 'CD_2D']

    parser.add_argument("--dimension", type=str, default='1D', choices=DIMENSION)
    parser.add_argument("--equation", type=str
                        , default='CD_1D'
                        # , default='Helmholtz_1D'
                        # , default='RD_1D'
                        , choices=EQUATIONS)
    parser.add_argument("--pde_parameter"
                        , default= 0.05
                        )
    parser.add_argument("--basistype", type=str, default='Legendre', choices=['Chebyshev', 'Legendre'])

    BOUNDARY_CONDITION = ['Dirichlet', 'Neumann']
    parser.add_argument("--boundary_condition", type=str
                        , default='Dirichlet'
                        # , default='Neumann'
                        , choices=BOUNDARY_CONDITION)

    ### QML
    parser.add_argument("--n_qubits", type=int
                        , default=5
                        )
    ANSATZ_NAME = ['StronglyEntanglingLayers' ,'BasicEntanglerLayers']
    parser.add_argument("--ansatz_name", type=str
                        , default='StronglyEntanglingLayers'
                        , choices=ANSATZ_NAME)
    parser.add_argument("--n_layers", type=int
                        , default=12
                        )

    ### Preconditioning
    PRECONDITIONS = ['No', 'Diagonal', 'GauseSeidel', 'ILU','RowScaling', 'ColScaling', 'ROWCOL', 'ITSELF']
    parser.add_argument("--precondition", type=str
                        , default='No'
                        # , default='ILU'
                        , choices=PRECONDITIONS)

    # Training Information
    PRECONDITIONS = ['Adam', 'LBFGS', 'From_LBFGS_To_Adam' , 'From_Adam_To_LBFGS']
    parser.add_argument("--optimizer_type", type=str
                        # , default='Adam'
                        , default='LBFGS'
                        )
    parser.add_argument("--EPOCH", type=int
                        , default=200000
                        )
    parser.add_argument("--lr", type=float, default=0.0001) # Helmholtz
    parser.add_argument("--batch_train", type=int
                        , default=800 # L-BFGS
                        # , default=200 # Adam
                        )
    parser.add_argument("--batch_test", type=int, default=200)

    # Forcing Data Information
    FORCINGTYPES = ['uniform','zero', 'ones']
    parser.add_argument("--train_size", type=int, default=800)
    parser.add_argument("--test_size", type=int, default=200)
    parser.add_argument("--DB_forcingtype", type=str, default='uniform', choices=FORCINGTYPES)
    parser.add_argument("--DB_sin_mean", type=float, default=0.0)
    parser.add_argument("--DB_sin_sd", type=float
                        , default=1.0
                        # , default=0.01
                        )
    parser.add_argument("--DB_cos_mean", type=float, default=0.0)
    parser.add_argument("--DB_cos_sd", type=float
                        , default=1.0
                        # , default=0.01
                        )
    
    args,_ = parser.parse_known_args()

    if args.ansatz_name == 'BasicEntanglerLayers' :
        args.output_dim = args.n_layers * args.n_qubits
    elif args.ansatz_name == 'StronglyEntanglingLayers' :
        args.output_dim = 3* args.n_layers * args.n_qubits
    elif args.ansatz_name == 'QCNN':
        # args.output_dim = 34
        args.output_dim = 17 * args.n_layers
    elif args.ansatz_name == 'QCNNNN':
        args.output_dim = (17+ 3*args.n_qubits) * args.n_layers
    else:
        raise NotImplementedError('setting - WRONG ANSATZ')

    args.DB_size = args.train_size + args.test_size

    args.N = 2**args.n_qubits + 1
    # args.nn_input_dim = args.N-1
    args.nn_input_dim = args.N+1
    args.sol_dim = args.N+1

    args.cur_time  = get_current_time()

    args.PDE_INFO = f"{args.equation}_{args.pde_parameter}_{args.basistype}_{args.boundary_condition}"
    args.TRAIN_INFO = (args.exp_name
                       +'_'+ f'{args.cur_time}'
                       +'_'+ f'n_{args.n_qubits}'
                       +'_'+ args.ansatz_name
                       +'_'+ f'n_layers_{args.n_layers}'
                       +'_'+ args.precondition
                    #    +'_'+ f'lr_{args.lr}' 
                       )

    if args.exp_name == 'Classic':
        args.RESULT_PATH = os.path.join(f'training_{args.optimizer_type}_classic'
                                        , args.PDE_INFO
                                        , args.TRAIN_INFO
                                        )
    else:
        args.RESULT_PATH = os.path.join(f'training_{args.optimizer_type}'
                                        , args.PDE_INFO
                                        , args.TRAIN_INFO
                                        )
    args.CKPT_PATH = os.path.join(args.RESULT_PATH , 'loss_and_ckpt')
    args.SOLUTION_PATH = os.path.join(args.RESULT_PATH , 'solution')
    
    # SPECTRAL_INFO = args.SPECTRAL_INFO

    if args.DB_forcingtype == 'uniform':
        args.DB_NAME = ( args.PDE_INFO
                        + f'_n_qubits_{args.n_qubits}_N_{args.N}'
                        +'_train_'+ str(args.train_size)
                        +'_'+ args.DB_forcingtype
                        +'_'+ f'{args.DB_sin_mean:.2f}'
                        +'_'+ f'{args.DB_sin_sd:.2f}'
                        +'_'+ f'{args.DB_cos_mean:.2f}'
                        +'_'+ f'{args.DB_cos_sd:.2f}'
                        )
    elif args.DB_forcingtype == 'ones':
        args.DB_NAME = (args.PDE_INFO 
                                # +'_'+ SPECTRAL_INFO
                                +'_'+ args.train_forcingtype
                                )
    else:
        raise NotImplementedError
    args.DB_FOLDER = os.path.join(args.RESULT_PATH, 'data')
    args.DB_ROOT = os.path.join(args.DB_FOLDER, args.DB_NAME+'.pickle')
    np.random.seed(args.random_seed)

    def return_instance_variable(obj):
        attrs = [attr for attr in dir(obj) if not attr.startswith('__') and not callable(getattr(obj, attr))]
        return sorted(attrs)
    
    config = {}
    for name in return_instance_variable(args):
        if name =='device':
            config[name] = str(getattr(args, name))
        else:
            config[name] = getattr(args, name)
    return args, config



class precondition:
    def __init__(self, AA, PRECONDITION):
        self.PRECONDITION= PRECONDITION

        if PRECONDITION == 'No':
            self.PL_inv, self.PR_inv = None, None
            self.A_tilde = AA
        elif PRECONDITION == 'ILU':
            self.PL_inv, self.A_tilde, self.PR_inv = self.ILU(AA)
        elif PRECONDITION == 'Diagonal':
            self.PL_inv, self.A_tilde, self.PR_inv  = self.Diagonal(AA)
        elif PRECONDITION == 'GauseSeidel':
            self.PL_inv, self.A_tilde, self.PR_inv = self.GauseSeidel(AA)
        elif PRECONDITION == 'RowScaling':
            self.PL_inv, self.A_tilde, self.PR_inv = self.RowScaling(ValueError)
        elif PRECONDITION == 'ColScaling':
            self.PL_inv, self.A_tilde, self.PR_inv = self.ColScaling(AA)
        elif PRECONDITION == 'ROWCOL':
            self.PL_inv, self.A_tilde, self.PR_inv = self.ROWCOL(AA)
        elif PRECONDITION == 'ITSELF':
            self.PL_inv, self.A_tilde, self.PR_inv  = self.ITSELF(AA)
        else:
            raise NotImplementedError('Wrong Precondition')

        # self.coeff = self.make_coeff()
        # self.RHS = self.make_RHS()

    def make_coeff(self):
        if  self.PR_inv is not None:
            return lambda alpha: jnp.einsum('ij,bj->bi', self.PR_inv, alpha)
        else:
            return lambda alpha: alpha
    
    def make_RHS(self):
        if self.PL_inv is not None:
            return lambda bar_f: jnp.einsum('ij,bj->bi', self.PL_inv, bar_f)
        else:
            return lambda bar_f: bar_f

    def Diagonal(self, AA):
        PL = jnp.diag(jnp.diag(AA))
        PL_inv = jnp.linalg.inv(PL)
        A_tilde = PL_inv @ AA
        PR_inv = jnp.eye(AA.shape[0], dtype=AA.dtype)
        return PL_inv, A_tilde, PR_inv
    
    def GaussSeidel(self, AA):
        DD = jnp.diag(jnp.diag(AA))
        LL = jnp.tril(AA, -1)
        PL_inv = jnp.linalg.inv(DD + LL)
        A_tilde = PL_inv @ AA
        PR_inv = jnp.eye(AA.shape[0], dtype=AA.dtype)
        return PL_inv, A_tilde, PR_inv
    
    def ITSELF(self, AA):
        PL_inv = AA
        A_tilde = AA@AA
        return PL_inv, A_tilde, None

    def ILU(self, AA):
        A_np = np.array(AA)
        ILU = splinalg.spilu(sp.csr_matrix(A_np))   # ILU factorization
        L_np = ILU.L.toarray()
        U_np = ILU.U.toarray()
        LU = L_np @ U_np
        LU_inv =  jnp.linalg.inv(jnp.array(LU))
        LU_inv_A = LU_inv @ AA
        return LU_inv, LU_inv_A, None
    
    def RowScaling(self, AA):
        row_norms = jnp.linalg.norm(AA, ord=2, axis=1, keepdims=True)
        row_scaling = jnp.diag(1.0 / row_norms.squeeze())
        A_row_scaled = row_scaling @ AA
        PR_inv = jnp.eye(AA.shape[0], dtype=AA.dtype)
        return row_scaling, A_row_scaled, PR_inv
    
    def ColScaling(self, AA):
        PL_inv = jnp.eye(AA.shape[0], dtype=AA.dtype)
        col_norms = jnp.linalg.norm(AA, ord=2, axis=0, keepdims=True)
        col_scaling = jnp.diag(1.0 / col_norms.squeeze())
        A_scaled = AA @ col_scaling
        return PL_inv, A_scaled, col_scaling
    
    def ROWCOL(self, AA):
        # Row scaling
        row_norms = jnp.linalg.norm(AA, ord=2, axis=1, keepdims=True)
        row_scaling = jnp.diag(1.0 / row_norms.squeeze())
        A_row_scaled = row_scaling @ AA

        # Column scaling
        col_norms = jnp.linalg.norm(A_row_scaled, ord=2, axis=0, keepdims=True)
        col_scaling = jnp.diag(1.0 / col_norms.squeeze())

        A_scaled = A_row_scaled @ col_scaling
        return row_scaling, A_scaled, jnp.diag(1.0 / col_norms.squeeze())


def get_current_time():
    cur_time = str(datetime.datetime.now()).replace(' ', 'T')
    cur_time = cur_time.replace(':','').split('.')[0].replace('-','')
    return cur_time

def save_config(config, RESULT_PATH):
    CONFIG_PATH = os.path.join(RESULT_PATH, "config.yaml")
    with open(CONFIG_PATH, "w") as file:
        yaml.dump(config, file, default_flow_style=False)

class PlotSystem:
    def __init__(self):
        
        self.figwide = 12
        self.figheight = 8

        self.title_fontsize=22
        self.label_fontsize=18
        self.ticks_fontsize=18
        self.legend_fontsize=18
        self.linewidth = 3.5


    def Two_Functions_1d(self, IMGNAME, x1, y1, label1, x2, y2, label2, IMG_TITLE=None):
        PATH = self.args.RESULT_PATH
        assert len(x1.shape) == len(y1.shape) == len(x2.shape)== len(y2.shape), f'{x1.shape} {y1.shape} {x2.shape} {y2.shape}'
        assert x1.shape == y1.shape
        assert x2.shape == y2.shape
        
        plt.figure(figsize=(self.figwide, self.figheight))
        plt.plot(x1.squeeze(), y1, label=label1)
        plt.plot(x2.squeeze(), y2, label=label2)
        plt.xlabel('x', fontsize= self.label_fontsize)
        plt.ylabel('y', fontsize= self.label_fontsize)
        
        plt.xticks(fontsize= self.ticks_fontsize)
        plt.yticks(fontsize= self.ticks_fontsize)
        plt.title(IMG_TITLE, fontsize= self.title_fontsize)
        plt.legend(loc='best', fontsize= self.legend_fontsize, ncol=2)

        plt.grid(True)
        plt.savefig(os.path.join(PATH, IMGNAME))
        plt.close()
        return None


    def Polynomials(self, xx, PP, SAVE_PATH=None, SHOW = True):

        plt.figure(figsize=(self.figwide, self.figheight))
        
        for n in range(len(PP)):
            plt.plot(xx, PP[n])
        
        plt.xticks(fontsize= self.ticks_fontsize)
        plt.yticks(fontsize= self.ticks_fontsize)

        plt.xlabel(r'$x$ (Collocation)', fontsize= self.label_fontsize)
        plt.ylabel(r'$P$', fontsize= self.label_fontsize)
        plt.title(f'Orthogonal Polynomials', fontsize= self.title_fontsize)

        plt.savefig(SAVE_PATH)
        if SHOW:
            plt.show()
        plt.close()



    def Basis(self, xx, phi, SAVE_PATH='basis_plot.png', SHOW = True):

        plt.figure(figsize=(self.figwide, self.figheight))
        
        for n in range(phi.shape[0]):
            plt.plot(xx, phi[n,:])
        
        plt.xticks(fontsize= self.ticks_fontsize)
        plt.yticks(fontsize= self.ticks_fontsize)

        plt.xlabel(r'$x$ (Collocation)', fontsize= self.label_fontsize)
        plt.ylabel(r'$\phi$', fontsize= self.label_fontsize)
        plt.title(f'Basis', fontsize= self.title_fontsize)

        plt.savefig(SAVE_PATH)
        if SHOW:
            plt.show()
        plt.close()

    def Basis_2D(self, X, Y, phi, SAVE_PATH='basis_2D_plot.png', SHOW = True):
        """
        Parameters:
            xx (ndarray): 1D array of x-coordinates (collocation points).
            yy (ndarray): 1D array of y-coordinates (collocation points).
            solution (ndarray): 2D array of solution values with shape (len(yy), len(xx)).
            SAVE_PATH (str): Path to save the generated plot.
            SHOW (bool): Whether to display the plot.
        """
        # Initialize the figure
        fig = plt.figure(figsize=(12, 6))

        # Contour plot (2D view)
        ax1 = fig.add_subplot(1, 2, 1)
        contour = ax1.contourf(X, Y, phi, cmap='viridis', levels=50)
        plt.colorbar(contour, ax=ax1)
        ax1.set_title('Contour Plot', fontsize=14)
        ax1.set_xlabel(r'$x$', fontsize=12)
        ax1.set_ylabel(r'$y$', fontsize=12)

        # Surface plot (3D view)
        ax2 = fig.add_subplot(1, 2, 2, projection='3d')
        surf = ax2.plot_surface(X, Y, phi, cmap='viridis', edgecolor='none')
        fig.colorbar(surf, ax=ax2, shrink=0.5, aspect=10)
        ax2.set_title('Surface Plot', fontsize=14)
        ax2.set_xlabel(r'$x$', fontsize=12)
        ax2.set_ylabel(r'$y$', fontsize=12)
        ax2.set_zlabel(r'Solution', fontsize=12)

        plt.tight_layout()
        plt.savefig(SAVE_PATH)
        if SHOW:
            plt.show()
        plt.close()


    def Pauli_Decomposition_Plot(self, MATRIX_SIZE, L_LIST, SAVE_PATH=None, SHOW=False):
        
        plt.figure(figsize=(8, 5))
        plt.plot(MATRIX_SIZE, L_LIST, label=r"$L$", color='b', linewidth=2)
    
        plt.plot(MATRIX_SIZE, MATRIX_SIZE, linestyle="--", color='g', label=r"$N$")
        plt.plot(MATRIX_SIZE, MATRIX_SIZE* np.log(MATRIX_SIZE), linestyle="--", color='r', label=r"$N \log N$")

        plt.xlabel(r'Matrix Size $N$ x $N$')
        plt.ylabel(r"Pauli-decomposition $L$")
        plt.title(r"Pauli-decomposition $L$ with increasing matrix size")
        plt.legend()
        plt.grid(True)

        plt.tight_layout()

        if SAVE_PATH is not None:
            plt.savefig(SAVE_PATH)

        if SHOW:
            plt.show()
        plt.close()


    def Solution_1D(self, xx, pred_u_dict=None, true_u=None, NORM=None, SAVE_PATH=None, SHOW=False, COLOR='b'):

        title_fontsize = self.title_fontsize
        label_fontsize = self.label_fontsize
        ticks_fontsize = self.ticks_fontsize
        legend_fontsize = self.legend_fontsize
        linewidth = self.linewidth

        fig, axes = plt.subplots(1, 2, figsize=(14, 6))  # 두 개 subplot (좌=solution, 우=error)

        # ------------------------------
        # (1) 왼쪽: Solution plot
        # ------------------------------
        ax1 = axes[0]
        if pred_u_dict is not None:
            for values, label in zip(pred_u_dict.values(), pred_u_dict.keys()):
                ax1.plot(xx, values.squeeze(), label=label, color=COLOR, linewidth=linewidth)

        if true_u is not None:
            ax1.plot(xx, true_u.squeeze(), 'r--', label='True solution', linewidth=linewidth)

        if NORM is not None:
            L2_norm = NORM['RelL2']
            Linf_norm = NORM['RelLinf']
            text_str = (r"Relative $L_2$" + f'={L2_norm:3e}' + '\n' +
                        r"Relative $L_\text{inf}$" + f'={Linf_norm:3e}')
            ax1.set_title(text_str, fontsize=title_fontsize)

        ax1.set_xlabel(r'$x$ (Collocation)', fontsize=label_fontsize)
        ax1.set_ylabel(r'$u(x)$', fontsize=label_fontsize)
        ax1.tick_params(axis='both', labelsize=ticks_fontsize)
        ax1.legend(loc='best', fontsize=legend_fontsize)

        # ------------------------------
        # (2) 오른쪽: Error plot (log scale)
        # ------------------------------
        ax2 = axes[1]
        if true_u is not None and pred_u_dict is not None:
            for values, label in zip(pred_u_dict.values(), pred_u_dict.keys()):
                error = np.abs(values.squeeze() - true_u.squeeze())
                ax2.plot(xx, error,
                        'b:', 
                        label='Abs Error',
                        linewidth=linewidth,
                        marker='o',          # 동그라미 마커
                        markersize=8,        # 마커 크기
                        markeredgewidth=1.5, # 테두리 두께
                        markeredgecolor='b', # 테두리 색
                        markerfacecolor='white'  # 마커 안쪽 색
                        )
            ax2.axhline(y=10**(-1),linestyle='--',linewidth=1,)
            ax2.axhline(y=10**(-2),linestyle='--',linewidth=1,)
            ax2.axhline(y=10**(-3),linestyle='--',linewidth=1,)

            ax2.set_yscale("log")
            ax2.set_xlabel(r'$x$ (Collocation)', fontsize=label_fontsize)
            ax2.set_ylabel(r'$|u_{pred}(x) - u_{true}(x)|$', fontsize=label_fontsize)
            ax2.tick_params(axis='both', labelsize=ticks_fontsize)
            ax2.legend(loc='best', fontsize=legend_fontsize)
            ax2.set_title("Absolute Error (log scale)", fontsize=title_fontsize)

        # ------------------------------
        # Save / Show
        # ------------------------------
        plt.tight_layout()
        plt.savefig(SAVE_PATH, dpi=300)
        if SHOW:
            plt.show()
        plt.close()



    def Solution_2D(self, X, Y, solution, reference=None, NORM=None
                    , SAVE_PATH='solution_2D_plot.png', SHOW=True):
        """
        2D Solution Plotting
        
        Parameters:
            xx (ndarray): 1D array of x-coordinates (collocation points).
            yy (ndarray): 1D array of y-coordinates (collocation points).
            solution (ndarray): 2D array of solution values with shape (len(yy), len(xx)).
            SAVE_PATH (str): Path to save the generated plot.
            SHOW (bool): Whether to display the plot.
        """
        NN = X.shape[0]-1
        # Initialize the figure
        fig = plt.figure(figsize=(12, 6))

        # Contour plot (2D view)
        ax1 = fig.add_subplot(1, 2, 1)
        contour = ax1.contourf(X, Y, solution.view(NN+1, NN+1)
                               , cmap='viridis', levels=50)
        plt.colorbar(contour, ax=ax1)
        ax1.set_xlabel(r'$x$', fontsize=12)
        ax1.set_ylabel(r'$y$', fontsize=12)

        # Surface plot (3D view)
        from matplotlib.lines import Line2D
        legend_elements = [
            Line2D([0], [0], color='blue', lw=4, label='Solution'),
        ]
        ax2 = fig.add_subplot(1, 2, 2, projection='3d')
        
        surf = ax2.plot_surface(X, Y, solution.view(NN+1, NN+1)
                                , cmap='viridis', edgecolor='none')
        # fig.colorbar(surf, ax=ax2, shrink=0.5, aspect=10)

        if reference is not None:
            for label, u in zip(reference.keys(), reference.values()):
                # Plot the reference solution with transparency
                ref_surf = ax2.plot_surface(X, Y, u.view(NN+1, NN+1)
                                            , cmap='coolwarm', edgecolor='none', alpha=0.5, label='Reference Solution')
                legend_elements.append(Line2D([0], [0], color='red', lw=4, label=label))

        if NORM is not None:
            L2_norm = NORM['L2']
            Linf_norm = NORM['Linf']

            TITLE1 = r"$L_2$"+f'={L2_norm:.6f}'
            TITLE2 = r"$L_\text{inf}$" + f'={Linf_norm:.6f}'
            ax1.set_title(TITLE1, fontsize=14)
            ax2.set_title(TITLE2, fontsize=14)
            
        ax2.set_xlabel(r'$x$', fontsize=12)
        ax2.set_ylabel(r'$y$', fontsize=12)
        ax2.legend(handles=legend_elements, loc='upper right')
        plt.tight_layout()
        
        plt.savefig(SAVE_PATH)
        if SHOW:
            plt.show()
        plt.close()

    def Solution_3D(self, X, Y, Z, solution,
                    SAVE_PATH=None, SHOW=True, 
                    title_fontsize=32, label_fontsize=28, tick_fontsize=24):
        NN = X.shape[0]-1
        # Convert tensors to numpy for matplotlib
        X_np = X.numpy()
        Y_np = Y.numpy()
        Z_np = Z.numpy()
        solution_np = solution.view(NN+1, NN+1, NN+1).numpy()

        # Create a 3D figure
        fig = plt.figure(figsize=(50, 8))
        plt.tight_layout()

        ax1 = fig.add_subplot(141)
        ax1.set_xlabel('X', fontsize=label_fontsize)
        ax1.set_ylabel('Y', fontsize=label_fontsize)
        ax1.set_title('xy-plane', fontsize=title_fontsize)
        surf1 = ax1.contourf(X[:,:,NN//2], Y[:,:,NN//2], solution[:,:,NN//2], cmap='viridis', levels=50)
        ax1.tick_params(axis='both', labelsize=tick_fontsize)

        ax2 = fig.add_subplot(142)
        ax2.set_xlabel('X', fontsize=label_fontsize)
        ax2.set_ylabel('Z', fontsize=label_fontsize)
        ax2.set_title('xz-plane', fontsize=title_fontsize)
        surf2 = ax2.contourf(X[:,NN//2,:], Z[:,NN//2,:], solution[:,NN//2,:], cmap='viridis', levels=50)
        ax2.tick_params(axis='both', labelsize=tick_fontsize)

        ax3 = fig.add_subplot(143)
        ax3.set_xlabel('Y', fontsize=label_fontsize)
        ax3.set_ylabel('Z', fontsize=label_fontsize)
        ax3.set_title('yz-plane', fontsize=title_fontsize)
        surf3 = ax3.contourf(Y[NN//2,:,:], Z[:,0,:], solution[NN//2,:,:], cmap='viridis', levels=50)
        ax3.tick_params(axis='both', labelsize=tick_fontsize)
        
        ax4 = fig.add_subplot(144, projection='3d')
        ax4.set_xlabel('X', fontsize=label_fontsize)
        ax4.set_ylabel('Y', fontsize=label_fontsize)
        ax4.set_zlabel('Z', fontsize=label_fontsize)
        ax4.set_title('3D', fontsize=title_fontsize)
        img4 = ax4.scatter(X_np, Y_np, Z_np, c=solution_np, cmap='viridis', marker='o')
        ax4.tick_params(axis='both', labelsize=tick_fontsize)

        cbar = fig.colorbar(img4, ax=ax4, shrink=0.7, aspect=10)

        if SAVE_PATH:
            plt.savefig(SAVE_PATH)
        if SHOW:
            plt.show()
        plt.close()


    def Losses(self, losses, SAVEPATH, log_scale=False, TITLE=None, SHOW=False, GUIDE_LINE = True
               , OPTIMIZER_CHANGE=None, COLOR='b'):
        # losses : dict
        figwide = self.figwide
        figheight = self.figheight
        ticks_fontsize = self.ticks_fontsize
        label_fontsize = self.label_fontsize
        title_fontsize = self.title_fontsize
        legend_fontsize = self.legend_fontsize

        fig, ax = plt.subplots(figsize=(figwide, figheight))

        if isinstance(losses, dict):
            for label, item in losses.items():
                if (
                    isinstance(item, list)
                    and len(item) == 2
                    and (isinstance(item[0], list) or isinstance(item[0], np.ndarray))
                ):
                    # item = [y_values, color]
                    y = item[0]
                    color = item[1]
                else:
                    y = item
                    color = COLOR

                ax.plot(y, color, label=label, linewidth=2)

        else:
            # losses가 단일 y값 리스트인 경우
            ax.plot(losses, COLOR, label='loss', linewidth=2)

        if GUIDE_LINE:
            all_vals = []

            if isinstance(losses, dict):
                for item in losses.values():
                    if (
                        isinstance(item, list)
                        and len(item) == 2
                        and (isinstance(item[0], list) or isinstance(item[0], np.ndarray))
                    ):
                        arr = np.array(item[0]).flatten()
                    else:
                        arr = np.array(item).flatten()

                    all_vals.extend(arr.tolist())
            else:
                # losses가 단일 loss 곡선일 경우
                arr = np.array(losses).flatten()
                all_vals.extend(arr.tolist())

            min_val = min(all_vals) + 1e-20

            try:
                # 3) n 찾기: 10^{-n} > min_val >= 10^{-(n+1)} 이 되도록
                n = int(-math.floor(math.log10(min_val)))

                # 4) 기준선 그리기: 1e-1, 1e-2, ..., 1e-n
                for i in range(n+1):
                    ax.axhline(
                        y=10**(-i),
                        linestyle='--',
                        linewidth=1,
                        # label=f'1e-{i} threshold'
                    )
            except:
                print(f'min_val : min_val')
                pass
        if log_scale:
            ax.set_yscale('log')
            # ax.yaxis.set_major_formatter(LogFormatterExponent())  # 눈금 형식을 일관되게 설정

        if OPTIMIZER_CHANGE is not None:
            for ep in OPTIMIZER_CHANGE:
                ax.axvline(
                    x=ep,     # 세로선 위치
                    color='red',            # 선 색깔
                    linestyle='--',         # 선 스타일
                    linewidth=1,
                    label=f'Optimizer changed at epoch = {ep}'
                )

        ax.set_title(TITLE, fontsize = title_fontsize)
        ax.set_xlabel('Epoch', fontsize=label_fontsize)  # x축 글씨 크기 조정
        ax.set_ylabel('Loss', fontsize=label_fontsize)  # y축 글씨 크기 조정  
        ax.tick_params(axis='both', labelsize=ticks_fontsize) 
        ax.legend(fontsize=legend_fontsize)                  
        ax.grid(True)                
   

        fig.savefig(SAVEPATH)  
        if SHOW:
            plt.show()
        plt.close(fig)
    # End Plot System


def From_loss_list_To_TXT(loss_list, TXT_PATH):
    with open( TXT_PATH, "w") as f:
        for epoch, loss in enumerate(loss_list):
            f.write(f"Epoch {epoch+1}: Loss {loss}\n")
    return None
