import yaml
import torch
import time
from omegaconf import DictConfig, OmegaConf, open_dict
from torch_geometric.utils import degree
from torch import Tensor
from typing import List

def load_config(args):
    from gpl import TBLOG_HPARAMS_DIR, CKPT_DIR

    if args.train is True:
        exp_name_list = [args.dataset.name, args.model.name, time.strftime('%Y-%m-%d_%H-%M-%S'), f"seed{args.random_seed}"]
        EXP_NAME = '_'.join(exp_name_list)
        if args.debug is True:
            EXP_NAME = 'DEBUG_' + EXP_NAME 
        local_config = args
        local_config['evaluation']['ckpt_file'] = '' 
        print('\nThese are [Training] args')
        
    else:
        from gpl.training import yaml_load
        ckpt_file = CKPT_DIR/args.evaluation.ckpt_file
        EXP_NAME = args.evaluation.ckpt_file.split('/')[0]

        config_fn = TBLOG_HPARAMS_DIR/f'{EXP_NAME}.yml'
        local_config = yaml_load(config_fn)
        local_config = OmegaConf.create(local_config)

        local_config['evaluation']['ckpt_file'] = str(ckpt_file) 
        local_config['train'] = False
        print(f'\nThese are [Old evaluation] args. {str(config_fn)} is loaded')

    
    return EXP_NAME, local_config


# copied from https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/utils/unbatch.html#unbatch_edge_index
def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]:
    r"""Splits the :obj:`edge_index` according to a :obj:`batch` vector.

    Args:
        edge_index (Tensor): The edge_index tensor. Must be ordered.
        batch (LongTensor): The batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. Must be ordered.

    :rtype: :class:`List[Tensor]`
    """
    deg = degree(batch, dtype=torch.int64)
    ptr = torch.cat([deg.new_zeros(1), deg.cumsum(dim=0)[:-1]], dim=0)

    edge_batch = batch[edge_index[0]]
    edge_index = edge_index - ptr[edge_batch]
    sizes = degree(edge_batch, dtype=torch.int64).cpu().tolist()
    return edge_index.split(sizes, dim=1)
