import warnings

warnings.filterwarnings(
    "ignore",
    message="Please install 'pytorch_lightning' for using the GraphGym experiment manager",
)

import logging
import time

import numpy as np
import torch
from scipy.stats import stats
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    mean_absolute_error,
    mean_squared_error,
    confusion_matrix,
)
from sklearn.metrics import r2_score
from torch_geometric.graphgym import get_current_gpu_usage
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.logger import infer_task, Logger
from torch_geometric.graphgym.utils.io import dict_to_json, dict_to_tb
from torchmetrics.functional import auroc

from custom_modules.metric_wrapper import MetricWrapper


def accuracy_SBM(targets, pred_int):
    """Accuracy eval for Benchmarking GNN's PATTERN and CLUSTER datasets.
    https://github.com/graphdeeplearning/benchmarking-gnns/blob/master/train/metrics.py#L34
    """
    S = targets
    C = pred_int
    CM = confusion_matrix(S, C).astype(np.float32)
    nb_classes = CM.shape[0]
    targets = targets.cpu().detach().numpy()
    nb_non_empty_classes = 0
    pr_classes = np.zeros(nb_classes)
    for r in range(nb_classes):
        cluster = np.where(targets == r)[0]
        if cluster.shape[0] != 0:
            pr_classes[r] = CM[r, r] / float(cluster.shape[0])
            if CM[r, r] > 0:
                nb_non_empty_classes += 1
        else:
            pr_classes[r] = 0.0
    acc = np.sum(pr_classes) / float(nb_classes)
    return acc


