import logging
from typing import List

import torch
from torch import Tensor
from torch_geometric.utils import degree
from torch_geometric.utils import remove_self_loops
from torch_geometric.utils import scatter
from yacs.config import CfgNode


def flatten_dict(metrics):
    """Flatten a list of train/val/test metrics into one dict to send to wandb.

    Args:
        metrics: List of Dicts with metrics

    Returns:
        A flat dictionary with names prefixed with "train/" , "val/" , "test/"
    """
    prefixes = ["train", "val", "test"]
    result = {}
    for i in range(len(metrics)):
        # Take the latest metrics.
        stats = metrics[i][-1]
        result.update({f"{prefixes[i]}/{k}": v for k, v in stats.items()})
    return result


def cfg_to_dict(cfg_node, key_list=[]):
    """Convert a config node to dictionary.

    Yacs doesn't have a default function to convert the cfg object to plain
    python dict. The following function was taken from
    https://github.com/rbgirshick/yacs/issues/19
    """
    _VALID_TYPES = {tuple, list, str, int, float, bool}

    if not isinstance(cfg_node, CfgNode):
        if type(cfg_node) not in _VALID_TYPES:
            logging.warning(
                f"Key {'.'.join(key_list)} with "
                f"value {type(cfg_node)} is not "
                f"a valid type; valid types: {_VALID_TYPES}"
            )
        return cfg_node
    else:
        cfg_dict = dict(cfg_node)
        for k, v in cfg_dict.items():
            cfg_dict[k] = cfg_to_dict(v, key_list + [k])
        return cfg_dict


def make_wandb_name(cfg):
    # Format dataset name.
    dataset_name = cfg.dataset.format
    if dataset_name.startswith("OGB"):
        dataset_name = dataset_name[3:]
    if dataset_name.startswith("PyG-"):
        dataset_name = dataset_name[4:]
    if dataset_name in ["GNNBenchmarkDataset", "TUDataset"]:
        # Shorten some verbose dataset naming schemes.
        dataset_name = ""
    if cfg.dataset.name != "none":
        dataset_name += "-" if dataset_name != "" else ""
        if cfg.dataset.name == "LocalDegreeProfile":
            dataset_name += "LDP"
        else:
            dataset_name += cfg.dataset.name
        dataset_name += "-" if dataset_name != "" else ""
        dataset_name += cfg.dataset.task_type

    # Format model name.
    model_name = cfg.model.type

    if cfg.model.type == "PerceiverGraph_SingleDataset":
        model_name = f"PerceiverGraph_SingleDataset"
    model_name += f".{cfg.name_tag}" if cfg.name_tag else ""

    if cfg.posenc_LapPE.enable:
        model_name += "+LapPE"

    if cfg.posenc_RWSE.enable:
        model_name += "+RWSE"

    # Compose wandb run name.
    name = f"{dataset_name}.{model_name}.r{cfg.run_id}"
    return name
