import math
import numpy as np
import jax.numpy as jnp
import yaml
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.gridspec import GridSpec
import datetime
import scipy.sparse as sp
import scipy.sparse.linalg as splinalg
import pickle
import os
import pennylane as qml
import jax
import logging

logging.basicConfig( format = "%(asctime)s - %(filename)s  : %(message)s "
                , level = logging.INFO
                , datefmt = "%I:%M:%S"
                )



def setting(BASE_DIR):
    import argparse
    parser = argparse.ArgumentParser("SEM")

    parser.add_argument("--jax_seed", type=int, default=42)
    parser.add_argument("--random_seed", type=int, default=1993)

    ### PDE Information
    parser.add_argument("--exp_name", type=str
                        , default='CD'
                        )
    DIMENSION = ['1D', '2D', '3D']
    parser.add_argument("--dimension", type=str, default='2D', choices=DIMENSION)
    
    EQUATIONS = ['RD_1D', 'CD_1D','Helmholtz_1D','RD_2D', 'CD_2D', 'Helmholtz_2D']
    parser.add_argument("--equation", type=str
                        # , default= 'RD_2D'
                        # , default='Helmholtz_2D'
                        , default= 'CD_2D'
                        , choices=EQUATIONS
                        )
    parser.add_argument("--pde_parameter"
                        , default=50.
                        )
    parser.add_argument("--basistype", type=str, default='Legendre', choices=['Chebyshev', 'Legendre'])

    BOUNDARY_CONDITION = ['Dirichlet', 'Neumann']
    parser.add_argument("--boundary_condition", type=str, default='Dirichlet', choices=BOUNDARY_CONDITION)

    ANGLENET_NAME = ['MLP', 'CNN']
    parser.add_argument("--anglenet_type", type=str
                        # , default='MLP'
                        , default='CNN'
                        , choices=ANGLENET_NAME)

    ANSATZ_NAME = ['QCNN', 'QCNNNN' , 'StronglyEntanglingLayers' ,'BasicEntanglerLayersRY']
    parser.add_argument("--ansatz_name", type=str
                        , default='StronglyEntanglingLayers'
                        , choices=ANSATZ_NAME)

    ### QML
    parser.add_argument("--n_qubits", type=int
                        # , default=2
                        # , default=4
                        , default=6
                        # , default=8
                        )
    parser.add_argument("--n_layers", type=int
                        # , default=1
                        , default=20
                        )

    ### Preconditioning
    PRECONDITIONS = ['No', 'Diagonal', 'GauseSeidel', 'ILU','RowScaling', 'ColScaling', 'ROWCOL']
    parser.add_argument("--precondition", type=str
                        , default='No'
                        # , default='ILU'
                        # , default='ROWCOL'
                        , choices=PRECONDITIONS)


    # Training Information
    PRECONDITIONS = ['Adam', 'LBFGS', 'From_LBFGS_To_Adam' , 'From_Adam_To_LBFGS']
    parser.add_argument("--optimizer_type", type=str
                        # , default='Adam'
                        , default='LBFGS'
                        # , default='From_LBFGS_To_Adam'
                        # , default='From_Adam_To_LBFGS'
                        )
    parser.add_argument("--EPOCH", type=int
                        , default=200000
                        )

    parser.add_argument("--lr", type=float, default=0.0001) # Helmholtz
    parser.add_argument("--batch_train", type=int
                        , default=800 # LBFGS
                        # , default=200 # Adam
                        )
    parser.add_argument("--batch_test", type=int, default=200)

    # Forcing Data Information
    FORCINGTYPES = ['uniform','zero', 'ones']
    parser.add_argument("--train_size", type=int
                        # , default=8000
                        , default=800
                        )
    parser.add_argument("--test_size", type=int
                        # , default=2000
                        , default=200
                        )
    parser.add_argument("--DB_forcingtype", type=str, default='uniform', choices=FORCINGTYPES)
    parser.add_argument("--DB_sin_mean", type=float, default=0.0)
    parser.add_argument("--DB_sin_sd", type=float
                        # , default=10.0
                        , default=1.0
                        )
    parser.add_argument("--DB_cos_mean", type=float, default=0.0)
    parser.add_argument("--DB_cos_sd", type=float
                        # , default=10.0
                        , default=1.0
                        )

    args,_ = parser.parse_known_args()
    


    if args.ansatz_name == 'StronglyEntanglingLayers' :
        args.output_dim = 3* args.n_layers * args.n_qubits
    elif args.ansatz_name in ('BasicEntanglerLayersRX', 'BasicEntanglerLayersRY'):
        args.output_dim = args.n_layers * args.n_qubits
    elif args.ansatz_name == 'QCNN':
        # args.output_dim = 34
        args.output_dim = 17 * args.n_layers
    else:
        raise NotImplementedError('setting - WRONG ANSATZ')

    args.DB_size = args.train_size + args.test_size

    args.N = 2**(args.n_qubits // 2) + 1

    print(f'N = {args.N}')

    args.input_dim = (args.N+1) ** 2

    args.cur_time  = get_current_time()

    args.PDE_INFO = f"{args.equation}_{args.pde_parameter}_{args.basistype}_{args.boundary_condition}"
    
    args.TRAIN_INFO = ( args.exp_name
                       +'_'+ args.dimension
                       +'_'+ f'{args.cur_time}'
                       +'_'+ f'n_{args.n_qubits}'
                       +'_'+ args.ansatz_name
                       +'_'+ f'n_layers_{args.n_layers}'
                    #    +'_'+ f'lr_{args.lr}' 
                       +'_'+ args.precondition
                       )
    
    args.RESULT_PATH = os.path.join(f'training_{args.optimizer_type}'
                                    , args.PDE_INFO
                                    #+'_'+args.SPECTRAL_INFO
                                    , args.TRAIN_INFO
                                    )
    args.CKPT_PATH = os.path.join(args.RESULT_PATH , 'loss_and_ckpt')
    args.SOLUTION_PATH = os.path.join(args.RESULT_PATH , 'solution')
    
    # SPECTRAL_INFO = args.SPECTRAL_INFO
    args.DB_FOLDER = os.path.join(BASE_DIR, 'data')
    if args.DB_forcingtype == 'uniform':
        args.DB_NAME = ( args.PDE_INFO
                        + f'_n_qubits_{args.n_qubits}_N_{args.N}'
                        +'_train_'+ str(args.train_size)
                        +'_'+ args.DB_forcingtype
                        +'_'+ f'{args.DB_sin_mean:.2f}'
                        +'_'+ f'{args.DB_sin_sd:.2f}'
                        +'_'+ f'{args.DB_cos_mean:.2f}'
                        +'_'+ f'{args.DB_cos_sd:.2f}'
                        )
    elif args.DB_forcingtype == 'ones':
        args.DB_NAME = (args.PDE_INFO 
                                # +'_'+ SPECTRAL_INFO
                                +'_'+ args.train_forcingtype
                                )
    else:
        raise NotImplementedError('Wrong Forcing Type')
    args.DB_ROOT = os.path.join(args.DB_FOLDER, args.DB_NAME+'.pickle')

    np.random.seed(args.random_seed)

    def return_instance_variable(obj):
        attrs = [attr for attr in dir(obj) if not attr.startswith('__') and not callable(getattr(obj, attr))]
        return sorted(attrs)
    
    config = {}
    for name in return_instance_variable(args):
        if name =='device':
            config[name] = str(getattr(args, name))
        else:
            config[name] = getattr(args, name)
    return args, config


def get_current_time():
    cur_time = str(datetime.datetime.now()).replace(' ', 'T')
    cur_time = cur_time.replace(':','').split('.')[0].replace('-','')
    return cur_time

def save_config(config, RESULT_PATH):
    CONFIG_PATH = os.path.join(RESULT_PATH, "config.yaml")
    with open(CONFIG_PATH, "w") as file:
        yaml.dump(config, file, default_flow_style=False)




class precondition:
    def __init__(self, AA, PRECONDITION):
        self.PRECONDITION= PRECONDITION

        if PRECONDITION == 'No':
            self.PL_inv, self.PR_inv = None, None
            self.A_tilde = AA
        elif PRECONDITION == 'INVERSE':
            self.PL_inv, self.A_tilde, self.PR_inv = self.INVERSE(AA)
        elif PRECONDITION == 'ILU':
            self.PL_inv, self.A_tilde, self.PR_inv = self.ILU(AA)
        elif PRECONDITION == 'Diagonal':
            self.PL_inv, self.A_tilde, self.PR_inv  = self.Diagonal(AA)
        elif PRECONDITION == 'GauseSeidel':
            self.PL_inv, self.A_tilde, self.PR_inv = self.GauseSeidel(AA)
        elif PRECONDITION == 'RowScaling':
            self.PL_inv, self.A_tilde, self.PR_inv = self.RowScaling(ValueError)
        elif PRECONDITION == 'ColScaling':
            self.PL_inv, self.A_tilde, self.PR_inv = self.ColScaling(AA)
        elif PRECONDITION == 'ROWCOL':
            self.PL_inv, self.A_tilde, self.PR_inv = self.ROWCOL(AA)
        elif PRECONDITION == 'ITSELF':
            self.PL_inv, self.A_tilde, self.PR_inv  = self.ITSELF(AA)
        else:
            raise NotImplementedError('Wrong Precondition')

        # self.coeff = self.make_coeff()
        # self.RHS = self.make_RHS()

    def make_coeff(self):
        if  self.PR_inv is not None:
            return lambda alpha: jnp.einsum('ij,bj->bi', self.PR_inv, alpha)
        else:
            return lambda alpha: alpha
    
    def make_RHS(self):
        if self.PL_inv is not None:
            return lambda bar_f: jnp.einsum('ij,bj->bi', self.PL_inv, bar_f)
        else:
            return lambda bar_f: bar_f

    def Diagonal(self, AA):
        PL = jnp.diag(jnp.diag(AA))
        PL_inv = jnp.linalg.inv(PL)
        A_tilde = PL_inv @ AA
        return PL_inv, A_tilde, None
    
    def GaussSeidel(self, AA):
        DD = jnp.diag(jnp.diag(AA))
        LL = jnp.tril(AA, -1)
        PL_inv = jnp.linalg.inv(DD + LL)
        A_tilde = PL_inv @ AA
        PR_inv = jnp.eye(AA.shape[0], dtype=AA.dtype)
        return PL_inv, A_tilde, PR_inv
    
    def INVERSE(self, AA):
        PL_inv = jnp.linalg.inv(AA)
        A_tilde = PL_inv@AA
        return PL_inv, A_tilde, None
    
    def ITSELF(self, AA):
        PL_inv = AA
        A_tilde = AA@AA
        return PL_inv, A_tilde, None

    def ILU(self, AA):
        A_np = np.array(AA)
        ILU = splinalg.spilu(sp.csr_matrix(A_np))   # ILU factorization
        L_np = ILU.L.toarray()
        U_np = ILU.U.toarray()
        LU = L_np @ U_np
        LU_inv =  jnp.linalg.inv(jnp.array(LU))
        LU_inv_A = LU_inv @ AA
        return LU_inv, LU_inv_A, None
    
    def RowScaling(self, AA):
        row_norms = jnp.linalg.norm(AA, ord=2, axis=1, keepdims=True)
        row_scaling = jnp.diag(1.0 / row_norms.squeeze())
        A_row_scaled = row_scaling @ AA
        PR_inv = jnp.eye(AA.shape[0], dtype=AA.dtype)
        return row_scaling, A_row_scaled, PR_inv
    
    def ColScaling(self, AA):
        PL_inv = jnp.eye(AA.shape[0], dtype=AA.dtype)
        col_norms = jnp.linalg.norm(AA, ord=2, axis=0, keepdims=True)
        col_scaling = jnp.diag(1.0 / col_norms.squeeze())
        A_scaled = AA @ col_scaling
        return PL_inv, A_scaled, col_scaling
    
    def ROWCOL(self, AA):
        # Row scaling
        row_norms = jnp.linalg.norm(AA, ord=2, axis=1, keepdims=True)
        row_scaling = jnp.diag(1.0 / row_norms.squeeze())
        A_row_scaled = row_scaling @ AA

        # Column scaling
        col_norms = jnp.linalg.norm(A_row_scaled, ord=2, axis=0, keepdims=True)
        col_scaling = jnp.diag(1.0 / col_norms.squeeze())

        A_scaled = A_row_scaled @ col_scaling
        return row_scaling, A_scaled, col_scaling




class PlotSystem:
    def __init__(self):
        
        self.figwide = 12
        self.figheight = 8

        self.title_fontsize=22
        self.label_fontsize=18
        self.ticks_fontsize=18
        self.legend_fontsize=18
        self.linewidth = 3.0


    def Two_Functions_1d(self, IMGNAME, x1, y1, label1, x2, y2, label2, IMG_TITLE=None):
        PATH = self.args.RESULT_PATH
        assert len(x1.shape) == len(y1.shape) == len(x2.shape)== len(y2.shape), f'{x1.shape} {y1.shape} {x2.shape} {y2.shape}'
        assert x1.shape == y1.shape
        assert x2.shape == y2.shape
        
        plt.figure(figsize=(self.figwide, self.figheight))
        plt.plot(x1.squeeze(), y1, label=label1)
        plt.plot(x2.squeeze(), y2, label=label2)
        plt.xlabel('x', fontsize= self.label_fontsize)
        plt.ylabel('y', fontsize= self.label_fontsize)
        
        plt.xticks(fontsize= self.ticks_fontsize)
        plt.yticks(fontsize= self.ticks_fontsize)
        plt.title(IMG_TITLE, fontsize= self.title_fontsize)
        plt.legend(loc='best', fontsize= self.legend_fontsize, ncol=2)

        plt.grid(True)
        plt.savefig(os.path.join(PATH, IMGNAME))
        plt.close()
        return None


    def Polynomials(self, xx, PP, SAVE_PATH=None, SHOW = True):

        plt.figure(figsize=(self.figwide, self.figheight))
        
        for n in range(len(PP)):
            plt.plot(xx, PP[n])
        
        plt.xticks(fontsize= self.ticks_fontsize)
        plt.yticks(fontsize= self.ticks_fontsize)

        plt.xlabel(r'$x$ (Collocation)', fontsize= self.label_fontsize)
        plt.ylabel(r'$P$', fontsize= self.label_fontsize)
        plt.title(f'Orthogonal Polynomials', fontsize= self.title_fontsize)

        plt.savefig(SAVE_PATH)
        if SHOW:
            plt.show()
        plt.close()



    def Basis(self, xx, phi, SAVE_PATH='basis_plot.png', SHOW = True):

        plt.figure(figsize=(self.figwide, self.figheight))
        
        for n in range(phi.shape[0]):
            plt.plot(xx, phi[n,:])
        
        plt.xticks(fontsize= self.ticks_fontsize)
        plt.yticks(fontsize= self.ticks_fontsize)

        plt.xlabel(r'$x$ (Collocation)', fontsize= self.label_fontsize)
        plt.ylabel(r'$\phi$', fontsize= self.label_fontsize)
        plt.title(f'Basis', fontsize= self.title_fontsize)

        plt.savefig(SAVE_PATH)
        if SHOW:
            plt.show()
        plt.close()

    def Basis_2D(self, X, Y, phi, SAVE_PATH='basis_2D_plot.png', SHOW = True):
        # Initialize the figure
        fig = plt.figure(figsize=(12, 6))

        # Contour plot (2D view)
        ax1 = fig.add_subplot(1, 2, 1)
        contour = ax1.contourf(X, Y, phi, cmap='viridis', levels=50)
        plt.colorbar(contour, ax=ax1)
        ax1.set_title('Contour Plot', fontsize=14)
        ax1.set_xlabel(r'$x$', fontsize=12)
        ax1.set_ylabel(r'$y$', fontsize=12)

        # Surface plot (3D view)
        ax2 = fig.add_subplot(1, 2, 2, projection='3d')
        surf = ax2.plot_surface(X, Y, phi, cmap='viridis', edgecolor='none')
        fig.colorbar(surf, ax=ax2, shrink=0.5, aspect=10)
        ax2.set_title('Surface Plot', fontsize=14)
        ax2.set_xlabel(r'$x$', fontsize=12)
        ax2.set_ylabel(r'$y$', fontsize=12)
        ax2.set_zlabel(r'Solution', fontsize=12)

        plt.tight_layout()
        plt.savefig(SAVE_PATH)
        if SHOW:
            plt.show()
        plt.close()


    def Pauli_Decomposition_Plot(self, MATRIX_SIZE, L_LIST, SAVE_PATH=None, SHOW=False):
        
        plt.figure(figsize=(8, 5))
        plt.plot(MATRIX_SIZE, L_LIST, label=r"$L$", color='b', linewidth=2)
    
        plt.plot(MATRIX_SIZE, MATRIX_SIZE, linestyle="--", color='g', label=r"$N$")
        plt.plot(MATRIX_SIZE, MATRIX_SIZE* np.log(MATRIX_SIZE), linestyle="--", color='r', label=r"$N \log N$")

        plt.xlabel(r'Matrix Size $N$ x $N$')
        plt.ylabel(r"Pauli-decomposition $L$")
        plt.title(r"Pauli-decomposition $L$ with increasing matrix size")
        plt.legend()
        plt.grid(True)

        plt.tight_layout()

        if SAVE_PATH is not None:
            plt.savefig(SAVE_PATH)

        if SHOW:
            plt.show()
        plt.close()


    def Solution_1D(self, xx
                    , pred_u_dict=None
                    , true_u=None
                    , NORM=None, SAVE_PATH = None, SHOW=False, COLOR='b'):

        title_fontsize = self.title_fontsize
        label_fontsize = self.label_fontsize
        ticks_fontsize = self.ticks_fontsize
        legend_fontsize = self.legend_fontsize
        linewidth = self.linewidth
        
        plt.figure(figsize=(10, 7))

        if pred_u_dict is not None:
            for values, label in zip(pred_u_dict.values(), pred_u_dict.keys()):
                plt.plot(xx, values.squeeze(),label = label, color=COLOR, linewidth=linewidth)

        if true_u is not None:
            plt.plot(xx, true_u.squeeze(), 'r--', label = 'True solution', linewidth = linewidth)

        if NORM is not None:
            L2_norm = NORM['RelL2']
            Linf_norm = NORM['RelLinf']
            text_str =  (r"Relative $L_2$"
                        # +f'={L2_norm:.6f}'+'\n' 
                        +f'={L2_norm:}'+'\n' 
                        +r"Relative $L_\text{inf}$"
                        +f'={Linf_norm}')
                        # +f'={Linf_norm:.6f}')
            plt.title(text_str, fontsize=title_fontsize)
            # plt.text(0.05, 0.95, text_str, fontsize=10, transform=plt.gca().transAxes,
            #         verticalalignment='top'
            #         , bbox=dict(boxstyle="round", alpha=0.3, facecolor="white"))

        plt.xlabel(r'$x$ (Collocation)', fontsize=label_fontsize)
        plt.ylabel(r'$u$', fontsize=label_fontsize)
        plt.tick_params(axis='both', labelsize=ticks_fontsize) 

        plt.legend(loc='best', fontsize=legend_fontsize)
        plt.savefig(SAVE_PATH, dpi=300)
        if SHOW:
            plt.show()
        plt.close()


    def Solution_2D(self, X, Y
                    , pred_u=None
                    , true_u=None
                    , NORM=None
                    , SAVE_PATH='solution_2D_plot.png'
                    , SHOW=False):
        """
        2D Solution Plotting
        """
        NN = X.shape[0]-1
        pred_u = pred_u.reshape(NN+1, NN+1)
        true_u = true_u.reshape(NN+1, NN+1)
        # Initialize the figure
        fig = plt.figure(figsize=(20, 6))

        # Contour plot (2D view)
        ax1 = fig.add_subplot(1, 3, 1, projection='3d')
        ax1.set_title('Prediction')
        ax1.plot_surface(X, Y, pred_u
                                    , cmap='coolwarm', edgecolor='none', alpha=0.5
                                    , label='Reference Solution')
                                    

        ax2 = fig.add_subplot(1, 3, 2, projection='3d')
        ax2.set_title('True Solution')
        ax2.plot_surface(X, Y, true_u
                                , cmap='viridis', edgecolor='none')
        # fig.colorbar(surf, ax=ax2, shrink=0.5, aspect=10)

        ax3 = fig.add_subplot(1, 3, 3)
        ax3.set_aspect('equal')
        ax3.set_title('Absolute Error')
        contour = ax3.contourf(X, Y, jnp.abs(pred_u - true_u )
                               , cmap='viridis', levels=50)
        plt.colorbar(contour, ax=ax3)

        # Surface plot (3D view)
        from matplotlib.lines import Line2D
        legend_elements = [
            Line2D([0], [0], color='blue', lw=4, label='Solution'),
        ]


        legend_elements.append(Line2D([0], [0], color='red', lw=4, label='Prediction'))


        L2_norm = NORM['RelL2']
        Linf_norm = NORM['RelLinf']

        TITLE1 = r"$L_2$"+f'={L2_norm:.4e}' + '\n'
        TITLE2 = r"$L_\text{inf}$" + f'={Linf_norm:.4e}'
        plt.title(TITLE1 + TITLE2, fontsize=14)
        
        ax1.set_xlabel(r'$x$', fontsize=12)
        ax1.set_ylabel(r'$y$', fontsize=12)
        ax2.set_xlabel(r'$x$', fontsize=12)
        ax2.set_ylabel(r'$y$', fontsize=12)
        ax3.set_xlabel(r'$x$', fontsize=12)
        ax3.set_ylabel(r'$y$', fontsize=12)
        # ax2.legend(handles=legend_elements, loc='upper right')
        plt.tight_layout()
        
        plt.savefig(SAVE_PATH)
        if SHOW:
            plt.show()
        plt.close()

    def Solution_3D(self, X, Y, Z, solution,
                    SAVE_PATH=None, SHOW=True, 
                    title_fontsize=32, label_fontsize=28, tick_fontsize=24):
        NN = X.shape[0]-1
        # Convert tensors to numpy for matplotlib
        X_np = X.numpy()
        Y_np = Y.numpy()
        Z_np = Z.numpy()
        solution_np = solution.view(NN+1, NN+1, NN+1).numpy()

        # Create a 3D figure
        fig = plt.figure(figsize=(50, 8))
        plt.tight_layout()

        ax1 = fig.add_subplot(141)
        ax1.set_xlabel('X', fontsize=label_fontsize)
        ax1.set_ylabel('Y', fontsize=label_fontsize)
        ax1.set_title('xy-plane', fontsize=title_fontsize)
        surf1 = ax1.contourf(X[:,:,NN//2], Y[:,:,NN//2], solution[:,:,NN//2], cmap='viridis', levels=50)
        ax1.tick_params(axis='both', labelsize=tick_fontsize)

        ax2 = fig.add_subplot(142)
        ax2.set_xlabel('X', fontsize=label_fontsize)
        ax2.set_ylabel('Z', fontsize=label_fontsize)
        ax2.set_title('xz-plane', fontsize=title_fontsize)
        surf2 = ax2.contourf(X[:,NN//2,:], Z[:,NN//2,:], solution[:,NN//2,:], cmap='viridis', levels=50)
        ax2.tick_params(axis='both', labelsize=tick_fontsize)

        ax3 = fig.add_subplot(143)
        ax3.set_xlabel('Y', fontsize=label_fontsize)
        ax3.set_ylabel('Z', fontsize=label_fontsize)
        ax3.set_title('yz-plane', fontsize=title_fontsize)
        surf3 = ax3.contourf(Y[NN//2,:,:], Z[:,0,:], solution[NN//2,:,:], cmap='viridis', levels=50)
        ax3.tick_params(axis='both', labelsize=tick_fontsize)
        
        ax4 = fig.add_subplot(144, projection='3d')
        ax4.set_xlabel('X', fontsize=label_fontsize)
        ax4.set_ylabel('Y', fontsize=label_fontsize)
        ax4.set_zlabel('Z', fontsize=label_fontsize)
        ax4.set_title('3D', fontsize=title_fontsize)
        img4 = ax4.scatter(X_np, Y_np, Z_np, c=solution_np, cmap='viridis', marker='o')
        ax4.tick_params(axis='both', labelsize=tick_fontsize)

        cbar = fig.colorbar(img4, ax=ax4, shrink=0.7, aspect=10)

        if SAVE_PATH:
            plt.savefig(SAVE_PATH)
        if SHOW:
            plt.show()
        plt.close()


    def Losses(self, losses, SAVEPATH, log_scale=False, TITLE=None, SHOW=False
               , GUIDE_LINE = True, OPTIMIZER_CHANGE=None, COLOR='b'):
        # losses : dict
        figwide = self.figwide
        figheight = self.figheight
        ticks_fontsize = self.ticks_fontsize
        label_fontsize = self.label_fontsize
        title_fontsize = self.title_fontsize
        legend_fontsize = self.legend_fontsize

        fig, ax = plt.subplots(figsize=(figwide, figheight))

        if isinstance(losses, dict):
            for label, item in losses.items():
                if (
                    isinstance(item, list)
                    and len(item) == 2
                    and (isinstance(item[0], list) or isinstance(item[0], np.ndarray))
                ):
                    # item = [y_values, color]
                    y = item[0]
                    color = item[1]
                else:
                    y = item
                    color = COLOR

                ax.plot(y, color, label=label, linewidth=2)

        else:
            # losses가 단일 y값 리스트인 경우
            ax.plot(losses, label='loss', color=COLOR, linewidth=2)

        if GUIDE_LINE:
            all_vals = []

            if isinstance(losses, dict):
                for item in losses.values():
                    if (
                        isinstance(item, list)
                        and len(item) == 2
                        and (isinstance(item[0], list) or isinstance(item[0], np.ndarray))
                    ):
                        arr = np.array(item[0]).flatten()
                    else:
                        arr = np.array(item).flatten()

                    all_vals.extend(arr.tolist())
            else:
                # losses가 단일 loss 곡선일 경우
                arr = np.array(losses).flatten()
                all_vals.extend(arr.tolist())

            min_val = min(all_vals) + 1e-20

            # 3) n 찾기: 10^{-n} > min_val >= 10^{-(n+1)} 이 되도록
            n = int(-math.floor(math.log10(min_val)))

            # 4) 기준선 그리기: 1e-1, 1e-2, ..., 1e-n
            for i in range(n+1):
                ax.axhline(
                    y=10**(-i),
                    linestyle='--',
                    linewidth=1,
                    # label=f'1e-{i} threshold'
                )
        if log_scale:
            ax.set_yscale('log')
            # ax.yaxis.set_major_formatter(LogFormatterExponent())  # 눈금 형식을 일관되게 설정

        if OPTIMIZER_CHANGE is not None:
            for ep in OPTIMIZER_CHANGE:
                ax.axvline(
                    x=ep,     # 세로선 위치
                    color='red',            # 선 색깔
                    linestyle='--',         # 선 스타일
                    linewidth=1,
                    label=f'Optimizer changed at epoch = {ep}'
                )
                
        ax.set_title(TITLE, fontsize = title_fontsize)
        ax.set_xlabel('Epoch', fontsize=label_fontsize)  # x축 글씨 크기 조정
        ax.set_ylabel('Loss', fontsize=label_fontsize)  # y축 글씨 크기 조정  
        ax.tick_params(axis='both', labelsize=ticks_fontsize) 
        ax.legend(fontsize=legend_fontsize)                  
        ax.grid(True)                
   

        fig.savefig(SAVEPATH)  
        if SHOW:
            plt.show()
        plt.close(fig)
    # End Plot System


def From_loss_list_To_TXT(loss_list, TXT_PATH):
    with open( TXT_PATH, "w") as f:
        for epoch, loss in enumerate(loss_list):
            f.write(f"Epoch {epoch+1}: Loss {loss}\n")
    return None
