import os
import torch
import argparse
import time
import logging
from torch.utils.data import Dataset, SequentialSampler, RandomSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
import torch.nn as nn
import torch.nn.functional as F

from dgl.dataloading.dataloader import DataLoader, GraphDataLoader

import torch.distributed as dist
import yaml
import random
import numpy as np
import pandas as pd
import json
from easydict import EasyDict as edict
from torch.optim import AdamW
from sklearn.metrics import accuracy_score, roc_auc_score, recall_score, precision_score, f1_score, roc_auc_score, average_precision_score, confusion_matrix, precision_recall_curve
from transformers import get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, RobertaConfig, RobertaModel, RobertaTokenizer
import math
import dgl
from gensim.models import Word2Vec

# from ogb.nodeproppred import DglNodePropPredDataset
dgl.use_libxsmm(False)

from model.ERUMModel import *
from model.ERUMLayers import RUMLayer
from utils.data_loader import *
from utils.dataset import *
from utils.utils import *
from utils.FGM import ATModel


os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

logger = logging.getLogger(__name__)

def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

def load_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return seed


def focal_loss(inputs, targets, alpha=None, gamma=2.0, reduction='mean'):
    ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=alpha)
    pt = torch.exp(-ce_loss)
    loss = (1 - pt) ** gamma * ce_loss
    return loss.mean() if reduction == 'mean' else loss


def compute_kl_loss(p, q, pad_mask=None):
    
    p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
    q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')
    
    # pad_mask is for seq-level tasks
    if pad_mask is not None:
        p_loss.masked_fill_(pad_mask, 0.)
        q_loss.masked_fill_(pad_mask, 0.)

    # Choose whether to use function "sum" and "mean" depending on task
    p_loss = p_loss.sum()
    q_loss = q_loss.sum()

    loss = (p_loss + q_loss) / 2
    return loss

