import os
import sys
import warnings
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.utils import resample
from sklearn.utils.class_weight import compute_class_weight
from torch_geometric.loader import DataListLoader
from tqdm import tqdm

from data.dataset import SceneGraphDataset
from learning.util.metrics import (
    get_graph_metrics,
    get_metrics,
    log_wandb,
    log_wandb_transfer_learning,
)
from learning.util.trainer import Trainer

sys.path.append(os.path.dirname(sys.path[0]))
warnings.simplefilter(action="ignore", category=FutureWarning)


class CURVE_Trainer(Trainer):
    def __init__(self, config, wandb_a=None):
        super().__init__(config, wandb_a)
        self.scene_graph_dataset = SceneGraphDataset()
        self.feature_list = [f"type_{i}" for i in range(self.config.model_config["num_of_classes"])]
        self.device = self.config.model_config["device"]
        self.best_mcc = -1.0

    @staticmethod
    def _gaussian_kl_to_standard_normal(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """
        KL( N(mu, diag(exp(logvar))) || N(0, I) ), averaged over all elements.
        """
        return 0.5 * torch.mean(torch.exp(logvar) + mu.pow(2) - 1.0 - logvar)

    def split_dataset(self):
        task_type = self.config.training_config["task_type"]
        if task_type not in ["sequence_classification", "graph_classification", "collision_prediction"]:
            raise ValueError("split_dataset(): task type error")

        self.training_data, self.testing_data = self.build_scenegraph_dataset()

        self.total_train_labels = np.concatenate(
            [np.full(len(d["sequence"]), d["label"]) for d in self.training_data]
        )
        self.total_test_labels = np.concatenate(
            [np.full(len(d["sequence"]), d["label"]) for d in self.testing_data]
        )

        self.training_labels = [d["label"] for d in self.training_data]
        self.testing_labels = [d["label"] for d in self.testing_data]

        if task_type == "sequence_classification":
            self.class_weights = torch.from_numpy(
                compute_class_weight(
                    class_weight="balanced",
                    classes=np.unique(self.training_labels),
                    y=self.training_labels,
                )
            )
            if self.config.training_config["n_fold"] <= 1:
                print("Number of Sequences Included: ", len(self.training_data))
                print("Number of Testing Sequences Included: ", len(self.testing_data))
                print(
                    "Num Labels in Each Class: "
                    + str(np.unique(self.training_labels, return_counts=True)[1])
                    + ", Class Weights: "
                    + str(self.class_weights)
                )
                print(
                    "Number of Testing Labels in Each Class: "
                    + str(np.unique(self.testing_labels, return_counts=True)[1])
                    + ", Class Weights: "
                    + str(self.class_weights)
                )

        elif task_type == "collision_prediction":
            self.class_weights = torch.from_numpy(
                compute_class_weight(
                    class_weight="balanced",
                    classes=np.unique(self.total_train_labels),
                    y=self.total_train_labels,
                )
            )
            if self.config.training_config["n_fold"] <= 1:
                print("Number of Training Sequences Included: ", len(self.training_data))
                print("Number of Testing Sequences Included: ", len(self.testing_data))
                print(
                    "Number of Training Labels in Each Class: "
                    + str(np.unique(self.total_train_labels, return_counts=True)[1])
                    + ", Class Weights: "
                    + str(self.class_weights)
                )
                print(
                    "Number of Testing Labels in Each Class: "
                    + str(np.unique(self.total_test_labels, return_counts=True)[1])
                    + ", Class Weights: "
                    + str(self.class_weights)
                )

    def build_transfer_learning_dataset(self):
        scene_graph_dataset = SceneGraphDataset()
        scene_graph_dataset.dataset_save_path = self.config.location_data["transfer_path"]
        self.scene_graph_dataset = scene_graph_dataset.load()

        self.transfer_data = []
        sorted_seq = sorted(self.scene_graph_dataset.labels)

        dataset_type = self.config.training_config["scenegraph_dataset_type"]
        if dataset_type == "carla":
            for ind, seq in enumerate(sorted_seq):
                data_to_append = {
                    "sequence": self.scene_graph_dataset.process_carla_graph_sequences(
                        self.scene_graph_dataset.scene_graphs[seq],
                        self.feature_list,
                        folder_name=self.scene_graph_dataset.folder_names[ind],
                    ),
                    "label": self.scene_graph_dataset.labels[seq],
                    "folder_name": self.scene_graph_dataset.folder_names[ind],
                }
                self.transfer_data.append(data_to_append)

        elif dataset_type == "real":
            for ind, seq in enumerate(sorted_seq):
                data_to_append = {
                    "sequence": self.scene_graph_dataset.scene_graphs[seq],
                    "label": self.scene_graph_dataset.labels[seq],
                    "folder_name": self.scene_graph_dataset.folder_names[ind],
                }
                self.transfer_data.append(data_to_append)
        else:
            raise ValueError("dataset_type unrecognized")

        self.total_transfer_data_labels = np.concatenate(
            [np.full(len(d["sequence"]), d["label"]) for d in self.transfer_data]
        )
        self.transfer_data_labels = [d["label"] for d in self.transfer_data]

        print("Running transfer learning on dataset: ", self.config.location_data["transfer_path"])
        print("total labels: ", len(self.transfer_data_labels))

        self.class_weights = torch.tensor([1.0, 1.0])

    def evaluate_transfer_learning(self, current_epoch=None):
        metrics = {}
        (
            outputs_test,
            labels_test,
            acc_loss_test,
            _,
            _,
            val_avg_prediction_frame,
            val_avg_seq_len,
            avg_predicted_risky_indices,
            avg_predicted_safe_indices,
            test_inference_time,
            _,
            seq_tpr,
            seq_fpr,
            seq_tnr,
            seq_fnr,
            _,
        ) = self.inference(self.transfer_data, self.transfer_data_labels)

        metrics["test"] = get_metrics(outputs_test, labels_test)
        metrics["test"]["loss"] = acc_loss_test
        metrics["test"]["avg_prediction_frame"] = val_avg_prediction_frame
        metrics["test"]["avg_seq_len"] = val_avg_seq_len
        metrics["test"]["avg_predicted_risky_indices"] = avg_predicted_risky_indices
        metrics["test"]["avg_predicted_safe_indices"] = avg_predicted_safe_indices
        metrics["test"]["seq_tpr"] = seq_tpr
        metrics["test"]["seq_tnr"] = seq_tnr
        metrics["test"]["seq_fpr"] = seq_fpr
        metrics["test"]["seq_fnr"] = seq_fnr
        metrics["avg_inf_time"] = test_inference_time / (len(labels_test))

        self.update_sg_best_metrics(metrics, current_epoch)

        metrics["best_epoch"] = self.best_epoch
        metrics["best_val_loss"] = self.best_val_loss
        metrics["best_val_acc"] = self.best_val_acc
        metrics["best_val_auc"] = self.best_val_auc
        metrics["best_val_conf"] = self.best_val_confusion
        metrics["best_val_f1"] = self.best_val_f1
        metrics["best_val_mcc"] = self.best_val_mcc
        metrics["best_val_acc_balanced"] = self.best_val_acc_balanced
        metrics["best_avg_pred_frame"] = self.best_avg_pred_frame

        if self.config.training_config["n_fold"] <= 1 and self.log:
            log_wandb_transfer_learning(metrics)
            print(metrics)

        return outputs_test, labels_test, metrics

    def build_scenegraph_dataset(self):
        scene_graph_dataset = SceneGraphDataset()
        scene_graph_dataset.dataset_save_path = self.config.location_data["input_path"]
        self.scene_graph_dataset = scene_graph_dataset.load()

        class_0, class_1 = [], []
        sorted_seq = sorted(self.scene_graph_dataset.labels)

        dataset_type = self.config.training_config["scenegraph_dataset_type"]
        data_format = self.config.training_config.get("scenegraph_data_format", None)

        if dataset_type == "carla":
            for ind, seq in enumerate(sorted_seq):
                data_to_append = {
                    "sequence": self.scene_graph_dataset.scene_graphs[seq],
                    "label": self.scene_graph_dataset.labels[seq],
                    "folder_name": self.scene_graph_dataset.folder_names[ind],
                }
                (class_0 if self.scene_graph_dataset.labels[seq] == 0 else class_1).append(data_to_append)

        elif dataset_type == "real":
            if data_format == "carla":
                for ind, seq in enumerate(sorted_seq):
                    _ = sorted(self.scene_graph_dataset.scene_graphs[seq].keys())[0]
                    data_to_append = {
                        "sequence": self.scene_graph_dataset.scene_graphs[seq],
                        "label": self.scene_graph_dataset.labels[seq],
                        "folder_name": self.scene_graph_dataset.folder_names[ind],
                    }
                    (class_0 if self.scene_graph_dataset.labels[seq] == 0 else class_1).append(data_to_append)

            elif data_format == "honda":
                for ind, seq in enumerate(sorted_seq):
                    data_to_append = {
                        "sequence": self.scene_graph_dataset.scene_graphs[seq],
                        "label": self.scene_graph_dataset.labels[seq],
                        "folder_name": self.scene_graph_dataset.folder_names[ind],
                    }
                    (class_0 if self.scene_graph_dataset.labels[seq] == 0 else class_1).append(data_to_append)
            else:
                raise ValueError("scenegraph_data_format not recognized")
        else:
            raise ValueError("scenegraph_dataset_type not recognized")

        y_0 = [0] * len(class_0)
        y_1 = [1] * len(class_1)

        min_number = min(len(class_0), len(class_1))
        downsample = self.config.training_config["downsample"]

        if downsample:
            modified_class_0, modified_y_0 = resample(class_0, y_0, n_samples=min_number)
        else:
            modified_class_0, modified_y_0 = class_0, y_0

        train, test, _, _ = train_test_split(
            modified_class_0 + class_1,
            modified_y_0 + y_1,
            test_size=self.config.training_config["split_ratio"],
            shuffle=True,
            stratify=modified_y_0 + y_1,
            random_state=self.config.seed,
        )

        if self.config.location_data["transfer_path"] is not None:
            self.build_transfer_learning_dataset()

        return train, test

    def preprocess_nodes(self, nodes: Dict) -> List[torch.FloatTensor]:
        actor_names = self.config.relation_extraction_settings["ACTOR_NAMES"]
        node_feature_list = []

        for node_id in nodes.keys():
            if nodes[node_id]["label"] in ["Root Road"]:
                continue

            if nodes[node_id]["label"] == "ego car":
                nodes[node_id]["attr"].update(
                    {"left": 640 - 100, "top": 620, "right": 640 + 100, "bottom": 720}
                )
                nodes[node_id]["label"] = "ego car_"

            if nodes[node_id]["label"] == "Left Lane":
                nodes[node_id]["attr"].update({"left": 0, "top": 0, "right": 0, "bottom": 0})
                nodes[node_id]["label"] = "lane_l"

            if nodes[node_id]["label"] == "Right Lane":
                nodes[node_id]["attr"].update({"left": 1280, "top": 0, "right": 1280, "bottom": 0})
                nodes[node_id]["label"] = "lane_r"

            if nodes[node_id]["label"] == "Middle Lane":
                nodes[node_id]["attr"].update({"left": 640, "top": 0, "right": 640, "bottom": 0})
                nodes[node_id]["label"] = "lane_m"

            node_features = [0.0] * 15
            node_features[0] = nodes[node_id]["attr"]["left"] / 1280
            node_features[1] = nodes[node_id]["attr"]["top"] / 720
            node_features[2] = nodes[node_id]["attr"]["right"] / 1280
            node_features[3] = nodes[node_id]["attr"]["bottom"] / 720
            node_features[4] = node_features[2] - node_features[0]
            node_features[5] = node_features[3] - node_features[1]

            actor_key = nodes[node_id]["label"].split("_")[0]
            node_features[6 + actor_names.index(actor_key)] = 1

            node_feature_list.append(torch.FloatTensor(node_features))

        return node_feature_list

    def train(self):
        task_type = self.config.training_config["task_type"]
        if task_type not in ["sequence_classification", "graph_classification", "collision_prediction"]:
            raise ValueError("train(): task type error")

        tqdm_bar = tqdm(range(self.config.training_config["epochs"]))
        warmup_epochs = 50

        gamma_div = float(self.config.training_config.get("gamma_div", 0.1))
        pos_weight = float(self.config.training_config.get("pos_weight", 2.0))
        rank_margin = float(self.config.training_config.get("rank_margin", 0.1))
        min_std_threshold = float(self.config.training_config.get("min_std_threshold", 0.5))

        lambda_kl = float(self.config.training_config.get("lambda_kl", 1e-3))

        for epoch_idx in tqdm_bar:
            acc_loss_train = 0.0
            self.sequence_loader = DataListLoader(
                self.training_data, batch_size=self.config.training_config["batch_size"]
            )

            for data_list in self.sequence_loader:
                self.model.train()
                self.optimizer.zero_grad()

                labels = torch.empty(0, dtype=torch.long, device=self.device)
                all_logits_mu = torch.empty(0, 2, device=self.device)
                all_logits_logvar = torch.empty(0, 2, device=self.device)

                batch_div_loss = 0.0
                batch_alpha_list = []

                for sequence in data_list:
                    data, label = sequence["sequence"], sequence["label"]
                    new_sequence = []

                    data = [data[frame] for frame in sorted(data.keys())] if isinstance(data, dict) else data
                    for d in data:
                        g = d.g
                        nodes = dict(g.nodes(data=True))
                        node_feature_list = self.preprocess_nodes(nodes)
                        new_sequence.append(torch.stack(node_feature_list).to(self.device))

                    result = self.model.forward(new_sequence)
                    logits_mu = result["logits_mu"]
                    logits_logvar = result["logits_logvar"]

                    if "diversity_loss" in result:
                        batch_div_loss += result["diversity_loss"]
                    if "mean_alpha" in result:
                        batch_alpha_list.append(result["mean_alpha"])

                    if task_type == "sequence_classification":
                        labels = torch.cat(
                            [labels, torch.tensor(label, device=self.device).long().unsqueeze(0)],
                            dim=0,
                        )
                    elif task_type == "collision_prediction":
                        label_tensor = torch.tensor(np.full(logits_mu.shape[0], label), device=self.device).long()
                        labels = torch.cat([labels, label_tensor], dim=0)
                    else:
                        raise ValueError("task_type is unimplemented")

                    all_logits_mu = torch.cat([all_logits_mu, logits_mu.view(-1, 2)], dim=0)
                    all_logits_logvar = torch.cat([all_logits_logvar, logits_logvar.view(-1, 2)], dim=0)

                if epoch_idx < warmup_epochs:
                    loss_train = F.cross_entropy(all_logits_mu, labels, label_smoothing=0.1)
                    if len(batch_alpha_list) > 0:
                        avg_alpha_val = torch.stack(batch_alpha_list).mean().item()
                    else:
                        avg_alpha_val = 0.0
                else:
                    div_loss = batch_div_loss / max(len(data_list), 1)
                    pred_loss = self.loss_func(all_logits_mu, all_logits_logvar, labels)
                    kl_loss = self._gaussian_kl_to_standard_normal(all_logits_mu, all_logits_logvar)
                    loss_train = pred_loss + gamma_div * div_loss + lambda_kl * kl_loss

                    pos_mask = labels == 1
                    if pos_mask.any():
                        pos_loss = self.loss_func(
                            all_logits_mu[pos_mask],
                            all_logits_logvar[pos_mask],
                            labels[pos_mask],
                        )
                        loss_train = loss_train + pos_weight * pos_loss

                    preds = all_logits_mu.argmax(dim=1)
                    correct_mask = preds == labels
                    wrong_mask = ~correct_mask

                    if correct_mask.any() and wrong_mask.any():
                        unc_correct = torch.exp(0.5 * all_logits_logvar[correct_mask]).mean()
                        unc_wrong = torch.exp(0.5 * all_logits_logvar[wrong_mask]).mean()
                        loss_train = loss_train + torch.relu(unc_correct - unc_wrong + rank_margin)

                        loss_collapse = torch.relu(min_std_threshold - unc_wrong)
                        loss_train = loss_train + loss_collapse

                    if len(batch_alpha_list) > 0:
                        avg_alpha_val = torch.stack(batch_alpha_list).mean().item()
                    else:
                        avg_alpha_val = 0.0

                loss_train.backward()
                acc_loss_train += loss_train.detach().cpu().item() * len(data_list)
                self.optimizer.step()

            acc_loss_train /= max(len(self.training_data), 1)
            tqdm_bar.set_description(
                f"Epoch: {epoch_idx:04d}, Loss: {acc_loss_train:.4f}, Alpha: {avg_alpha_val:.3f}"
            )

            if epoch_idx % self.config.training_config["test_step"] == 0:
                self.evaluate(epoch_idx)

    def cross_valid(self):
        skf = StratifiedKFold(n_splits=self.config.training_config["n_fold"])
        X = np.array(self.training_data + self.testing_data)
        y = np.array(self.training_labels + self.testing_labels)

        self.results = {}
        self.fold = 1

        for train_index, test_index in skf.split(X, y):
            X_train, X_test = X[train_index], X[test_index]
            y_train, y_test = y[train_index], y[test_index]

            self.training_data = X_train
            self.testing_data = X_test
            self.training_labels = y_train
            self.testing_labels = y_test

            task_type = self.config.training_config["task_type"]
            if task_type == "sequence_classification":
                print(f"\nFold {self.fold}")
                print("Number of Sequences Included: ", len(self.training_data))
                print(
                    "Num Labels in Each Class: "
                    + str(np.unique(self.training_labels, return_counts=True)[1])
                    + ", Class Weights: "
                    + str(self.class_weights)
                )
            elif task_type == "collision_prediction":
                print(f"\nFold {self.fold}")
                print("Number of Training Sequences Included: ", len(self.training_data))
                print("Number of Testing Sequences Included: ", len(self.testing_data))
                print(
                    "Number of Training Labels in Each Class: "
                    + str(np.unique(self.total_train_labels, return_counts=True)[1])
                    + ", Class Weights: "
                    + str(self.class_weights)
                )
                print(
                    "Number of Testing Labels in Each Class: "
                    + str(np.unique(self.total_test_labels, return_counts=True)[1])
                    + ", Class Weights: "
                    + str(self.class_weights)
                )

            self.best_val_loss = 99999
            self.train()

            outputs_train, labels_train, outputs_test, labels_test, metrics = self.evaluate(self.fold)
            self.update_sg_cross_valid_metrics(outputs_train, labels_train, outputs_test, labels_test, metrics)

            if self.fold != self.config.training_config["n_fold"]:
                del self.model
                del self.optimizer
                self.build_model()

            self.fold += 1

        del self.results

    def inference(self, testing_data, testing_labels):
        labels = torch.LongTensor().to(self.device)
        outputs = torch.FloatTensor().to(self.device)
        uncertainties = torch.FloatTensor().to(self.device)

        acc_loss_test = 0.0
        attns_weights = []
        node_attns = []
        folder_names = []

        sum_prediction_frame = 0
        sum_seq_len = 0
        num_risky_sequences = 0
        num_safe_sequences = 0
        sum_predicted_risky_indices = 0
        sum_predicted_safe_indices = 0

        inference_time = 0.0
        prof_result = ""

        correct_risky_seq = 0
        correct_safe_seq = 0
        incorrect_risky_seq = 0
        incorrect_safe_seq = 0

        all_graph_lists = []

        if self.config.model_config.get("log_graphs", False):
            os.makedirs(self.config.model_config["ss_graph_path"] + "correct", exist_ok=True)
            os.makedirs(self.config.model_config["ss_graph_path"] + "incorrect", exist_ok=True)
            os.makedirs(self.config.model_config["orig_graph_path"] + "correct", exist_ok=True)
            os.makedirs(self.config.model_config["orig_graph_path"] + "incorrect", exist_ok=True)

        with torch.autograd.profiler.profile(enabled=False, use_cuda=True) as prof:
            self.model.eval()
            with torch.no_grad():
                for i in range(len(testing_data)):
                    data, label = testing_data[i]["sequence"], testing_labels[i]
                    folder_name = testing_data[i]["folder_name"]
                    folder_names.append(folder_name)

                    new_sequence = []
                    data = [data[frame] for frame in sorted(data.keys())] if isinstance(data, dict) else data
                    for d in data:
                        g = d.g
                        nodes = dict(g.nodes(data=True))
                        node_feature_list = self.preprocess_nodes(nodes)
                        new_sequence.append(torch.stack(node_feature_list).to(self.device))

                    out = self.model.forward(new_sequence)

                    output = out["output"].view(-1, 2)
                    uncertainty = out["uncertainty"].view(-1)
                    graph_list = out["graph_list"]

                    label_tensor = torch.tensor(np.full(output.shape[0], label), device=self.device).long()

                    if self.config.model_config.get("log_graphs", False):
                        preds_seq = output.max(1)[1].type_as(label_tensor)
                        if torch.equal(label_tensor, preds_seq):
                            torch.save(graph_list, self.config.model_config["ss_graph_path"] + "correct/" + folder_name + ".pt")
                            torch.save(data, self.config.model_config["orig_graph_path"] + "correct/" + folder_name + ".pt")
                        else:
                            torch.save(graph_list, self.config.model_config["ss_graph_path"] + "incorrect/" + folder_name + ".pt")
                            torch.save(data, self.config.model_config["orig_graph_path"] + "incorrect/" + folder_name + ".pt")
                        all_graph_lists.append(graph_list)

                    outputs = torch.cat([outputs, output], dim=0)
                    labels = torch.cat([labels, label_tensor], dim=0)
                    uncertainties = torch.cat([uncertainties, uncertainty], dim=0)

                    preds = output.max(1)[1].type_as(label_tensor)

                    if (label_tensor == 1).any():
                        num_risky_sequences += 1
                        sum_seq_len += output.shape[0]
                        if (preds == 1).any():
                            correct_risky_seq += 1
                            sum_prediction_frame += torch.where(preds == 1)[0][0].item()
                            sum_predicted_risky_indices += (
                                torch.sum(torch.where(preds == 1)[0] + 1).item()
                                / np.sum(range(output.shape[0] + 1))
                            )
                        else:
                            incorrect_risky_seq += 1
                            sum_prediction_frame += output.shape[0]
                    else:
                        num_safe_sequences += 1
                        if (preds == 1).any():
                            incorrect_safe_seq += 1
                        else:
                            correct_safe_seq += 1
                        if (preds == 0).any():
                            sum_predicted_safe_indices += (
                                torch.sum(torch.where(preds == 0)[0] + 1).item()
                                / np.sum(range(output.shape[0] + 1))
                            )

                    loss_test = F.cross_entropy(output, label_tensor)
                    acc_loss_test += loss_test.detach().cpu().item()

        flat_graphs = [g for sublist in all_graph_lists for g in sublist] if len(all_graph_lists) > 0 else []
        graph_metrics = get_graph_metrics(flat_graphs) if len(flat_graphs) > 0 else {}
        print(graph_metrics)

        avg_risky_prediction_frame = sum_prediction_frame / num_risky_sequences if num_risky_sequences else 0
        avg_risky_seq_len = sum_seq_len / num_risky_sequences if num_risky_sequences else 0
        avg_predicted_risky_indices = sum_predicted_risky_indices / num_risky_sequences if num_risky_sequences else 0
        avg_predicted_safe_indices = sum_predicted_safe_indices / num_safe_sequences if num_safe_sequences else 0

        seq_tpr = correct_risky_seq / num_risky_sequences if num_risky_sequences else 0
        seq_fpr = incorrect_safe_seq / num_safe_sequences if num_safe_sequences else 0
        seq_tnr = correct_safe_seq / num_safe_sequences if num_safe_sequences else 0
        seq_fnr = incorrect_risky_seq / num_risky_sequences if num_risky_sequences else 0

        if prof is not None:
            prof_result = prof.key_averages().table(sort_by="cuda_time_total")

        return (
            outputs,
            labels,
            acc_loss_test / max(len(testing_data), 1),
            attns_weights,
            node_attns,
            avg_risky_prediction_frame,
            avg_risky_seq_len,
            avg_predicted_risky_indices,
            avg_predicted_safe_indices,
            inference_time,
            prof_result,
            seq_tpr,
            seq_fpr,
            seq_tnr,
            seq_fnr,
            folder_names,
        )

    def evaluate(self, current_epoch=None):
        metrics = {}

        (
            outputs_train,
            labels_train,
            acc_loss_train,
            attns_train,
            node_attns_train,
            train_avg_prediction_frame,
            train_avg_seq_len,
            avg_predicted_risky_indices,
            avg_predicted_safe_indices,
            train_inference_time,
            train_profiler_result,
            seq_tpr,
            seq_fpr,
            seq_tnr,
            seq_fnr,
            _,
        ) = self.inference(self.training_data, self.training_labels)

        torch.cuda.empty_cache()

        metrics["train"] = get_metrics(outputs_train, labels_train)
        metrics["train"]["loss"] = acc_loss_train
        metrics["train"]["avg_prediction_frame"] = train_avg_prediction_frame
        metrics["train"]["avg_seq_len"] = train_avg_seq_len
        metrics["train"]["avg_predicted_risky_indices"] = avg_predicted_risky_indices
        metrics["train"]["avg_predicted_safe_indices"] = avg_predicted_safe_indices
        metrics["train"]["seq_tpr"] = seq_tpr
        metrics["train"]["seq_tnr"] = seq_tnr
        metrics["train"]["seq_fpr"] = seq_fpr
        metrics["train"]["seq_fnr"] = seq_fnr

        with open("graph_profile_metrics.txt", mode="w") as f:
            f.write(train_profiler_result)

        (
            outputs_test,
            labels_test,
            acc_loss_test,
            attns_test,
            node_attns_test,
            val_avg_prediction_frame,
            val_avg_seq_len,
            avg_predicted_risky_indices,
            avg_predicted_safe_indices,
            test_inference_time,
            test_profiler_result,
            seq_tpr,
            seq_fpr,
            seq_tnr,
            seq_fnr,
            _,
        ) = self.inference(self.testing_data, self.testing_labels)

        metrics["test"] = get_metrics(outputs_test, labels_test)
        metrics["test"]["loss"] = acc_loss_test
        metrics["test"]["avg_prediction_frame"] = val_avg_prediction_frame
        metrics["test"]["avg_seq_len"] = val_avg_seq_len
        metrics["test"]["avg_predicted_risky_indices"] = avg_predicted_risky_indices
        metrics["test"]["avg_predicted_safe_indices"] = avg_predicted_safe_indices
        metrics["test"]["seq_tpr"] = seq_tpr
        metrics["test"]["seq_tnr"] = seq_tnr
        metrics["test"]["seq_fpr"] = seq_fpr
        metrics["test"]["seq_fnr"] = seq_fnr
        metrics["avg_inf_time"] = (train_inference_time + test_inference_time) / (
            len(labels_train) + len(labels_test)
        )

        print(
            "\ntrain loss: "
            + str(acc_loss_train)
            + ", acc:"
            + str(metrics["train"]["acc"])
            + " "
            + str(metrics["train"]["confusion"])
            + " mcc:"
            + str(metrics["train"]["mcc"])
            + "\ntest loss: "
            + str(acc_loss_test)
            + ", acc:"
            + str(metrics["test"]["acc"])
            + " "
            + str(metrics["test"]["confusion"])
            + " mcc:"
            + str(metrics["test"]["mcc"])
        )

        self.update_sg_best_metrics(metrics, current_epoch)

        metrics["best_epoch"] = self.best_epoch
        metrics["best_val_loss"] = self.best_val_loss
        metrics["best_val_acc"] = self.best_val_acc
        metrics["best_val_auc"] = self.best_val_auc
        metrics["best_val_conf"] = self.best_val_confusion
        metrics["best_val_f1"] = self.best_val_f1
        metrics["best_val_mcc"] = self.best_val_mcc
        metrics["best_val_acc_balanced"] = self.best_val_acc_balanced
        metrics["best_avg_pred_frame"] = self.best_avg_pred_frame

        if self.config.training_config["n_fold"] <= 1 and self.log:
            log_wandb(metrics)

        torch.cuda.empty_cache()
        return outputs_train, labels_train, outputs_test, labels_test, metrics

    def update_sg_best_metrics(self, metrics, current_epoch):
        current_mcc = metrics["test"]["mcc"]
        if current_mcc > self.best_mcc:
            self.best_mcc = current_mcc
            self.best_val_mcc = metrics["test"]["mcc"]
            self.best_val_loss = metrics["test"]["loss"]
            self.best_epoch = current_epoch if current_epoch is not None else self.config.training_config["epochs"]
            self.best_val_acc = metrics["test"]["acc"]
            self.best_val_auc = metrics["test"]["auc"]
            self.best_val_confusion = metrics["test"]["confusion"]
            self.best_val_f1 = metrics["test"]["f1"]
            self.best_val_acc_balanced = metrics["test"]["balanced_acc"]
            self.best_avg_pred_frame = metrics["test"]["avg_prediction_frame"]
            self.save_model(suffix=f"_mcc_{current_mcc:.4f}")

    def update_sg_cross_valid_metrics(self, outputs_train, labels_train, outputs_test, labels_test, metrics):
        if self.fold == 1:
            self.results["outputs_train"] = outputs_train
            self.results["labels_train"] = labels_train
            self.results["train"] = metrics["train"]
            self.results["train"]["loss"] = metrics["train"]["loss"]
            self.results["train"]["avg_prediction_frame"] = metrics["train"]["avg_prediction_frame"]
            self.results["train"]["avg_seq_len"] = metrics["train"]["avg_seq_len"]
            self.results["train"]["avg_predicted_risky_indices"] = metrics["train"]["avg_predicted_risky_indices"]
            self.results["train"]["avg_predicted_safe_indices"] = metrics["train"]["avg_predicted_safe_indices"]

            self.results["outputs_test"] = outputs_test
            self.results["labels_test"] = labels_test
            self.results["test"] = metrics["test"]
            self.results["test"]["loss"] = metrics["test"]["loss"]
            self.results["test"]["avg_prediction_frame"] = metrics["test"]["avg_prediction_frame"]
            self.results["test"]["avg_seq_len"] = metrics["test"]["avg_seq_len"]
            self.results["test"]["avg_predicted_risky_indices"] = metrics["test"]["avg_predicted_risky_indices"]
            self.results["test"]["avg_predicted_safe_indices"] = metrics["test"]["avg_predicted_safe_indices"]
            self.results["avg_inf_time"] = metrics["avg_inf_time"]

            self.results["best_epoch"] = metrics["best_epoch"]
            self.results["best_val_loss"] = metrics["best_val_loss"]
            self.results["best_val_acc"] = metrics["best_val_acc"]
            self.results["best_val_auc"] = metrics["best_val_auc"]
            self.results["best_val_conf"] = metrics["best_val_conf"]
            self.results["best_val_f1"] = metrics["best_val_f1"]
            self.results["best_val_mcc"] = metrics["best_val_mcc"]
            self.results["best_val_acc_balanced"] = metrics["best_val_acc_balanced"]
            self.results["best_avg_pred_frame"] = metrics["best_avg_pred_frame"]
        else:
            self.results["outputs_train"] = torch.cat((self.results["outputs_train"], outputs_train), dim=0)
            self.results["labels_train"] = torch.cat((self.results["labels_train"], labels_train), dim=0)
            self.results["train"]["loss"] = np.append(self.results["train"]["loss"], metrics["train"]["loss"])
            self.results["train"]["avg_prediction_frame"] = np.append(
                self.results["train"]["avg_prediction_frame"], metrics["train"]["avg_prediction_frame"]
            )
            self.results["train"]["avg_seq_len"] = np.append(self.results["train"]["avg_seq_len"], metrics["train"]["avg_seq_len"])
            self.results["train"]["avg_predicted_risky_indices"] = np.append(
                self.results["train"]["avg_predicted_risky_indices"], metrics["train"]["avg_predicted_risky_indices"]
            )
            self.results["train"]["avg_predicted_safe_indices"] = np.append(
                self.results["train"]["avg_predicted_safe_indices"], metrics["train"]["avg_predicted_safe_indices"]
            )

            self.results["outputs_test"] = torch.cat((self.results["outputs_test"], outputs_test), dim=0)
            self.results["labels_test"] = torch.cat((self.results["labels_test"], labels_test), dim=0)
            self.results["test"]["loss"] = np.append(self.results["test"]["loss"], metrics["test"]["loss"])
            self.results["test"]["avg_prediction_frame"] = np.append(
                self.results["test"]["avg_prediction_frame"], metrics["test"]["avg_prediction_frame"]
            )
            self.results["test"]["avg_seq_len"] = np.append(self.results["test"]["avg_seq_len"], metrics["test"]["avg_seq_len"])
            self.results["test"]["avg_predicted_risky_indices"] = np.append(
                self.results["test"]["avg_predicted_risky_indices"], metrics["test"]["avg_predicted_risky_indices"]
            )
            self.results["test"]["avg_predicted_safe_indices"] = np.append(
                self.results["test"]["avg_predicted_safe_indices"], metrics["test"]["avg_predicted_safe_indices"]
            )
            self.results["avg_inf_time"] = np.append(self.results["avg_inf_time"], metrics["avg_inf_time"])

            self.results["best_epoch"] = np.append(self.results["best_epoch"], metrics["best_epoch"])
            self.results["best_val_loss"] = np.append(self.results["best_val_loss"], metrics["best_val_loss"])
            self.results["best_val_acc"] = np.append(self.results["best_val_acc"], metrics["best_val_acc"])
            self.results["best_val_auc"] = np.append(self.results["best_val_auc"], metrics["best_val_auc"])
            self.results["best_val_conf"] = np.append(self.results["best_val_conf"], metrics["best_val_conf"])
            self.results["best_val_f1"] = np.append(self.results["best_val_f1"], metrics["best_val_f1"])
            self.results["best_val_mcc"] = np.append(self.results["best_val_mcc"], metrics["best_val_mcc"])
            self.results["best_val_acc_balanced"] = np.append(
                self.results["best_val_acc_balanced"], metrics["best_val_acc_balanced"]
            )
            self.results["best_avg_pred_frame"] = np.append(self.results["best_avg_pred_frame"], metrics["best_avg_pred_frame"])

        if self.fold == self.config.training_config["n_fold"]:
            final_results = {}

            final_results["train"] = get_metrics(self.results["outputs_train"], self.results["labels_train"])
            final_results["train"]["loss"] = np.average(self.results["train"]["loss"])
            final_results["train"]["avg_prediction_frame"] = np.average(self.results["train"]["avg_prediction_frame"])
            final_results["train"]["avg_seq_len"] = np.average(self.results["train"]["avg_seq_len"])
            final_results["train"]["avg_predicted_risky_indices"] = np.average(self.results["train"]["avg_predicted_risky_indices"])
            final_results["train"]["avg_predicted_safe_indices"] = np.average(self.results["train"]["avg_predicted_safe_indices"])

            final_results["test"] = get_metrics(self.results["outputs_test"], self.results["labels_test"])
            final_results["test"]["loss"] = np.average(self.results["test"]["loss"])
            final_results["test"]["avg_prediction_frame"] = np.average(self.results["test"]["avg_prediction_frame"])
            final_results["test"]["avg_seq_len"] = np.average(self.results["test"]["avg_seq_len"])
            final_results["test"]["avg_predicted_risky_indices"] = np.average(self.results["test"]["avg_predicted_risky_indices"])
            final_results["test"]["avg_predicted_safe_indices"] = np.average(self.results["test"]["avg_predicted_safe_indices"])
            final_results["avg_inf_time"] = np.average(self.results["avg_inf_time"])

            final_results["best_epoch"] = np.average(self.results["best_epoch"])
            final_results["best_val_loss"] = np.average(self.results["best_val_loss"])
            final_results["best_val_acc"] = np.average(self.results["best_val_acc"])
            final_results["best_val_auc"] = np.average(self.results["best_val_auc"])
            final_results["best_val_conf"] = self.results["best_val_conf"]
            final_results["best_val_f1"] = np.average(self.results["best_val_f1"])
            final_results["best_val_mcc"] = np.average(self.results["best_val_mcc"])
            final_results["best_val_acc_balanced"] = np.average(self.results["best_val_acc_balanced"])
            final_results["best_avg_pred_frame"] = np.average(self.results["best_avg_pred_frame"])

            print("\nFinal Averaged Results")
            print(
                "\naverage train loss: "
                + str(final_results["train"]["loss"])
                + ", average acc:"
                + str(final_results["train"]["acc"])
                + " "
                + str(final_results["train"]["confusion"])
                + " "
                + str(final_results["train"]["auc"])
                + "\naverage test loss: "
                + str(final_results["test"]["loss"])
                + ", average acc:"
                + str(final_results["test"]["acc"])
                + " "
                + str(final_results["test"]["confusion"])
                + " "
                + str(final_results["test"]["auc"])
            )

            if self.log:
                log_wandb(final_results)

            return (
                self.results["outputs_train"],
                self.results["labels_train"],
                self.results["outputs_test"],
                self.results["labels_test"],
                final_results,
            )

