R"""
"""
#
import abc
import numpy as onp
import torch
import time
import os
from typing import List, Tuple, Optional, Generic, TypeVar
from ..meta.meta import Meta
from .transfer import transfer
from .types import TIMECOST
from ..task.task import Task
from ..utils.info import info5
from .gradclip import GRADCLIPS
from lib.model.mlp import MLP
import torch.optim as optim
from dgl import DGLGraph
import dgl
from sklearn import linear_model
from sklearn.utils.extmath import randomized_svd
import scipy.sparse as sp
import lib.utils.ppr as ppr
import numpy as onp
import torch.nn as nn


#
META = TypeVar("META", bound=Meta)


#
DIR = "log"
LR_DECAY = 10
LR_THRES = 1e-4 * (1 + 1 / LR_DECAY)
IMP_ABS = 0
IMP_REL = 0

class FrameworkIndexable(abc.ABC, Generic[META]):
    R"""
    Framework with indexable (by finite integer) meta samples.
    """
    #
    BATCH_PAD: bool

    def __init__(
        self,
        identifier: str, metaset: META, neuralnet: Task,
        /,
        *,
        lr: float, weight_decay: float, seed: int, device: str,
        metaspindle: str, gradclip: str,
    ) -> None:
        R"""
        Initalize the class.
        """
        #
        self.identifier = identifier
        self.metaset = metaset
        self.neuralnet = neuralnet
        self.seed = seed
        self.device = device

        #
        self.metaspindle = metaspindle

        #
        if neuralnet.num_resetted_params > 0:
            #
            self.gradclip = GRADCLIPS[gradclip]
            self.optim = (
                torch.optim.Adam(
                    self.neuralnet.parameters(),
                    lr=lr, weight_decay=weight_decay,
                )
            )

        # Move model to device after creating the optimizer.
        self.neuralnet = neuralnet.to(device)
        # Ensure the existence of log directory.
        if not os.path.isdir(DIR):
            #
            os.makedirs(DIR, exist_ok=True)
        self.dir_this_time = self.create_next_numerical_directory(DIR)
        self.save_result = os.path.join(self.dir_this_time, "{:s}.txt".format('results'))
        self.ptnnp = os.path.join(self.dir_this_time, "{:s}.ptnnp".format(self.identifier))
        self.ptres = os.path.join(self.dir_this_time, "{:s}.ptres".format(self.identifier))
        self.ptlog = os.path.join(self.dir_this_time, "{:s}.ptlog".format(self.identifier))
        self.model_path = os.path.join(self.dir_this_time, "{:s}.pth".format(self.identifier))
        self.whole_model_path = os.path.join(self.dir_this_time, "{:s}.pth".format('whole_model_tgnn'))
        self.ptbev = os.path.join(self.dir_this_time, "{:s}.ptbev".format(self.identifier))
        self.best_epoch_saved = 0
        self.prepare_sdgnn(0.0005, 0.000001)
        self.sdgnn_saved_path = os.path.join(self.dir_this_time, "{:s}.pth".format('sdgnn'))
        self.weights_saved_path = os.path.join(self.dir_this_time, "{:s}.pth".format('weights'))


    def prepare_sdgnn(self, lr, weight_decay):
        self.neuralnet.transformation_model = self.neuralnet.transformation_model.to(self.device)
        self.optimizer_sdgnn = optim.Adam(
            self.neuralnet.transformation_model.parameters(), lr=lr, weight_decay=weight_decay
        )
        idx = onp.arange(self.metaset.num_nodes)
        # print("self.metaset.num_nodes: ", idx)
        # Create the adjacency matrix in COO format
        data = onp.ones(self.metaset.edge_srcs.shape[0])  # Use 1 for the presence of an edge
        adj_matrix_coo = sp.coo_matrix((data, (self.metaset.edge_srcs, self.metaset.edge_dsts)),\
                                        shape=(self.metaset.num_nodes, self.metaset.num_nodes))

        # Convert it to CSR format
        adj_matrix_csr = adj_matrix_coo.tocsr()

        topk_ppr_matrix = ppr.topk_ppr_matrix(adj_matrix=adj_matrix_csr, alpha=0.5, eps=1e-4, idx=idx, topk=64)
        # print(topk_ppr_matrix.toarray().shape)
        self.weights = topk_ppr_matrix.toarray()
        self.loss_func = nn.MSELoss()
        self.freeze_target_model()
        self.optimizer_mlp = optim.Adam(self.neuralnet.mlp.parameters(), lr=lr, weight_decay=weight_decay)



    def create_next_numerical_directory(self, parent_directory):
        # Ensure the parent directory exists.
        if not os.path.exists(parent_directory):
            print(f"The directory {parent_directory} does not exist.")
            return

        # List all items in the parent directory.
        all_items = os.listdir(parent_directory)

        # Filter out items that are not directories.
        directories = [item for item in all_items if os.path.isdir(os.path.join(parent_directory, item))]

        # Convert directory names to integers, ignore non-integer names.
        numerical_directories = []
        for dir_name in directories:
            try:
                numerical_directories.append(int(dir_name))
            except ValueError:
                # This directory name is not an integer, ignore it.
                continue

        # Find the largest number (if any).
        if numerical_directories:
            largest_number = max(numerical_directories)
        else:
            largest_number = 0  # No numerical directories exist, start with 0.

        # The name for the new directory is one more than the largest number found.
        new_dir_name = str(largest_number + 1)
        new_dir_path = os.path.join(parent_directory, new_dir_name)

        # Create the new directory.
        os.mkdir(new_dir_path)
        print("=" * 10 + " " + "Created new directory: {}".format(new_dir_path)+ " " + "=" * 10)
        return new_dir_path

    @abc.abstractmethod
    def train(
        self,
        meta_indices: List[int], meta_index_pad: Optional[int],
        meta_batch_size: int, pinned: List[torch.Tensor],
        /,
    ) -> TIMECOST:
        R"""
        Train.
        Mostly used for neural network parameter tuning.
        """
        #
        ...

    @abc.abstractmethod
    def evaluate(
        self,
        meta_indices: List[int], meta_index_pad: Optional[int],
        meta_batch_size: int, pinned: List[torch.Tensor],
        /,
    ) -> Tuple[List[float], TIMECOST]:
        R"""
        Evaluate.
        Mostly used for neural network parameter evaluation.
        """
        #
        ...

    def preprocess(
        self,
        proportion: Tuple[int, int, int], priority: Tuple[int, int, int],
        train_prop: Tuple[int, int, bool],
        /,
    ) -> None:
        R"""
        Preprocess metaset before fitting.
        priority: Time series data: In time series data, priority can be set to ensure that the most recent data is used for the training set, while older data can be used for validation or testing, simulating the real-world flow of data.
        """
        # Split and normalize.
        print("=" * 10 + " " + "(Prep)rocessing" + " " + "=" * 10)
        (
            self.meta_indices_train, self.meta_indices_valid,
            self.meta_indices_test,
        ) = (
            self.metaset.fitsplit(proportion, priority, self.metaspindle)
            if train_prop[1] == 0 else
            self.metaset.reducesplit(
                proportion, priority, self.metaspindle, *train_prop,
            )
        )
        meta_size_train = len(self.meta_indices_train)
        meta_size_valid = len(self.meta_indices_valid)
        meta_size_test = len(self.meta_indices_test)
        meta_size = meta_size_train + meta_size_valid + meta_size_test
        self.factors = (
            self.metaset.normalizeby(self.meta_indices_train, self.metaspindle)
        )
        print(
            info5(
                {
                    "Split": {
                        "Train": (
                            "{:d}/{:d}".format(meta_size_train, meta_size)
                        ),
                        "Validate": (
                            "{:d}/{:d}".format(meta_size_valid, meta_size)
                        ),
                        "Test": (
                            "{:d}/{:d}".format(meta_size_test, meta_size)
                        ),
                    },
                },
            )
        )
        print(self.metaset.distrep(n=3))

    def fit(
        self,
        proportion: Tuple[int, int, int], priority: Tuple[int, int, int],
        train_prop: Tuple[int, int, bool],
        /,
        *,
        batch_size: int, max_epochs: int, validon: int, validrep: str,
        patience: int,
    ) -> None:
        R"""
        Fit neural network of the framework based on initialization status.
        """
        #
        self.preprocess(proportion, priority, train_prop)
        

        #
        timecosts: TIMECOST

        # Zero-out number of tuning epochs for non-parametric cases.
        if self.neuralnet.num_resetted_params == 0:
            #
            max_epochs = 0

        #
        rng = onp.random.RandomState(self.seed)
        timecosts = {}
        meta_index_pad = (
            onp.min(self.meta_indices_train).item() if self.BATCH_PAD else None
        )
        epochlen = len(str(max_epochs))
        logs = []
        metric_best = 0.0
        num_not_improving = 0
        noimplen = len(str(patience))

        # Pin shared memory.
        elapsed = time.time()
        pinned_numpy = self.metaset.pin(batch_size)
        timecosts["pin.generate"] = time.time() - elapsed
        elapsed = time.time()
        pinned_ondev = transfer(pinned_numpy, self.device)
        timecosts["pin.transfer"] = time.time() - elapsed

        # Validate once before training.
        print("=" * 10 + " " + "Train & Validate" + " " + "=" * 10)
        with torch.no_grad():
            #
            (metrics, timeparts) = (
                self.evaluate(
                    self.meta_indices_valid.tolist(), meta_index_pad,
                    batch_size, pinned_ondev,
                )
            )
        timecosts["valid.generate"] = timeparts["generate"]
        timecosts["valid.transfer"] = timeparts["transfer"]
        timecosts["valid.forward"] = timeparts["forward"]
        logs.append(metrics)

        # Initialize performance status.
        torch.save(self.neuralnet.state_dict(), self.ptnnp)
        metric_best = metrics[validon]
        num_not_improving = 0
        print(
            "[{:>{:d}d}/{:d}] {:s}: {:>8s} ({:>8s}) {:s}{:<{:d}s}".format(
                0, epochlen, max_epochs, validrep,
                "{:.6f}".format(metrics[validon])[:8],
                "{:.6f}".format(metric_best)[:8],
                "\x1b[92m↑\x1b[0m", "", noimplen,
            ),
        )

        #
        timecosts["train.generate"] = []
        timecosts["train.transfer"] = []
        timecosts["train.forward"] = []
        timecosts["train.backward"] = []
        for epoch in range(1, 1 + max_epochs):
            #
            shuffling = rng.permutation(len(self.meta_indices_train))
            meta_indices_train = self.meta_indices_train[shuffling]
            # Train.
            timeparts = (
                self.train(
                    meta_indices_train.tolist(), meta_index_pad, batch_size,
                    pinned_ondev,
                )
            )
            timecosts["train.generate"] = timeparts["generate"]
            timecosts["train.transfer"] = timeparts["transfer"]
            timecosts["train.forward"] = timeparts["forward"]
            timecosts["train.backward"] = timeparts["backward"]

            # Validate.
            with torch.no_grad():
                #
                (metrics, timeparts) = (
                    self.evaluate(
                        self.meta_indices_valid.tolist(), meta_index_pad,
                        batch_size, pinned_ondev,
                    )
                )
            timecosts["valid.generate"] = timeparts["generate"]
            timecosts["valid.transfer"] = timeparts["transfer"]
            timecosts["valid.forward"] = timeparts["forward"]
            logs.append(metrics)
            # Update performance status
            if (
                metrics[validon] < metric_best - IMP_ABS
                or metrics[validon] < metric_best * (1 - IMP_REL)
            ):

                #
                torch.save(self.neuralnet.state_dict(), self.model_path)
                torch.save(self.neuralnet.state_dict(), self.ptnnp)
                snn_edge_dest_dir = os.path.join(self.dir_this_time, 'snn_edge.pth')
                snn_node_dest_dir = os.path.join(self.dir_this_time, 'snn_node.pth')
                gnnx2_dest_dir = os.path.join(self.dir_this_time, 'gnnx2.pth')
                lower_level_mlp_dir = os.path.join(self.dir_this_time, 'lower_level_mlp.pth')
                torch.save(self.neuralnet.tgnn.snn_edge.state_dict(), snn_edge_dest_dir)
                torch.save(self.neuralnet.tgnn.snn_node.state_dict(), snn_node_dest_dir)
                torch.save(self.neuralnet.tgnn.gnnx2.state_dict(), gnnx2_dest_dir) 
                torch.save(self.neuralnet.mlp.state_dict(), lower_level_mlp_dir)
                # Additional information

                torch.save({
                            'max_epoch':1+max_epochs,
                            'epoch': epoch,
                            'model_state_dict': self.neuralnet.tgnn.state_dict(),
                            'optimizer_state_dict': self.optim.state_dict(),
                            'snn_node': self.neuralnet.tgnn.snn_node.state_dict(),
                            'snn_edge': self.neuralnet.tgnn.snn_edge.state_dict(),
                            'gnn_model': self.neuralnet.tgnn.gnnx2.state_dict(),
                            }, self.whole_model_path)
                metric_best = metrics[validon]
                improving = True
                num_not_improving = 0
            else:
                #
                improving = False
                num_not_improving = num_not_improving + 1
            print(
                "[{:>{:d}d}/{:d}] {:s}: {:>8s} ({:>8s}) {:s}{:<{:d}s}{:s}"
                .format(
                    epoch, epochlen, max_epochs, validrep,
                    "{:.6f}".format(metrics[validon])[:8],
                    "{:.6f}".format(metric_best)[:8],
                    "\x1b[92m↑\x1b[0m" if improving else "\x1b[91m↓\x1b[0m",
                    "" if improving else str(num_not_improving), noimplen,
                    " --" if num_not_improving == patience else "",
                ),
            )

            # Adjust learning rate according to performance status if
            # necessary.
            if patience >= 0 and num_not_improving == patience:
                # Reduce the learning rate if the neural network is not
                # improving for a while.
                lr_reach_epsilon = False
                if self.neuralnet.num_resetted_params > 0:
                    #
                    for group in self.optim.param_groups:
                        #
                        group["lr"] = group["lr"] / LR_DECAY
                        if group["lr"] < LR_THRES:
                            #
                            lr_reach_epsilon = True
                    num_not_improving = 0
                else:
                    # Directly terminate on the first decay for non-parametric
                    # case.
                    lr_reach_epsilon = True
                if lr_reach_epsilon:
                    # Early stop if learning rate is too small.
                    break
        if hasattr(self.neuralnet, "tgnn"):
            #
            torch.save(
                (
                    torch.cuda.get_device_name(),
                    getattr(getattr(self.neuralnet, "tgnn"), "COSTS"),
                ),
                self.ptres,
            )

        # Final test after training.
        print("=" * 10 + " " + "Test" + " " + "=" * 10)
        with torch.no_grad():
            #
            (metrics_valid, timeparts) = (
                self.evaluate(
                    self.meta_indices_valid.tolist(), meta_index_pad,
                    batch_size, pinned_ondev,
                )
            )
            (metrics_test, timeparts) = (
                self.evaluate(
                    self.meta_indices_test.tolist(), meta_index_pad,
                    batch_size, pinned_ondev,
                )
            )
        timecosts["test.generate"] = timeparts["generate"]
        timecosts["test.transfer"] = timeparts["transfer"]
        timecosts["test.forward"] = timeparts["forward"]
        print(
            "Valid\x1b[94m:\x1b[0m \x1b[3m{:s}\x1b[0m: {:s}"
            .format(validrep, "{:.6f}".format(metrics_valid[validon])[:8]),
        )
        print(
            " Test\x1b[94m:\x1b[0m \x1b[3m{:s}\x1b[0m: {:s}"
            .format(validrep, "{:.6f}".format(metrics_test[validon])[:8]),
        )

        #
        print("=" * 10 + " " + "(Res)ource (Stat)istics" + " " + "=" * 10)
        gpu_mem_peak = int(onp.ceil(torch.cuda.max_memory_allocated() / 1024))
        print("Max GPU Memory: {:d} KB".format(gpu_mem_peak))
        torch.save(
            (
                self.factors, logs, metrics_valid[validon], metrics_test,
                gpu_mem_peak, timecosts,
            ),
            self.ptlog,
        )


    def fit_sdgnn(
        self,
        proportion: Tuple[int, int, int], priority: Tuple[int, int, int],
        train_prop: Tuple[int, int, bool],
        /,
        *,
        batch_size: int, max_epochs: int, validon: int, validrep: str,
        patience: int,
    ) -> None:
        R"""
        Fit SDGNN of the framework based on initialization status.
        """
        #
        self.preprocess(proportion, priority, train_prop)
        

        #
        timecosts: TIMECOST

        # Zero-out number of tuning epochs for non-parametric cases.
        if self.neuralnet.num_resetted_params == 0:
            #
            max_epochs = 0

        #
        rng = onp.random.RandomState(self.seed)
        timecosts = {}
        meta_index_pad = (
            onp.min(self.meta_indices_train).item() if self.BATCH_PAD else None
        )
        epochlen = len(str(max_epochs))
        logs = []
        metric_best = 0.0
        num_not_improving = 0
        noimplen = len(str(patience))

        # Pin shared memory.
        elapsed = time.time()
        pinned_numpy = self.metaset.pin(batch_size)
        timecosts["pin.generate"] = time.time() - elapsed
        elapsed = time.time()
        pinned_ondev = transfer(pinned_numpy, self.device)
        timecosts["pin.transfer"] = time.time() - elapsed

        # Validate once before training.
        print("=" * 10 + " " + "Train & Validate" + " " + "=" * 10)

        with torch.no_grad():
            #
            (metrics, timeparts) = (
                self.evaluate_sdgnn(
                    self.meta_indices_valid.tolist(), meta_index_pad,
                    batch_size, pinned_ondev,
                )
            )
        
        timecosts["valid.generate"] = timeparts["generate"]
        timecosts["valid.transfer"] = timeparts["transfer"]
        timecosts["valid.forward"] = timeparts["forward"]
        logs.append(metrics)

        # Initialize performance status.
        torch.save(self.neuralnet.state_dict(), self.ptnnp)
        metric_best = metrics[validon]
        num_not_improving = 0
        print(
            "[{:>{:d}d}/{:d}] {:s}: {:>8s} ({:>8s}) {:s}{:<{:d}s}".format(
                0, epochlen, max_epochs, validrep,
                "{:.6f}".format(metrics[validon])[:8],
                "{:.6f}".format(metric_best)[:8],
                "\x1b[92m↑\x1b[0m", "", noimplen,
            ),
        )

        #
        timecosts["train.generate"] = []
        timecosts["train.transfer"] = []
        timecosts["train.forward"] = []
        timecosts["train.backward"] = []

        for epoch in range(1, 1 + max_epochs):
            #
            shuffling = rng.permutation(len(self.meta_indices_train))
            meta_indices_train = self.meta_indices_train[shuffling]
            # Train.
            timeparts = (
                self.train_sdgnn(
                    meta_indices_train.tolist(), meta_index_pad, batch_size,
                    pinned_ondev, epoch
                )
            )

            # Validate.
            with torch.no_grad():
                #
                (metrics, timeparts) = (
                    self.evaluate_sdgnn(
                        self.meta_indices_valid.tolist(), meta_index_pad,
                        batch_size, pinned_ondev,
                    )
                )
            timecosts["valid.generate"] = timeparts["generate"]
            timecosts["valid.transfer"] = timeparts["transfer"]
            timecosts["valid.forward"] = timeparts["forward"]
            logs.append(metrics)
            # Update performance status
            if (
                metrics[validon] < metric_best - IMP_ABS
                or metrics[validon] < metric_best * (1 - IMP_REL)
            ):

                #
                torch.save(self.neuralnet.state_dict(), self.model_path)
                torch.save(self.neuralnet.state_dict(), self.ptnnp)
                snn_edge_dest_dir = os.path.join(self.dir_this_time, 'snn_edge.pth')
                snn_node_dest_dir = os.path.join(self.dir_this_time, 'snn_node.pth')
                gnnx2_dest_dir = os.path.join(self.dir_this_time, 'gnnx2.pth')
                lower_level_mlp_dir = os.path.join(self.dir_this_time, 'lower_level_mlp.pth')
                sdgnn_saved_dir = os.path.join(self.dir_this_time, "{:s}.pth".format('sdgnn'))
                weights_saved_dir = os.path.join(self.dir_this_time, "{:s}.pth".format('weights'))
                torch.save(self.neuralnet.tgnn.snn_edge.state_dict(), snn_edge_dest_dir)
                torch.save(self.neuralnet.tgnn.snn_node.state_dict(), snn_node_dest_dir)
                torch.save(self.neuralnet.tgnn.gnnx2.state_dict(), gnnx2_dest_dir) 
                torch.save(self.neuralnet.mlp.state_dict(), lower_level_mlp_dir)
                torch.save(self.neuralnet.transformation_model.state_dict(), sdgnn_saved_dir)
                torch.save(self.weights, weights_saved_dir)

                # Additional information

                torch.save({
                            'max_epoch':1+max_epochs,
                            'epoch': epoch,
                            'model_state_dict': self.neuralnet.tgnn.state_dict(),
                            'optimizer_state_dict': self.optim.state_dict(),
                            'snn_node': self.neuralnet.tgnn.snn_node.state_dict(),
                            'snn_edge': self.neuralnet.tgnn.snn_edge.state_dict(),
                            'gnn_model': self.neuralnet.tgnn.gnnx2.state_dict(),
                            'sdgnn': self.neuralnet.transformation_model.state_dict(),
                            'weights': self.weights
                            }, self.whole_model_path)
                metric_best = metrics[validon]
                improving = True
                num_not_improving = 0
            else:
                #
                improving = False
                num_not_improving = num_not_improving + 1
            print(
                "[{:>{:d}d}/{:d}] {:s}: {:>8s} ({:>8s}) {:s}{:<{:d}s}{:s}"
                .format(
                    epoch, epochlen, max_epochs, validrep,
                    "{:.6f}".format(metrics[validon])[:8],
                    "{:.6f}".format(metric_best)[:8],
                    "\x1b[92m↑\x1b[0m" if improving else "\x1b[91m↓\x1b[0m",
                    "" if improving else str(num_not_improving), noimplen,
                    " --" if num_not_improving == patience else "",
                ),
            )


    def fit_low_level_mlp(
        self,
        proportion: Tuple[int, int, int], priority: Tuple[int, int, int],
        train_prop: Tuple[int, int, bool],
        /,
        *,
        batch_size: int, max_epochs: int, validon: int, validrep: str,
        patience: int,
    ) -> None:
        R"""
        Fit SDGNN of the framework based on initialization status.
        """
        #
        self.preprocess(proportion, priority, train_prop)
        

        #
        timecosts: TIMECOST

        # Zero-out number of tuning epochs for non-parametric cases.
        if self.neuralnet.num_resetted_params == 0:
            #
            max_epochs = 0

        #
        rng = onp.random.RandomState(self.seed)
        timecosts = {}
        meta_index_pad = (
            onp.min(self.meta_indices_train).item() if self.BATCH_PAD else None
        )
        epochlen = len(str(max_epochs))
        logs = []
        metric_best = 0.0
        num_not_improving = 0
        noimplen = len(str(patience))

        # Pin shared memory.
        elapsed = time.time()
        pinned_numpy = self.metaset.pin(batch_size)
        timecosts["pin.generate"] = time.time() - elapsed
        elapsed = time.time()
        pinned_ondev = transfer(pinned_numpy, self.device)
        timecosts["pin.transfer"] = time.time() - elapsed

        # Validate once before training.
        print("=" * 10 + " " + "Train & Validate" + " " + "=" * 10)

        with torch.no_grad():
            #
            (metrics, timeparts) = (
                self.evaluate_sdgnn(
                    self.meta_indices_valid.tolist(), meta_index_pad,
                    batch_size, pinned_ondev,
                )
            )
        
        timecosts["valid.generate"] = timeparts["generate"]
        timecosts["valid.transfer"] = timeparts["transfer"]
        timecosts["valid.forward"] = timeparts["forward"]
        logs.append(metrics)

        # Initialize performance status.
        torch.save(self.neuralnet.state_dict(), self.ptnnp)
        metric_best = metrics[validon]
        num_not_improving = 0
        print(
            "[{:>{:d}d}/{:d}] {:s}: {:>8s} ({:>8s}) {:s}{:<{:d}s}".format(
                0, epochlen, max_epochs, validrep,
                "{:.6f}".format(metrics[validon])[:8],
                "{:.6f}".format(metric_best)[:8],
                "\x1b[92m↑\x1b[0m", "", noimplen,
            ),
        )

        #
        timecosts["train.generate"] = []
        timecosts["train.transfer"] = []
        timecosts["train.forward"] = []
        timecosts["train.backward"] = []

        for epoch in range(1, 1 + max_epochs):
            #
            shuffling = rng.permutation(len(self.meta_indices_train))
            meta_indices_train = self.meta_indices_train[shuffling]
            # Train.
            timeparts = (
                self.continue_train_low_level_mlp(
                    meta_indices_train.tolist(), meta_index_pad, batch_size,
                    pinned_ondev, epoch
                )
            )

            # Validate.
            with torch.no_grad():
                #
                (metrics, timeparts) = (
                    self.evaluate_sdgnn(
                        self.meta_indices_valid.tolist(), meta_index_pad,
                        batch_size, pinned_ondev,
                    )
                )
            timecosts["valid.generate"] = timeparts["generate"]
            timecosts["valid.transfer"] = timeparts["transfer"]
            timecosts["valid.forward"] = timeparts["forward"]
            logs.append(metrics)
            # Update performance status
            if (
                metrics[validon] < metric_best - IMP_ABS
                or metrics[validon] < metric_best * (1 - IMP_REL)
            ):

                #
                torch.save(self.neuralnet.state_dict(), self.model_path)
                torch.save(self.neuralnet.state_dict(), self.ptnnp)
                snn_edge_dest_dir = os.path.join(self.dir_this_time, 'snn_edge.pth')
                snn_node_dest_dir = os.path.join(self.dir_this_time, 'snn_node.pth')
                gnnx2_dest_dir = os.path.join(self.dir_this_time, 'gnnx2.pth')
                lower_level_mlp_dir = os.path.join(self.dir_this_time, 'lower_level_mlp.pth')
                sdgnn_saved_dir = os.path.join(self.dir_this_time, "{:s}.pth".format('sdgnn'))
                weights_saved_dir = os.path.join(self.dir_this_time, "{:s}.pth".format('weights'))
                print(weights_saved_dir)
                print(sdgnn_saved_dir)
                
                torch.save(self.neuralnet.tgnn.snn_edge.state_dict(), snn_edge_dest_dir)
                torch.save(self.neuralnet.tgnn.snn_node.state_dict(), snn_node_dest_dir)
                torch.save(self.neuralnet.tgnn.gnnx2.state_dict(), gnnx2_dest_dir) 
                torch.save(self.neuralnet.mlp.state_dict(), lower_level_mlp_dir)
                torch.save(self.neuralnet.transformation_model.state_dict(), sdgnn_saved_dir)
                torch.save(self.weights, weights_saved_dir)

                # Additional information

                torch.save({
                            'max_epoch':1+max_epochs,
                            'epoch': epoch,
                            'model_state_dict': self.neuralnet.tgnn.state_dict(),
                            'optimizer_state_dict': self.optim.state_dict(),
                            'snn_node': self.neuralnet.tgnn.snn_node.state_dict(),
                            'snn_edge': self.neuralnet.tgnn.snn_edge.state_dict(),
                            'gnn_model': self.neuralnet.tgnn.gnnx2.state_dict(),
                            'sdgnn': self.neuralnet.transformation_model.state_dict(),
                            'weights': self.weights
                            }, self.whole_model_path)
                metric_best = metrics[validon]
                improving = True
                num_not_improving = 0
            else:
                #
                improving = False
                num_not_improving = num_not_improving + 1
            print(
                "[{:>{:d}d}/{:d}] {:s}: {:>8s} ({:>8s}) {:s}{:<{:d}s}{:s}"
                .format(
                    epoch, epochlen, max_epochs, validrep,
                    "{:.6f}".format(metrics[validon])[:8],
                    "{:.6f}".format(metric_best)[:8],
                    "\x1b[92m↑\x1b[0m" if improving else "\x1b[91m↓\x1b[0m",
                    "" if improving else str(num_not_improving), noimplen,
                    " --" if num_not_improving == patience else "",
                ),
            )





    def besteval(
        self,
        proportion: Tuple[int, int, int], priority: Tuple[int, int, int],
        train_prop: Tuple[int, int, bool],
        /,
        batch_size: int, validon: int, validrep: str, resume: str,
    ) -> None:
        R"""
        Best test after training.
        """
        #
        self.preprocess(proportion, priority, train_prop)

        #
        meta_index_pad = (
            onp.min(self.meta_indices_train).item() if self.BATCH_PAD else None
        )

        # Pin shared memory.
        pinned_numpy = self.metaset.pin(batch_size)
        pinned_ondev = transfer(pinned_numpy, self.device)

        # Best test after training.
        print("=" * 10 + " " + "Test (best)" + " " + "=" * 10)
        self.neuralnet.load_state_dict(torch.load(self.ptnnp))
        with torch.no_grad():
            #
            (metrics_valid, _) = (
                self.evaluate(
                    self.meta_indices_valid.tolist(), meta_index_pad,
                    batch_size, pinned_ondev,
                )
            )
            (metrics_test, _) = (
                self.evaluate(
                    self.meta_indices_test.tolist(), meta_index_pad,
                    batch_size, pinned_ondev,
                )
            )
        print(
            "Valid\x1b[94m:\x1b[0m \x1b[3m{:s}\x1b[0m: {:s}"
            .format(validrep, "{:.6f}".format(metrics_valid[validon])[:8]),
        )
        print(
            " Test\x1b[94m:\x1b[0m \x1b[3m{:s}\x1b[0m: {:s}"
            .format(validrep, "{:.6f}".format(metrics_test[validon])[:8]),
        )

        #
        print("=" * 10 + " " + "Relog" + " " + "=" * 10)
        (factors, logs, _, _, gpu_mem_peak, timecosts) = torch.load(self.ptlog)
        torch.save(
            (
                factors, logs, metrics_valid[validon], metrics_test,
                gpu_mem_peak, timecosts,
            ),
            self.ptbev,
        )
        torch.save(self.neuralnet, self.model_path)
        # Write to the file
        with open(self.save_result, 'w') as file:
            file.write("Valid:\n")
            file.write("mse, rmse, mape, mae,\n")
            for item in metrics_valid:
                file.write("{:s},".format("{:.7f}".format(item)))
            file.write('\n')

            file.write("Test:\n")
            file.write("mse, rmse, mape, mae,\n")
            for item in metrics_test:
                file.write("{:s},".format("{:.7f}".format(item)))
            file.write('\n')

    def besteval_sdgnn(
        self,
        proportion: Tuple[int, int, int], priority: Tuple[int, int, int],
        train_prop: Tuple[int, int, bool],
        /,
        batch_size: int, validon: int, validrep: str, resume: str,
    ) -> None:
        R"""
        Best test after training.
        """
        #
        self.preprocess(proportion, priority, train_prop)

        #
        meta_index_pad = (
            onp.min(self.meta_indices_train).item() if self.BATCH_PAD else None
        )

        # Pin shared memory.
        pinned_numpy = self.metaset.pin(batch_size)
        pinned_ondev = transfer(pinned_numpy, self.device)

        # Best test after training.
        print("=" * 10 + " " + "Test (best)" + " " + "=" * 10)
        # self.neuralnet.load_state_dict(torch.load(self.ptnnp))
        with torch.no_grad():
            #
            (metrics_valid, _) = (
                self.evaluate_sdgnn(
                    self.meta_indices_valid.tolist(), meta_index_pad,
                    batch_size, pinned_ondev,
                )
            )
            (metrics_test, _) = (
                self.evaluate_sdgnn(
                    self.meta_indices_test.tolist(), meta_index_pad,
                    batch_size, pinned_ondev,
                )
            )
        print(
            "Valid\x1b[94m:\x1b[0m \x1b[3m{:s}\x1b[0m: {:s}"
            .format(validrep, "{:.6f}".format(metrics_valid[validon])[:8]),
        )
        print(
            " Test\x1b[94m:\x1b[0m \x1b[3m{:s}\x1b[0m: {:s}"
            .format(validrep, "{:.6f}".format(metrics_test[validon])[:8]),
        )

        #
        print("=" * 10 + " " + "Relog" + " " + "=" * 10)
        (factors, logs, _, _, gpu_mem_peak, timecosts) = torch.load(self.ptlog)
        torch.save(
            (
                factors, logs, metrics_valid[validon], metrics_test,
                gpu_mem_peak, timecosts,
            ),
            self.ptbev,
        )
        torch.save(self.neuralnet, self.model_path)
        # Write to the file
        with open(self.save_result, 'w') as file:
            file.write("Valid:\n")
            file.write("mse, rmse, mape, mae,\n")
            for item in metrics_valid:
                file.write("{:s},".format("{:.7f}".format(item)))
            file.write('\n')

            file.write("Test:\n")
            file.write("mse, rmse, mape, mae,\n")
            for item in metrics_test:
                file.write("{:s},".format("{:.7f}".format(item)))
            file.write('\n')