def train(args, model, train_dataset, train_dataloader, eval_dataset=None, eval_dataloader=None):
    """ Train the model """ 
    args.max_steps = args.epochs * len(train_dataloader)
    args.save_steps = len(train_dataloader)
    args.warmup_steps = len(train_dataloader)
    args.logging_steps = len(train_dataloader)
    model.to(args.device)
    
    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 
         'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    if hasattr(args, 'learning_decay') and args.learning_decay > 0:
        scheduler = get_polynomial_decay_schedule_with_warmup(
            optimizer, 
            num_warmup_steps=int(args.max_steps * 0.1),
            num_training_steps=args.max_steps,
            power=args.learning_decay
        )
    else:
        scheduler = get_linear_schedule_with_warmup(
            optimizer, 
            num_warmup_steps=int(args.max_steps * 0.1),
            num_training_steps=args.max_steps
        )
    
    torch.autograd.set_detect_anomaly(True)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # Only use DistributedDataParallel for multi-GPU training
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, 
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True
        )
        
    checkpoint_last = os.path.join(args.output_dir, 'checkpoint-last')
    scheduler_last = os.path.join(checkpoint_last, 'scheduler.pt')
    optimizer_last = os.path.join(checkpoint_last, 'optimizer.pt')
    if os.path.exists(scheduler_last):
        scheduler.load_state_dict(torch.load(scheduler_last))
    if os.path.exists(optimizer_last):
        optimizer.load_state_dict(torch.load(optimizer_last))

    # Train!
    if args.local_rank in [-1, 0]:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Num Epochs = %d", args.epochs)
        logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                    args.train_batch_size * args.gradient_accumulation_steps * (
                        torch.distributed.get_world_size() if args.local_rank != -1 else 1))
        logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", args.max_steps)

    global_step = args.start_step
    tr_loss, logging_loss, avg_loss, tr_nb, tr_num, train_loss = 0.0, 0.0, 0.0, 0, 0, 0
    best_f1 = 0.0
    best_acc = 0.0
    
    ## Criterion Selection ##
    criterion = nn.CrossEntropyLoss()
    
    utilize_list = []
    pseudo_acc = []

    model.zero_grad()
    
    fgm = ATModel(model)
    for idx in range(args.start_epoch, int(args.epochs)): 
        # Set epoch for DistributedSampler
        if args.local_rank != -1:
            train_dataloader.sampler.set_epoch(idx)
            
        tr_num = 0
        train_loss = 0
        epoch_start = time.time()
        
        for step, batch in enumerate(train_dataloader):
            # Move graph data to current device
            batched_graph = batch[0].to(args.device)
            batched_targets = batch[1].to(args.device) 
            
            model.train()
            if args.directed:
                e = batched_graph.edata['edge_feature'].to(args.device)  # Ensure edge features are on the same device
            else:
                e = None

            if args.model == 'ERUMGraphRegressionModel':
                logits, loss = model(batched_graph, batched_graph.ndata['feat'], e=e)
                loss = loss + criterion(logits, batched_targets)
            elif args.model == 'ERUMGraphClassificationModel':
                logits, loss = model(batched_graph, e=e)
                loss = loss + criterion(logits, batched_targets)

            if args.local_rank != -1:
                loss = loss.mean()  # In DistributedDataParallel, average losses across GPUs
                
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if not args.do_fgm:
                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
            else:
                loss.backward(retain_graph=True)
                
                fgm.attack_emb()
                if args.model == 'ERUMGraphClassificationModel':
                    logits_adv, loss_adv = model(batched_graph)
                else:
                    logits_adv, loss_adv = model(batched_graph, batched_graph.ndata['feat'])
                loss_adv = loss_adv + criterion(logits_adv, batched_targets)
                kl_loss = compute_kl_loss(logits, logits_adv)
            
                if args.local_rank != -1:
                    loss_adv = loss_adv.mean()
                    kl_loss = kl_loss.mean()
                    
                if args.gradient_accumulation_steps > 1:
                    loss_adv = loss_adv / args.gradient_accumulation_steps
                    kl_loss = kl_loss / args.gradient_accumulation_steps
                    
                loss_adv.backward(retain_graph=True)
                kl_loss.backward()
                
                fgm.restore_emb()
            
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1, norm_type=2)
                
            tr_loss += loss.item()
            tr_num += 1
            train_loss += loss.item()
            
            if avg_loss == 0:
                avg_loss = tr_loss
                
            avg_loss = round(train_loss / tr_num, 5)

            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()  
                global_step += 1
                output_flag = True
                avg_loss = round(np.exp((tr_loss - logging_loss) / (global_step - tr_nb)), 4)
                
                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logging_loss = tr_loss
                    tr_nb = global_step

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    if args.local_rank == -1 and args.evaluate_during_training and idx % args.pre_epoch_eval == 0:
                        results = evaluate(args, model, eval_dataset, eval_dataloader, eval_when_training=True)
                        for key, value in results.items():
                            logger.info("  %s = %s", key, round(value, 4))
                    
                        if results['eval_f1'] > best_f1:
                            best_f1 = results['eval_f1']
                            logger.info("  " + "*" * 20)  
                            logger.info("  Best f1:%s", round(best_f1, 4))
                            logger.info("  " + "*" * 20)
                            
                            checkpoint_prefix = 'checkpoint-best-f1'
                            output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix))
                            if not os.path.exists(output_dir):
                                os.makedirs(output_dir)
                            model_to_save = model.module if hasattr(model, 'module') else model
                            output_dir = os.path.join(output_dir, '{}'.format('model.bin'))
                            torch.save(model_to_save.state_dict(), output_dir)
                            logger.info("Saving best f1 model checkpoint to %s", output_dir)
                            
                        if results['eval_acc'] > best_acc:
                            best_acc = results['eval_acc']
                            logger.info("  " + "*" * 20)
                            logger.info("  Best acc:%s", round(best_acc, 4))
                            logger.info("  " + "*" * 20)
                            
                            checkpoint_prefix = 'checkpoint-best-acc'
                            output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix))
                            if not os.path.exists(output_dir):
                                os.makedirs(output_dir)
                            model_to_save = model.module if hasattr(model, 'module') else model
                            output_dir = os.path.join(output_dir, '{}'.format('model.bin'))
                            torch.save(model_to_save.state_dict(), output_dir)
                            logger.info("Saving best acc model checkpoint to %s", output_dir)
        
        avg_loss = round(train_loss / tr_num, 5)

        epoch_end = time.time()
        cost_time = epoch_end - epoch_start
        if args.local_rank in [-1, 0]:
            logger.info("epoch {} loss {} cost time {}".format(idx, avg_loss, cost_time))

