# from graphgym.loader.datasets.ocb_dataset import OCBDataset
# from graphgym.loader.master_loader import load_dataset_master
from graphgym.optimizer.extra_optimizers import ExtendedSchedulerConfig
from graphgym.logger import create_logger
from graphgym.finetuning import load_pretrained_model_cfg, \
    init_model_from_pretrained

import os
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
from torch_geometric.graphgym.utils.comp_budget import params_count
from torch_geometric.graphgym.cmd_args import parse_args
from torch_geometric.graphgym.config import (cfg, dump_cfg,
                                             # set_agg_dir,
                                             set_cfg, load_cfg,
                                             makedirs_rm_exist)
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.optim import create_optimizer, \
    create_scheduler, OptimizerConfig
from torch_geometric.graphgym.train import train
from torch_geometric.graphgym.register import train_dict
from torch_geometric.graphgym.loader import create_loader

import datetime
import torch
import logging

from torch_geometric.data import DataLoader, Batch

def custom_collate_fn(batch):
    """
    Custom collate function to ensure that custom attributes 
    (like 'rrwp', 'rrwp_index', 'rrwp_val', 'log_deg', 'deg') 
    are properly batched.
    
    Args:
        batch (list of Data objects): The batch of data from the dataset.
        
    Returns:
        Batch: The batched data object containing all attributes, including custom ones.
    """
    # Use the default collation behavior to batch standard graph attributes like `x`, `edge_index`, etc.
    collated_batch = Batch.from_data_list(batch)
    
    # List of custom attributes to batch
    custom_attributes = ['rrwp', 'rrwp_index', 'rrwp_val', 'log_deg', 'deg']
    
    for attr in custom_attributes:
        if hasattr(batch[0], attr):  # Check if the attribute exists in the data
            # Collect the custom attributes across the batch and concatenate them
            values = [getattr(data, attr) for data in batch]
            setattr(collated_batch, attr, torch.cat(values, dim=0))  # Set the concatenated values to the batch
    
    return collated_batch


def new_optimizer_config(cfg):
    return OptimizerConfig(optimizer=cfg.optim.optimizer,
                           base_lr=cfg.optim.base_lr,
                           weight_decay=cfg.optim.weight_decay,
                           momentum=cfg.optim.momentum)

def custom_set_out_dir(cfg, cfg_fname, name_tag):
    """Set custom main output directory path to cfg.
    Include the config filename and name_tag in the new :obj:`cfg.out_dir`.

    Args:
        cfg (CfgNode): Configuration node
        cfg_fname (string): Filename for the yaml format configuration file
        name_tag (string): Additional name tag to identify this execution of the
            configuration file, specified in :obj:`cfg.name_tag`
    """
    run_name = os.path.splitext(os.path.basename(cfg_fname))[0]
    run_name += f"-{name_tag}" if name_tag else ""
    cfg.out_dir = os.path.join(cfg.out_dir, run_name)

def new_scheduler_config(cfg):
    return ExtendedSchedulerConfig(
        scheduler=cfg.optim.scheduler,
        steps=cfg.optim.steps,
        lr_decay=cfg.optim.lr_decay,
        max_epoch=cfg.optim.max_epoch, reduce_factor=cfg.optim.reduce_factor,
        schedule_patience=cfg.optim.schedule_patience, min_lr=cfg.optim.min_lr,
        num_warmup_epochs=cfg.optim.num_warmup_epochs,
        train_mode=cfg.train.mode,
        eval_period=cfg.train.eval_period,
        num_cycles=cfg.optim.num_cycles,
        min_lr_mode=cfg.optim.min_lr_mode
    )

if __name__ == "__main__":

    # Load cmd line args
    args = parse_args()
    # Load config file
    set_cfg(cfg)
    # ----- note: allow to change config -----------
    cfg.set_new_allowed(True)
    cfg.work_dir = os.getcwd()
    # -----------------------------
    load_cfg(cfg, args)
    ##
    if hasattr(cfg, 'experiment_name'):
        cfg.name_tag = cfg.experiment_name
    else:
        cfg.name_tag = ""
    ##
    if hasattr(cfg.train, 'prior') and cfg.train.prior == 'masked':
        cfg.dataset.nnode_types = cfg.dataset.nnode_types + 1
        cfg.dataset.nnode_features = cfg.dataset.nnode_features + 1
        cfg.dataset.nedge_types = cfg.dataset.nedge_types + 1
    cfg.cfg_file = args.cfg_file
    custom_set_out_dir(cfg, args.cfg_file, cfg.name_tag)
    # Set Pytorch environment
    torch.set_num_threads(cfg.num_threads)

    loaders = create_loader()
    cfg.device = "cuda"
    dump_cfg(cfg)
    loggers = create_logger()
    model = create_model()
    if cfg.pretrained.dir:
        model = init_model_from_pretrained(
            model, cfg.pretrained.dir, cfg.pretrained.freeze_main,
            cfg.pretrained.reset_prediction_head
        )
    optimizer = create_optimizer(model.parameters(),
                                    new_optimizer_config(cfg))
    scheduler = create_scheduler(optimizer, new_scheduler_config(cfg))
    # # Print model info
    
    logging.info(model)
    logging.info(cfg)
    cfg.params = params_count(model)
    logging.info('Num parameters: %s', cfg.params)
    # # Start training
    if cfg.train.mode == 'standard':
        if cfg.wandb.use:
            logging.warning("[W] WandB logging is not supported with the "
                            "default train.mode, set it to `custom`")
        if cfg.mlflow.use:
            logging.warning("[ML] MLflow logging is not supported with the "
                            "default train.mode, set it to `custom`")
        train(loggers, loaders, model, optimizer, scheduler)
    else:
        train_dict[cfg.train.mode](loggers, loaders, model, optimizer,
                                    scheduler)
    logging.info(f"[*] All done: {datetime.datetime.now()}")