class CustomLogger(Logger):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Whether to run comparison tests of alternative score implementations.
        self.test_scores = False

    # basic properties
    def basic(self):
        stats = {
            "loss": round(self._loss / self._size_current, max(8, cfg.round)),
            "lr": round(self._lr, max(8, cfg.round)),
            "params": self._params,
            "time_iter": round(self.time_iter(), cfg.round),
        }
        gpu_memory = get_current_gpu_usage()
        if gpu_memory > 0:
            stats["gpu_memory"] = gpu_memory
        return stats

    # task properties
    def classification_binary(self):
        true = torch.cat(self._true).squeeze(-1)
        pred_score = torch.cat(self._pred)
        pred_int = self._get_pred_int(pred_score)

        if (
            true.shape[0] < 1e7
        ):  # AUROC computation for very large datasets is too slow.
            # TorchMetrics AUROC on GPU if available.
            auroc_score = auroc(
                pred_score.to(torch.device(cfg.accelerator)),
                true.to(torch.device(cfg.accelerator)),
                task="binary",
            )
            if self.test_scores:
                # SK-learn version.
                try:
                    r_a_score = roc_auc_score(
                        true.cpu().numpy(), pred_score.cpu().numpy()
                    )
                except ValueError:
                    r_a_score = 0.0
                assert np.isclose(float(auroc_score), r_a_score)
        else:
            auroc_score = 0.0

        reformat = lambda x: round(float(x), cfg.round)
        res = {
            "accuracy": reformat(accuracy_score(true, pred_int)),
            "precision": reformat(precision_score(true, pred_int)),
            "recall": reformat(recall_score(true, pred_int)),
            "f1": reformat(f1_score(true, pred_int)),
            "auc": reformat(auroc_score),
        }
        if cfg.metric_best == "accuracy-SBM":
            res["accuracy-SBM"] = reformat(accuracy_SBM(true, pred_int))
        return res

    def classification_multi(self):
        true, pred_score = torch.cat(self._true), torch.cat(self._pred)
        pred_int = self._get_pred_int(pred_score)
        reformat = lambda x: round(float(x), cfg.round)

        res = {
            "accuracy": reformat(accuracy_score(true, pred_int)),
            "f1": reformat(f1_score(true, pred_int, average="macro", zero_division=0)),
        }
        if cfg.metric_best == "accuracy-SBM":
            res["accuracy-SBM"] = reformat(accuracy_SBM(true, pred_int))
        if true.shape[0] < 1e7:
            # AUROC computation for very large datasets runs out of memory.
            # TorchMetrics AUROC on GPU is much faster than sklearn for large ds
            res["auc"] = reformat(
                auroc(
                    pred_score.to(torch.device(cfg.accelerator)),
                    true.to(torch.device(cfg.accelerator)).squeeze(),
                    task="multiclass",
                    num_classes=pred_score.shape[1],
                    average="macro",
                )
            )

            if self.test_scores:
                # SK-learn version.
                sk_auc = reformat(
                    roc_auc_score(
                        true, pred_score.exp(), average="macro", multi_class="ovr"
                    )
                )
                assert np.isclose(sk_auc, res["auc"])

        return res

    def regression(self):
        true, pred = torch.cat(self._true), torch.cat(self._pred)
        reformat = lambda x: round(float(x), cfg.round)
        return {
            "mae": reformat(mean_absolute_error(true, pred)),
            "r2": reformat(r2_score(true, pred, multioutput="uniform_average")),
            "spearmanr": reformat(
                eval_spearmanr(true.numpy(), pred.numpy())["spearmanr"]
            ),
            "mse": reformat(mean_squared_error(true, pred)),
            "rmse": reformat(mean_squared_error(true, pred, squared=False)),
        }

    def update_stats(
        self, true, pred, loss, lr, time_used, params, dataset_name=None, **kwargs
    ):
        assert true.shape[0] == pred.shape[0]
        batch_size = true.shape[0]
        self._iter += 1
        self._true.append(true)
        self._pred.append(pred)
        self._size_current += batch_size
        self._loss += loss * batch_size
        self._lr = lr
        self._params = params
        self._time_used += time_used
        self._time_total += time_used
        for key, val in kwargs.items():
            if key not in self._custom_stats:
                self._custom_stats[key] = val * batch_size
            else:
                self._custom_stats[key] += val * batch_size

    def write_epoch(self, cur_epoch):
        start_time = time.perf_counter()
        basic_stats = self.basic()

        if self.task_type == "regression":
            task_stats = self.regression()
        elif self.task_type == "classification_binary":
            task_stats = self.classification_binary()
        elif self.task_type == "classification_multi":
            task_stats = self.classification_multi()
        elif self.task_type == "classification_multilabel":
            task_stats = self.classification_multilabel()
        elif self.task_type == "subtoken_prediction":
            task_stats = self.subtoken_prediction()
        else:
            raise ValueError("Task has to be regression or classification")

        epoch_stats = {
            "epoch": cur_epoch,
            "time_epoch": round(self._time_used, cfg.round),
        }
        eta_stats = {
            "eta": round(self.eta(cur_epoch), cfg.round),
            "eta_hours": round(self.eta(cur_epoch) / 3600, cfg.round),
        }
        custom_stats = self.custom()

        if self.name == "train":
            stats = {
                **epoch_stats,
                **eta_stats,
                **basic_stats,
                **task_stats,
                **custom_stats,
            }
        else:
            stats = {**epoch_stats, **basic_stats, **task_stats, **custom_stats}

        # print
        logging.info("{}: {}".format(self.name, stats))
        # json
        dict_to_json(stats, "{}/stats.json".format(self.out_dir))
        # tensorboard
        if cfg.tensorboard_each_run:
            dict_to_tb(stats, self.tb_writer, cur_epoch)
        self.reset()
        if cur_epoch < 3:
            logging.info(
                f"...computing epoch stats took: "
                f"{time.perf_counter() - start_time:.2f}s"
            )
        return stats