def evaluate(args, model, eval_dataset, eval_dataloader, eval_when_training=False):
    if args.local_rank != -1 and args.local_rank != 0:
        return
    eval_output_dir = args.output_dir
    if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(eval_output_dir)

    # Eval!
    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    
    ## Criterion Selection ##
    criterion = nn.CrossEntropyLoss()
    
    model.eval()
    y_pred, y_truth = [], []
    for batch in eval_dataloader:
        batched_graph = batch[0].to(args.device)
        batched_targets = batch[1].to(args.device) 
        if args.directed:
            e = batched_graph.edata['edge_feature'].to(args.device)
        else:
            e = None
        with torch.no_grad():
            if args.model == 'ERUMGraphRegressionModel':
                logits, loss = model(batched_graph, batched_graph.ndata['feat'], e=e)
                loss = loss + criterion(logits, batched_targets)
                prob = F.softmax(logits, dim=1)
                y_pred.append(prob)
                y_truth.append(batched_targets)
            elif args.model == 'ERUMGraphClassificationModel':
                logits, loss = model(batched_graph, e=e)
                loss = loss + criterion(logits, batched_targets)
                prob = F.softmax(logits, dim=1)
                y_pred.append(prob)
                y_truth.append(batched_targets)

            eval_loss += loss.mean().item()

        nb_eval_steps += 1

    probs = torch.cat(y_pred, dim=0).cpu().numpy()
    labels = torch.cat(y_truth, dim=0).cpu().numpy()
    if args.dynamic_thresholds:
        positive_probs = probs[:, 1]

        precision, recall, thresholds = precision_recall_curve(labels, positive_probs)
        f1_scores = (2 * precision * recall) / (precision + recall + 1e-8)

        best_f1_idx = np.nanargmax(f1_scores[:-1])
        best_threshold = thresholds[best_f1_idx]

        preds = positive_probs > best_threshold
    else:
        preds = probs[:, 1] > 0.5
    
    acc = accuracy_score(labels, preds)
    recall_val = recall_score(labels, preds)
    precision_val = precision_score(labels, preds)
    f1 = f1_score(labels, preds)
        
    pr_auc = average_precision_score(labels, probs[:, 1])
    roc_auc = roc_auc_score(labels, probs[:, 1])
        
    conf_matrix = confusion_matrix(labels, preds)
    tn, fp, fn, tp = conf_matrix.ravel()
    false_positive_rate = fp / (fp + tn)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        
    eval_loss = eval_loss / nb_eval_steps
    valid_result = {
        "eval_loss": float(eval_loss),
        "eval_acc": float(acc),                
        "eval_recall": float(recall_val),      
        "eval_precision": float(precision_val),
        "eval_f1": float(f1),                  
        "eval_pr_auc": float(pr_auc),          
        "eval_roc_auc": float(roc_auc),        
        "eval_fpr": float(false_positive_rate),
        "eval_specificity": float(specificity),
    }

    return valid_result