class CustomLogger_multi(Logger):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Whether to run comparison tests of alternative score implementations.
        self.test_scores = False

    # basic properties
    def basic(self):
        stats = {
            "loss": round(self._loss / self._size_current, max(8, cfg.round)),
            "lr": round(self._lr, max(8, cfg.round)),
            "params": self._params,
            "time_iter": round(self.time_iter(), cfg.round),
        }
        gpu_memory = get_current_gpu_usage()
        if gpu_memory > 0:
            stats["gpu_memory"] = gpu_memory
        return stats

    # task properties
    def classification_binary(self, task_key):
        true = torch.cat(self._custom_stats["true"][task_key]).squeeze(-1)
        pred_score = torch.cat(self._custom_stats["pred"][task_key])
        pred_int = self._get_pred_int(pred_score)

        if (
            true.shape[0] < 1e7
        ):  # AUROC computation for very large datasets is too slow.
            # TorchMetrics AUROC on GPU if available.
            auroc_score = auroc(
                pred_score.to(torch.device(cfg.accelerator)),
                true.to(torch.device(cfg.accelerator)),
                task="binary",
            )
            if self.test_scores:
                # SK-learn version.
                try:
                    r_a_score = roc_auc_score(
                        true.cpu().numpy(), pred_score.cpu().numpy()
                    )
                except ValueError:
                    r_a_score = 0.0
                assert np.isclose(float(auroc_score), r_a_score)
        else:
            auroc_score = 0.0

        reformat = lambda x: round(float(x), cfg.round)
        res = {
            "accuracy": reformat(accuracy_score(true, pred_int)),
            "precision": reformat(precision_score(true, pred_int)),
            "recall": reformat(recall_score(true, pred_int)),
            "f1": reformat(f1_score(true, pred_int)),
            "auc": reformat(auroc_score),
        }
        if cfg.metric_best == "accuracy-SBM":
            res["accuracy-SBM"] = reformat(accuracy_SBM(true, pred_int))
        return res

    def classification_multi(self, task_key):
        true, pred_score = torch.cat(self._custom_stats["true"][task_key]), torch.cat(
            self._custom_stats["pred"][task_key]
        )
        pred_int = self._get_pred_int(pred_score)
        reformat = lambda x: round(float(x), cfg.round)

        res = {
            "accuracy": reformat(accuracy_score(true, pred_int)),
            # "f1": reformat(f1_score(true, pred_int, average="macro", zero_division=0)),
        }
        # if cfg.metric_best == "accuracy-SBM":
        #     res["accuracy-SBM"] = reformat(accuracy_SBM(true, pred_int))
        # if true.shape[0] < 1e7:
        #     # AUROC computation for very large datasets runs out of memory.
        #     # TorchMetrics AUROC on GPU is much faster than sklearn for large ds
        #     res["auc"] = reformat(
        #         auroc(
        #             pred_score.to(torch.device(cfg.accelerator)),
        #             true.to(torch.device(cfg.accelerator)).squeeze(),
        #             task="multiclass",
        #             num_classes=pred_score.shape[1],
        #             average="macro",
        #         )
        #     )

        #     if self.test_scores:
        #         # SK-learn version.
        #         sk_auc = reformat(
        #             roc_auc_score(
        #                 true, pred_score.exp(), average="macro", multi_class="ovr"
        #             )
        #         )
        #         assert np.isclose(sk_auc, res["auc"])

        return res

    def regression(self, task_key):
        true, pred = torch.cat(self._custom_stats["true"][task_key]), torch.cat(
            self._custom_stats["pred"][task_key]
        )
        reformat = lambda x: round(float(x), cfg.round)
        return {
            "mae": reformat(mean_absolute_error(true, pred)),
            "r2": reformat(r2_score(true, pred, multioutput="uniform_average")),
            "spearmanr": reformat(
                eval_spearmanr(true.numpy(), pred.numpy())["spearmanr"]
            ),
            "mse": reformat(mean_squared_error(true, pred)),
            "rmse": reformat(mean_squared_error(true, pred, squared=False)),
        }

    def update_stats(
        self,
        true,
        pred,
        batch_size,
        losses_taskwise,
        lr,
        total_loss,
        time_used,
        params,
        dataset_name=None,
        **kwargs,
    ):
        for key in true:
            assert key in pred, f"Key '{key}' not found in 'pred' dictionary"
            assert (
                true[key].shape[0] == pred[key].shape[0]
            ), f"Mismatch in shape for key '{key}': {true[key].shape[0]} != {pred[key].shape[0]}"

        self._iter += 1
        # self._true.append(true)
        # self._pred.append(pred)
        self._loss += total_loss
        self._size_current += 1
        self._lr = lr
        self._params = params
        self._time_used += time_used
        self._time_total += time_used

        # if "true" not in self._custom_stats:
        #     self._custom_stats["true"] = true
        #     self._custom_stats["pred"] = pred
        #     self._custom_stats["losses_taskwise"] = losses_taskwise
        # else:
        #     for key in true:
        #         if key not in self._custom_stats["true"].keys():
        #             self._custom_stats["true"][key] = true[key]
        #             self._custom_stats["pred"][key] = pred[key]
        #             self._custom_stats["losses_taskwise"][key] = losses_taskwise[key]
        #         else:
        #             self._custom_stats["true"][key] = torch.cat((self._custom_stats["true"][key], true[key]), dim=0)
        #             self._custom_stats["pred"][key] = torch.cat((self._custom_stats["pred"][key], pred[key]), dim=0)
        #             self._custom_stats["losses_taskwise"][key] += losses_taskwise[key]

        for key in true.keys():
            if "true" not in self._custom_stats:
                self._custom_stats["true"] = {}
                self._custom_stats["pred"] = {}
                self._custom_stats["losses_taskwise"] = {}

            if key not in self._custom_stats["true"].keys():
                self._custom_stats["true"][key] = [true[key]]
                self._custom_stats["pred"][key] = [pred[key]]
                self._custom_stats["losses_taskwise"][key] = losses_taskwise[key]
            else:
                self._custom_stats["true"][key].append(true[key])
                self._custom_stats["pred"][key].append(pred[key])
                self._custom_stats["losses_taskwise"][key] += losses_taskwise[key]

        for key, val in kwargs.items():
            if key not in self._custom_stats:
                self._custom_stats[key] = val
            else:
                self._custom_stats[key] += val

    def get_task_stats(self, data_cfg, task_key):
        num_label = data_cfg.task_dim
        if data_cfg.task_type == "classification":
            if num_label <= 2:
                task_type = "classification_binary"
            else:
                task_type = "classification_multi"
        else:
            task_type = data_cfg.task_type

        if task_type == "regression":
            task_stats = self.regression(task_key)
        elif task_type == "classification_binary":
            task_stats = self.classification_binary(task_key)
        elif task_type == "classification_multi":
            task_stats = self.classification_multi(task_key)
        else:
            raise ValueError("Task has to be regression or classification")

        return task_stats

    def write_epoch(self, cur_epoch):
        start_time = time.perf_counter()
        basic_stats = self.basic()

        task_stats_all = {}
        for dataset_name in cfg.dataset_multi.name_list:
            dataset_cfg = getattr(cfg, dataset_name)
            task_key = (
                f"{dataset_cfg.dataset_name}_{dataset_cfg.task}_{dataset_cfg.task_type}"
            )
            task_stats = self.get_task_stats(dataset_cfg, task_key)
            task_stats_all[task_key] = task_stats

        epoch_stats = {
            "epoch": cur_epoch,
            "time_epoch": round(self._time_used, cfg.round),
        }
        eta_stats = {
            "eta": round(self.eta(cur_epoch), cfg.round),
            "eta_hours": round(self.eta(cur_epoch) / 3600, cfg.round),
        }
        custom_stats = {}

        if self.name == "train":
            stats = {
                **epoch_stats,
                **eta_stats,
                **basic_stats,
                **task_stats_all,
                **custom_stats,
            }
        else:
            stats = {**epoch_stats, **basic_stats, **task_stats_all, **custom_stats}

        # print
        logging.info("{}: {}".format(self.name, stats))
        # json
        dict_to_json(stats, "{}/stats.json".format(self.out_dir))
        # tensorboard
        if cfg.tensorboard_each_run:
            dict_to_tb(stats, self.tb_writer, cur_epoch)
        self.reset()
        if cur_epoch < 3:
            logging.info(
                f"...computing epoch stats took: "
                f"{time.perf_counter() - start_time:.2f}s"
            )
        return stats


def create_logger(cfg):
    """
    Create logger for the experiment

    Returns: List of logger objects

    """
    loggers = []
    names = ["train", "val", "test"]
    if len(cfg.dataset_multi.name_list) > 0:
        cfg.share.num_splits = 3
    for i, dataset in enumerate(range(cfg.share.num_splits)):
        if len(cfg.dataset_multi.name_list) > 0:
            loggers.append(CustomLogger_multi(name=names[i]))
        else:
            loggers.append(CustomLogger(name=names[i], task_type=infer_task()))
    return loggers


def eval_spearmanr(y_true, y_pred):
    """Compute Spearman Rho averaged across tasks."""
    res_list = []

    if y_true.ndim == 1:
        res_list.append(stats.spearmanr(y_true, y_pred)[0])
    else:
        for i in range(y_true.shape[1]):
            # ignore nan values
            is_labeled = ~np.isnan(y_true[:, i])
            res_list.append(
                stats.spearmanr(y_true[is_labeled, i], y_pred[is_labeled, i])[0]
            )

    return {"spearmanr": sum(res_list) / len(res_list)}