def test(args, model, test_dataset, test_dataloader):
    if args.local_rank != -1 and args.local_rank != 0:
        return
    test_output_dir = args.output_dir
    if not os.path.exists(test_output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(test_output_dir)
        
    # Test!
    logger.info("***** Running Test *****")
    logger.info("  Num examples = %d", len(test_dataset))
    logger.info("  Batch size = %d", args.test_batch_size)
    test_loss = 0.0
    nb_test_steps = 0
    
    ## Criterion Selection ##
    criterion = nn.CrossEntropyLoss()

    model.eval()
    y_pred, y_truth = [], []
    for batch in test_dataloader:
        batched_graph = batch[0].to(args.device)
        batched_targets = batch[1].to(args.device) 
        if args.directed:
            e = batched_graph.edata['edge_feature'].to(args.device)
        else:
            e = None
        with torch.no_grad():
            if args.model == 'ERUMGraphRegressionModel':
                logits, loss = model(batched_graph, batched_graph.ndata['feat'], e=e)
                loss = loss + criterion(logits, batched_targets)
                prob = F.softmax(logits, dim=1)
                y_pred.append(prob)
                y_truth.append(batched_targets)
            elif args.model == 'ERUMGraphClassificationModel':
                logits, loss = model(batched_graph, e=e)
                loss = loss + criterion(logits, batched_targets)
                prob = F.softmax(logits, dim=1)
                y_pred.append(prob)
                y_truth.append(batched_targets)
            
            test_loss += loss.mean().item()
        
        nb_test_steps += 1

    probs = torch.cat(y_pred, dim=0).cpu().numpy()
    labels = torch.cat(y_truth, dim=0).cpu().numpy()

    if args.dynamic_thresholds:
        positive_probs = probs[:, 1]

        precision, recall, thresholds = precision_recall_curve(labels, positive_probs)
        f1_scores = (2 * precision * recall) / (precision + recall + 1e-8)

        best_f1_idx = np.nanargmax(f1_scores[:-1])
        best_threshold = thresholds[best_f1_idx]

        preds = positive_probs > best_threshold
    else:
        preds = probs[:, 1] > 0.5
    
    acc = accuracy_score(labels, preds)
    recall_test = recall_score(labels, preds)
    precision_test = precision_score(labels, preds)
    f1 = f1_score(labels, preds)
        
    pr_auc = average_precision_score(labels, probs[:, 1])
    roc_auc = roc_auc_score(labels, probs[:, 1])
        
    conf_matrix = confusion_matrix(labels, preds)
    tn, fp, fn, tp = conf_matrix.ravel()
    false_positive_rate = fp / (fp + tn)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        
    test_loss = test_loss / nb_test_steps
    test_result = {
        "test_loss": float(test_loss),
        "test_acc": float(acc),                
        "test_recall": float(recall_test),     
        "test_precision": float(precision_test),
        "test_f1": float(f1),                  
        "test_pr_auc": float(pr_auc),          
        "test_roc_auc": float(roc_auc),        
        "test_fpr": float(false_positive_rate),
        "test_specificity": float(specificity),
    }

    return test_result

def main():
    parser = argparse.ArgumentParser(description='RUM-VDL')

    ## Required parameters
    parser.add_argument('--config', type=str, default="config.yaml", help="Path of config file")
    parser.add_argument('--lang', type=str, default="c", help="language of dataset", choices=['c', 'java'])
    parser.add_argument("--train_data_file", default="../dataset/train.jsonl", type=str,
                        help="The input training data file (a text file).")
    parser.add_argument("--output_dir", default="./saved_models", type=str,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--eval_data_file", default="../dataset/valid.jsonl", type=str,
                        help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
    parser.add_argument("--test_data_file", default="../dataset/test.jsonl", type=str,
                        help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
    parser.add_argument("--block_size", default=-1, type=int,
                        help="Optional input sequence length after tokenization."
                             "The training dataset will be truncated in block of this size for training."
                             "Default to the model max input length for single sentence inputs (take into account special tokens).")
    parser.add_argument("--num_nodes", default=400, type=int)
    parser.add_argument("--word2vec_path", default="./saved_models/word2vec")
    parser.add_argument("--model_name_or_path", default="microsoft/codebert-base")
    parser.add_argument("--model_type", default="roberta")
    parser.add_argument("--encoder_type", default="by_token")
    parser.add_argument("--do_train", action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval", action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_test", action='store_true',
                        help="Whether to run eval on the dev set.")    
    parser.add_argument("--evaluate_during_training", action='store_true',
                        help="Run evaluation during training at each logging step.")

    parser.add_argument("--train_batch_size", default=4, type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--eval_batch_size", default=4, type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument("--test_batch_size", default=4, type=int,
                        help="Batch size per GPU/CPU for testing.")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--learning_rate", default=5e-5, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--learning_decay", default=-0.9999999, type=float,
                        help="The initial learning decay for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-5, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--unweighted_MSE", action='store_true',
                        help="Whether to use unweighted MSE.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")
    parser.add_argument("--epochs", default=1.0, type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--pre_epoch_eval", default=1.0, type=float)
    parser.add_argument("--max_steps", default=-1, type=int,
                        help="If > 0: set total number of training steps to perform. Override epochs.")
    parser.add_argument("--warmup_steps", default=0, type=int,
                        help="Linear warmup over warmup_steps.")

    parser.add_argument('--logging_steps', type=int, default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps', type=int, default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument('--save_total_limit', type=int, default=None,
                        help='Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default')
    parser.add_argument("--eval_all_checkpoints", action='store_true',
                        help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
    parser.add_argument("--no_cuda", action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--overwrite_output_dir', action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument('--overwrite_cache', action='store_true',
                        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed', type=int, default=42,
                        help="random seed for initialization")
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
    parser.add_argument('--fp16_opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")

    parser.add_argument("--model", default="RUMModel", type=str,
                        help="choose model.", choices=['ERUMGraphRegressionModel', 
                                                       'ERUMGraphClassificationModel'])
    parser.add_argument("--hidden_size", default=256, type=int,
                        help="hidden size.")
    parser.add_argument("--feature_dim_size", default=129, type=int,
                        help="feature dim size.")
    parser.add_argument("--num_classes", default=2, type=int,
                        help="num classes.")
    parser.add_argument("--training_percent", default=1., type=float, help="percet of training sample")

    parser.add_argument("--do_save_best_f1", action='store_true',
                        help="Whether to save best-f1 model.")
    parser.add_argument("--do_use_best_f1", action='store_true',
                        help="Whether to use best-f1 model.")
    parser.add_argument("--dynamic_thresholds", action='store_true',
                        help="Whether to use dynamic thresholds.")
    parser.add_argument("--do_fgm", action='store_true',
                        help="Whether to use fgm.")
    
    # ERUM
    parser.add_argument("--num_samples", default=8, type=int,
                        help="num samples.")
    parser.add_argument("--walk_length", default=8, type=int,
                        help="walk length.")
    parser.add_argument("--depth", default=1, type=int,
                        help="depth.")
    parser.add_argument("--num_layers", default=2, type=int,
                        help="num layers.")
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--self_supervise_weight", type=float, default=1e-5)
    parser.add_argument("--consistency_weight", type=float, default=1)
    parser.add_argument("--consistency_temperature", type=float, default=0.5)
    parser.add_argument("--dropout", type=float, default=0.5)
    parser.add_argument("--activation", type=str, default="ELU", 
                        choices=["ReLU", "ELU", "SiLU"]) # tune.choice(["ReLU", "ELU", "SiLU"]),
    parser.add_argument("--directed", action='store_true',
                        help="Whether to use directed graph.")
    
    args = parser.parse_args()

    if os.path.exists(args.config):
        config = edict(yaml.load(open(args.config, 'r'), Loader=yaml.FullLoader))
        args.seed = config.seed
        args.dataset = config.dataset
        args.do_train = config.do_train

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count() if torch.cuda.is_available() and not args.no_cuda else 0
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1

    args.device = device
    args.per_gpu_train_batch_size = args.train_batch_size
    args.per_gpu_eval_batch_size = args.eval_batch_size
    args.per_gpu_test_batch_size = args.test_batch_size
    
    # Adjust total batch size for distributed training
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    args.test_batch_size = args.per_gpu_test_batch_size * max(1, args.n_gpu)

    # Setup logging
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN
    )
    logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
                   args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)

    # Set seed
    load_seed(args.seed)

    # Initialize distributed training if needed
    if args.local_rank != -1:
        torch.distributed.barrier()

    args.start_epoch = 0
    args.start_step = 0

    if args.directed:
        args.edge_feat_dim = 2  # direction + type
    else:
        args.edge_feat_dim = 0  # no direction and type

    if args.model == 'ERUMGraphRegressionModel':
        model = ERUMGraphRegressionModel(
            in_features=args.feature_dim_size,
            out_features=args.num_classes,
            hidden_features=args.hidden_size,
            depth=args.depth,
            num_samples=args.num_samples,
            length=args.walk_length,
            dropout=args.dropout,
            self_supervise_weight=args.self_supervise_weight,
            consistency_weight=args.consistency_weight,
            temperature=args.consistency_temperature,
            activation=getattr(torch.nn, args.activation)(),
            edge_features=args.edge_feat_dim,
        )
    elif args.model == 'ERUMGraphClassificationModel':
        model = ERUMGraphClassificationModel(
            in_features=args.feature_dim_size,
            out_features=args.num_classes,
            hidden_features=args.hidden_size,
            depth=args.depth,
            num_samples=args.num_samples,
            walk_length=args.walk_length,
            dropout=args.dropout,
            block_size=args.block_size,
            model_name_or_path=args.model_name_or_path,
            encoder_type=args.encoder_type,
            temperature=args.consistency_temperature,
            self_supervise_weight=args.self_supervise_weight,
            consistency_weight=args.consistency_weight,
            activation=getattr(torch.nn, args.activation)(),
            edge_features=args.edge_feat_dim,
        )
    
    # Training
    if args.do_train:
        # Prepare for distributed training
        if args.local_rank == -1:
            device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        else:
            device = torch.device("cuda", args.local_rank)
            model.to(device)
            
        # Prepare training dataset and dataloader
        train_dataset, train_dataloader = load_integrated_data(
            args, 
            args.train_data_file, 
            args.per_gpu_train_batch_size,  # Use per GPU batch size
            mode_type='train',
            sample_percent=args.training_percent
        )
        
        if args.evaluate_during_training:
            eval_dataset, eval_dataloader = load_integrated_data(
                args, 
                args.eval_data_file, 
                args.per_gpu_eval_batch_size,  # Use per GPU batch size
                mode_type='eval'
            )
            train(args, model, train_dataset, train_dataloader, eval_dataset, eval_dataloader)
        else:
            train(args, model, train_dataset, train_dataloader)

    eval_results = {}
    test_results = {}
    # chose which model to evaluate
    if args.do_use_best_f1:
        checkpoint_prefix = 'checkpoint-best-f1/model.bin'
        info = 'by best f1'
    else:
        checkpoint_prefix = 'checkpoint-best-acc/model.bin'
        info = 'by best acc'
    
    # Evaluation
    if args.do_eval and args.local_rank in [-1, 0]:
        eval_dataset, eval_dataloader = load_integrated_data(
            args, 
            args.eval_data_file, 
            args.per_gpu_eval_batch_size, 
            mode_type='eval'
        )
        
        output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix))  
        model.load_state_dict(torch.load(output_dir))
        model.to(args.device)
        result = evaluate(args, model, eval_dataset, eval_dataloader)
        logger.info(f"***** Eval results {info} *****")
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(round(result[key], 4)))
    
    # test
    if args.do_test and args.local_rank in [-1, 0]:
        test_dataset, test_dataloader = load_integrated_data(
            args, 
            args.test_data_file, 
            args.per_gpu_test_batch_size, 
            mode_type='test'
        )
        
        output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix))  
        model.load_state_dict(torch.load(output_dir))
        model.to(args.device)
        test_results = test(args, model, test_dataset, test_dataloader)

        logger.info(f"***** Test results {info} *****")
        for key in sorted(test_results.keys()):
            logger.info("  %s = %s", key, str(round(test_results[key], 4)))

    return eval_results, test_results

if __name__ == '__main__':
    start = time.time()
    main()
    end = time.time()
    print(f'---------------cost time: {end - start}---------------')