import os
import time
import math
import random
import numpy as np
import argparse
import torch
import torch.nn as nn
import wandb
# Set PyTorch to print full tensors
torch.set_printoptions(threshold=float('inf'), linewidth=200)
from gnn_data import GNN_DATA, GNN_DATA_Binding
from gnn_model import GIN_Net2,SimpleMLP,PIPR
from prompt_model import InactivePromptBinding
from utils import Metrictor_PPI, print_file, GateLoss,print_file_color,negative_sampling_from_node_pairs
import torch.nn.functional as F
from tqdm import tqdm
import pdb



def boolean_string(s):
    if s not in {'False', 'True'}:
        raise ValueError('Not a valid boolean string')
    return s == 'True'

def smooth_labels(labels, smoothing=0.1):
    """Apply label smoothing to prevent overconfident predictions"""
    return labels * (1 - smoothing) + smoothing * 0.5


parser = argparse.ArgumentParser(description='Train Model')
parser.add_argument('--description', default=None, type=str,
                    help='train description')
parser.add_argument('--ppi_path', default=None, type=str,
                    help="ppi path")
parser.add_argument('--pseq_path', default=None, type=str,
                    help="protein sequence path")
parser.add_argument('--vec_path', default=None, type=str,
                    help='protein sequence vector path')
parser.add_argument('--split_new', default=None, type=boolean_string,
                    help='split new index file or not')
parser.add_argument('--split_mode', default=None, type=str,
                    help='split method, random, bfs or dfs')
parser.add_argument('--train_valid_index_path', default=None, type=str,
                    help='cnn_rnn and gnn unified train and valid ppi index')
parser.add_argument('--use_lr_scheduler', default=None, type=boolean_string,
                    help="train use learning rate scheduler or not")
parser.add_argument('--save_path', default=None, type=str,
                    help='model save path')
parser.add_argument('--graph_only_train', default=None, type=boolean_string,
                    help='train ppi graph conctruct by train or all(train with test)')
parser.add_argument('--batch_size', default=None, type=int,
                    help="gnn train batch size, edge batch size")
parser.add_argument('--epochs', default=None, type=int,
                    help='train epoch number')
parser.add_argument('--num_token', default=None, type=int,
                    help='train token number')
parser.add_argument('--hidden', default=None, type=int,
                    help='hidden dim number')
parser.add_argument('--lr', default=None, type=float,
                    help='learning rate')
parser.add_argument('--gin_num_layer', default=None, type=int,
                    help='hidden dim number')
parser.add_argument('--th_epoch', default=None, type=int,
                    help='hidden dim number')
parser.add_argument('--use_jk', default=None, type=boolean_string,
                    help="train use learning rate scheduler or not")
parser.add_argument('--use_GRU', default=None, type=boolean_string,
                    help="train use learning rate scheduler or not")
parser.add_argument('--token_init', default='xavier', type=str,
                    help="token initialation method")
parser.add_argument('--model_type', default='prompt', type=str, choices=['prompt', 'gnn', 'mlp','PIPR'],
                    help='model type: prompt (InactivePromptBinding), gnn (GIN_Net2), or mlp (SimpleMLP)')

def train_one_epoch(model, graph, loss_fn, loss_fn_gate, optimizer, device, batch_size, epoch, th_epoch, got):
    gate_work = epoch >= th_epoch
    
    # Only apply gate mechanism for prompt models
    if hasattr(model, 'num_token') and hasattr(model, 'A_B_token'):  # Check if it's a prompt model
        if gate_work:
            model.trainable_param_gin()
            model.freeze_param_token()
        else:
            model.freeze_param_gin()
    else:
        # For non-prompt models (GNN, MLP), ensure all parameters are trainable
        for param in model.parameters():
            param.requires_grad = True

    steps = math.ceil(len(graph.train_mask) / batch_size)
    recall_sum = precision_sum = f1_sum = loss_sum = gate_loss_sum = 0.0
    model.train()

    for step in tqdm(range(steps)):
        if step == steps-1:
            train_edge_id = graph.train_pair_index[: ,step*batch_size:]
            label = graph.train_pair_label[step*batch_size:]
        else:
            train_edge_id = graph.train_pair_index[:, step*batch_size : step*batch_size + batch_size] 
            label = graph.train_pair_label[step*batch_size : step*batch_size + batch_size]

        
        output, prob = model(graph.x, graph.edge_index_got, train_edge_id, active_prune=0.2, gate_work=gate_work)
        # import pdb; pdb.set_trace()


        output = torch.clamp(output, min=-1.5, max=1.5)


        if hasattr(model, 'num_token') and hasattr(model, 'A_B_token') and gate_work:

            if isinstance(prob, torch.Tensor) and prob.numel() > 0:
                pred_confidence = torch.sigmoid(output).detach()

                if prob.dim() > 1:
                    prob_flat = prob.squeeze()
                else:
                    prob_flat = prob
                if pred_confidence.dim() > 1:
                    pred_flat = pred_confidence.squeeze()
                else:
                    pred_flat = pred_confidence

                min_size = min(prob_flat.size(0), pred_flat.size(0))
                if min_size > 0:
                    gate_loss = F.mse_loss(prob_flat[:min_size], pred_flat[:min_size]) * 0.1
                else:
                    gate_loss = torch.tensor(0.0, device=device)
            else:
                gate_loss = torch.tensor(0.0, device=device)
        else:
            gate_loss = torch.tensor(0.0, device=device)

        smooth_label = smooth_labels(label, smoothing=0.1)
        
        # print(output.T,label.T)
        loss = loss_fn(output, smooth_label) + gate_loss
        # print(label)

        optimizer.zero_grad()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        
        optimizer.step()

        pred = (torch.sigmoid(output) > 0.5).float().to(device)
        # print(pred.T,label)
        metrics = Metrictor_PPI(pred.cpu().data, label.cpu().data, True)
        metrics.show_result()


        if step == 0:
            pred_probs = torch.sigmoid(output)
            pos_predictions = (pred_probs > 0.5).sum().item()
            total_predictions = pred_probs.size(0)
            pos_ratio = pos_predictions / total_predictions
            actual_pos = label.sum().item()
            actual_pos_ratio = actual_pos / total_predictions
            

            dynamic_threshold = 0.5
            if pos_ratio > 0.9:
                dynamic_threshold = pred_probs.median().item()
                print(f"  Detected prediction bias towards positive class, dynamically adjusting threshold to: {dynamic_threshold:.3f}")

            if dynamic_threshold != 0.5:
                dynamic_pred = (pred_probs > dynamic_threshold).float().to(device)
                dynamic_metrics = Metrictor_PPI(dynamic_pred.cpu().data, label.cpu().data, True)
                dynamic_metrics.show_result()
                print(f"  F1 under dynamic threshold: {dynamic_metrics.F1:.3f} (vs fixed threshold: {metrics.F1:.3f})")

            logits_mean = output.mean().item()
            logits_std = output.std().item()
            logits_range = f"[{output.min().item():.3f}, {output.max().item():.3f}]"
            
            gate_info = ""
            if hasattr(model, 'num_token') and hasattr(model, 'A_B_token') and gate_work and isinstance(prob, torch.Tensor) and prob.numel() > 0:
                gate_mean = prob.mean().item()
                gate_std = prob.std().item()
                gate_info = f", Gate mean: {gate_mean:.3f}, Gate std: {gate_std:.3f}"

            if hasattr(model, 'num_token') and hasattr(model, 'A_B_token'):
                total_params = sum(p.numel() for p in model.parameters())
                trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
                frozen_params = total_params - trainable_params
                param_info = f", Trainable params: {trainable_params}/{total_params} ({trainable_params/total_params*100:.1f}%)"
                gate_info += param_info
            
            print(f"Epoch {epoch}, Step 0 - Positive prediction ratio: {pos_ratio:.3f}, Actual positive ratio: {actual_pos_ratio:.3f}")
            print(f"  Logits stats: mean={logits_mean:.3f}, std={logits_std:.3f}, range={logits_range}")
            print(f"  Prediction probability range: [{pred_probs.min():.3f}, {pred_probs.max():.3f}]{gate_info}")

        recall_sum += metrics.Recall
        precision_sum += metrics.Precision
        f1_sum += metrics.F1
        loss_sum += loss.item()
        gate_loss_sum += gate_loss.item()

        wandb.log({
            'train/loss': loss.item(),
            'train/gate_loss': gate_loss.item(),
            'train/precision': metrics.Precision,
            'train/recall': metrics.Recall,
            'train/F1': metrics.F1,
        })

    return {
        "loss": loss_sum / steps,
        "gate_loss": gate_loss_sum / steps,
        "recall": recall_sum / steps,
        "precision": precision_sum / steps,
        "f1": f1_sum / steps
    }

def eval_pair_batch(graph, pair_name: str, step: int, steps: int, batch_size: int,
                    model, loss_fn, device, gate_work=0.1, active_prune=0.1):

    pair_index = getattr(graph, f"{pair_name}_pair_index")
    pair_label = getattr(graph, f"{pair_name}_pair_label")

    start = step * batch_size
    end = pair_index.shape[1] if step == steps - 1 else start + batch_size

    edge_ids = pair_index[:, start:end]
    label = pair_label[start:end].view(-1, 1)

    output, _ = model(graph.x, graph.edge_index_got, edge_ids,
                      active_prune=active_prune, gate_work=gate_work)

    output = torch.clamp(output, min=-1.5, max=1.5)

    if step == 0:
        pred_probs = torch.sigmoid(output)
        pos_predictions = (pred_probs > 0.5).sum().item()
        total_predictions = pred_probs.size(0)
        pos_ratio = pos_predictions / total_predictions
        actual_pos = label.sum().item()
        actual_pos_ratio = actual_pos / total_predictions
        logits_range = f"[{output.min().item():.3f}, {output.max().item():.3f}]"
        print(f"  {pair_name.upper()} validation - Positive prediction ratio: {pos_ratio:.3f}, Actual positive ratio: {actual_pos_ratio:.3f}, Logits range: {logits_range}")

        if pos_ratio > 0.9 or pos_ratio < 0.1:
            adaptive_threshold = pred_probs.median().item()
            adaptive_pred = (pred_probs > adaptive_threshold).float().to(device)
            print(f"    Detected extreme prediction, using adaptive threshold: {adaptive_threshold:.3f}")
            return loss_fn(output, label).item(), adaptive_pred.cpu(), label.cpu()
    
    # print(torch.sigmoid(output).view(-1))
    loss = loss_fn(output, label)
    pred = (torch.sigmoid(output) > 0.5).float().to(device)

    return loss.item(), pred.cpu(), label.cpu()

@torch.no_grad()
def evaluate(model, graph, loss_fn, device, batch_size, gate_work,mode):
    model.eval()
    
    # Ensure evaluation mode doesn't freeze parameters incorrectly
    for param in model.parameters():
        param.requires_grad = True
    result_dict = {}

    if mode == 'random':
        pair_names = ['bs', 'es', 'ns']
    else:
        pair_names = ['es', 'ns']


    All_preds = []
    All_labels = []
    for pair_name in pair_names:
            pair_index = getattr(graph, f"{pair_name}_pair_index")
            steps = math.ceil(pair_index.shape[1] / batch_size)

            all_preds = []
            all_labels = []
            loss_sum = 0.0

            for step in range(steps):
                loss_item, pred, label = eval_pair_batch(
                    graph, pair_name=pair_name, step=step, steps=steps,
                    batch_size=batch_size, model=model, loss_fn=loss_fn,
                    device=device, gate_work=gate_work
                )
                #print(label.T,pred.T)
                loss_sum += loss_item
                all_preds.append(pred)
                all_labels.append(label)

            all_preds = torch.cat(all_preds, dim=0)
            all_labels = torch.cat(all_labels, dim=0)
            All_preds.append(all_preds)
            All_labels.append(all_labels)
            # pdb.set_trace()
            # print(all_preds.view(-1)[:100],all_labels.view(-1)[:100])

            metrics = Metrictor_PPI(all_preds, all_labels, True)
            metrics.show_result()

            result_dict[pair_name] = {
                "loss": loss_sum / steps,
                "recall": metrics.Recall,
                "precision": metrics.Precision,
                "f1": metrics.F1
            }
    All_preds = torch.cat(All_preds, dim=0)
    All_labels = torch.cat(All_labels, dim=0)
    metrics = Metrictor_PPI(All_preds, All_labels, True)
    metrics.show_result()
    result_dict['all'] = {
                "loss": loss_sum / steps,
                "recall": metrics.Recall,
                "precision": metrics.Precision,
                "f1": metrics.F1
            }

    return result_dict


def train(model, graph, ppi_list, loss_fn, optimizer, device,
          result_file_path, save_path,
          batch_size=512, epochs=1000, scheduler=None, 
          got=False, th_epoch=10, mode = 'random'):
    
    print(f"\n{'='*60}")
    print("Protection measures against 'always predicting positive class' activated:")
    print("   Logits clipping: [-1.5, +1.5]")
    print("   Extremely conservative learning rate (0.00005)")
    print("   Enhanced gradient clipping (max_norm=0.5)")
    print("   Enhanced regularization (weight_decay=1e-3, dropout=0.3)")
    print("   Label smoothing (smoothing=0.1)")
    print("   Temperature scaling calibration")
    print("   Adaptive validation threshold")
    print("   Early stopping protection mechanism")
    print(f"{'='*60}\n")
    
    loss_fn_gate = GateLoss(reduction='mean')
    global_best_bs_valid_f1 = 0.0
    global_best_bs_valid_f1_epoch = 0
    global_best_ns_valid_f1 = 0.0
    global_best_ns_valid_f1_epoch = 0
    global_best_es_valid_f1 = 0.0
    global_best_es_valid_f1_epoch = 0
    global_best_all_valid_f1 = 0.0
    global_best_all_valid_f1_epoch = 0
    

    early_stop_patience = 10
    best_val_loss = float('inf')
    patience_counter = 0
    extreme_prediction_count = 0

    for epoch in range(epochs):
        train_metrics = train_one_epoch(
            model, graph, loss_fn, loss_fn_gate, optimizer, device,
            batch_size, epoch, th_epoch, got
        )

        gate_work = epoch >= th_epoch
        valid_metrics = evaluate(model, graph, loss_fn, device, batch_size, gate_work,mode)

        es_valid_metrics = valid_metrics['es']
        ns_valid_metrics = valid_metrics['ns']
        all_valid_metrics = valid_metrics['all']
        if scheduler is not None:
            scheduler.step(train_metrics["loss"])
            print_file(f"epoch: {epoch}, now learning rate: {scheduler.optimizer.param_groups[0]['lr']}", save_file_path=result_file_path)


        print_file_color(
                "epoch: {epoch}, Training_avg: loss: {train_loss}, gate_loss: {gate_loss}, "
                "recall: {train_recall}, precision: {train_precision}, F1: {train_f1} "   ,
                color_dict={'train_loss': 'yellow'},          
                epoch=epoch,
                train_loss=train_metrics["loss"],
                gate_loss=train_metrics["gate_loss"],
                train_recall=train_metrics["recall"],
                train_precision=train_metrics["precision"],
                train_f1=train_metrics["f1"])
        if mode == 'random':
            bs_valid_metrics = valid_metrics['bs']
            if bs_valid_metrics["f1"] > global_best_bs_valid_f1:
                global_best_bs_valid_f1 = bs_valid_metrics["f1"]
                global_best_bs_valid_f1_epoch = epoch
                ckpt_path = os.path.join(save_path, 'gnn_model_bs_valid_best.ckpt')
                torch.save({'epoch': epoch, 'state_dict': model.state_dict()}, ckpt_path)
                wandb.save(ckpt_path)

            wandb.log({
                'bs_valid/loss': bs_valid_metrics["loss"],
                'bs_valid/precision': bs_valid_metrics["precision"],
                'bs_valid/recall': bs_valid_metrics["recall"],
                'bs_valid/F1': bs_valid_metrics["f1"],
                'best_bs_valid_F1': global_best_bs_valid_f1,
            })

            print_file_color(
                "epoch: {epoch}, Training_avg: loss: {train_loss}, gate_loss: {gate_loss}, "
                "recall: {train_recall}, precision: {train_precision}, F1: {train_f1}, "
                "bs_validation_avg: loss: {bs_valid_loss}, "
                "recall: {bs_valid_recall}, precision: {bs_valid_precision}, F1: {bs_valid_f1}, "
                "Best bs_valid_f1: {best_f1}, in {best_epoch} epoch",
                save_file_path=result_file_path,
                color_dict={
                    'train_loss': 'green',
                    'gate_loss': 'green',
                    'best_f1': 'green',
                    'bs_validation_avg: loss': 'green',
                },
                epoch=epoch,
                train_loss=train_metrics["loss"],
                gate_loss=train_metrics["gate_loss"],
                train_recall=train_metrics["recall"],
                train_precision=train_metrics["precision"],
                train_f1=train_metrics["f1"],
                bs_valid_loss=bs_valid_metrics["loss"],
                bs_valid_recall=bs_valid_metrics["recall"],
                bs_valid_precision=bs_valid_metrics["precision"],
                bs_valid_f1=bs_valid_metrics["f1"],
                best_f1=global_best_bs_valid_f1,
                best_epoch=global_best_bs_valid_f1_epoch
            )

        if es_valid_metrics["f1"] > global_best_es_valid_f1:
            global_best_es_valid_f1 = es_valid_metrics["f1"]
            global_best_es_valid_f1_epoch = epoch
            ckpt_path = os.path.join(save_path, 'gnn_model_es_valid_best.ckpt')
            torch.save({'epoch': epoch, 'state_dict': model.state_dict()}, ckpt_path)
            wandb.save(ckpt_path)

        wandb.log({
            'es_valid/loss': es_valid_metrics["loss"],
            'es_valid/precision': es_valid_metrics["precision"],
            'es_valid/recall': es_valid_metrics["recall"],
            'es_valid/F1': es_valid_metrics["f1"],
            'best_es_valid_F1': global_best_es_valid_f1,
        })

        print_file_color(
            "es_validation_avg: loss: {es_valid_loss}, "
            "recall: {es_valid_recall}, precision: {es_valid_precision}, F1: {es_valid_f1}, "
            "Best es_valid_f1: {best_f1}, in {best_epoch} epoch",
            save_file_path=result_file_path,
            color_dict={
                'train_loss': 'green',
                'gate_loss': 'green',
                'best_f1': 'green',
                'es_validation_avg: loss': 'green',
            },
            epoch=epoch,
            train_loss=train_metrics["loss"],
            gate_loss=train_metrics["gate_loss"],
            train_recall=train_metrics["recall"],
            train_precision=train_metrics["precision"],
            train_f1=train_metrics["f1"],
            es_valid_loss=es_valid_metrics["loss"],
            es_valid_recall=es_valid_metrics["recall"],
            es_valid_precision=es_valid_metrics["precision"],
            es_valid_f1=es_valid_metrics["f1"],
            best_f1=global_best_es_valid_f1,
            best_epoch=global_best_es_valid_f1_epoch
        )

        if ns_valid_metrics["f1"] > global_best_ns_valid_f1:
            global_best_ns_valid_f1 = ns_valid_metrics["f1"]
            global_best_ns_valid_f1_epoch = epoch
            ckpt_path = os.path.join(save_path, 'gnn_model_ns_valid_best.ckpt')
            torch.save({'epoch': epoch, 'state_dict': model.state_dict()}, ckpt_path)
            wandb.save(ckpt_path)

        wandb.log({
            'ns_valid/loss': ns_valid_metrics["loss"],
            'ns_valid/precision': ns_valid_metrics["precision"],
            'ns_valid/recall': ns_valid_metrics["recall"],
            'ns_valid/F1': ns_valid_metrics["f1"],
            'best_ns_valid_F1': global_best_ns_valid_f1,
        })

        print_file_color(
            "ns_validation_avg: loss: {ns_valid_loss}, "
            "recall: {ns_valid_recall}, precision: {ns_valid_precision}, F1: {ns_valid_f1}, "
            "Best ns_valid_f1: {best_f1}, in {best_epoch} epoch",
            save_file_path=result_file_path,
            color_dict={
                'train_loss': 'green',
                'gate_loss': 'green',
                'best_f1': 'green',
                'ns_validation_avg: loss': 'green',
            },
            epoch=epoch,
            train_loss=train_metrics["loss"],
            gate_loss=train_metrics["gate_loss"],
            train_recall=train_metrics["recall"],
            train_precision=train_metrics["precision"],
            train_f1=train_metrics["f1"],
            ns_valid_loss=ns_valid_metrics["loss"],
            ns_valid_recall=ns_valid_metrics["recall"],
            ns_valid_precision=ns_valid_metrics["precision"],
            ns_valid_f1=ns_valid_metrics["f1"],
            best_f1=global_best_ns_valid_f1,
            best_epoch=global_best_ns_valid_f1_epoch
        )
        if all_valid_metrics["f1"] > global_best_all_valid_f1:
            global_best_all_valid_f1 = all_valid_metrics["f1"]
            global_best_all_valid_f1_epoch = epoch
            ckpt_path = os.path.join(save_path, 'gnn_model_all_valid_best.ckpt')
            torch.save({'epoch': epoch, 'state_dict': model.state_dict()}, ckpt_path)
            wandb.save(ckpt_path)

        wandb.log({
            'all_valid/loss': all_valid_metrics["loss"],
            'all_valid/precision': all_valid_metrics["precision"],
            'all_valid/recall': all_valid_metrics["recall"],
            'all_valid/F1': all_valid_metrics["f1"],
            'best_all_valid_F1': global_best_all_valid_f1,
        })

        print_file_color(
            "all_validation_avg: loss: {all_valid_loss}, "
            "recall: {all_valid_recall}, precision: {all_valid_precision}, F1: {all_valid_f1}, "
            "Best all_valid_f1: {best_f1}, in {best_epoch} epoch",
            save_file_path=result_file_path,
            color_dict={
                'train_loss': 'green',
                'gate_loss': 'green',
                'best_f1': 'green',
                'all_validation_avg: loss': 'green',
            },
            epoch=epoch,
            train_loss=train_metrics["loss"],
            gate_loss=train_metrics["gate_loss"],
            train_recall=train_metrics["recall"],
            train_precision=train_metrics["precision"],
            train_f1=train_metrics["f1"],
            all_valid_loss=all_valid_metrics["loss"],
            all_valid_recall=all_valid_metrics["recall"],
            all_valid_precision=all_valid_metrics["precision"],
            all_valid_f1=all_valid_metrics["f1"],
            best_f1=global_best_all_valid_f1,
            best_epoch=global_best_all_valid_f1_epoch
        )



def main():
    import torch
    SEED =40
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

    args = parser.parse_args()
    print(args.graph_only_train)
    wandb.init(
        project="GNN-PPI-yeast-seed-all-"+str(SEED),
        name=f"{args.description}_{args.model_type}_{time.strftime('%Y-%m-%d_%H-%M-%S')}",
        config=vars(args),
    )
    ppi_data = GNN_DATA_Binding(ppi_path=args.ppi_path)

    print("use_get_feature_origin")
    ppi_data.get_feature_origin(pseq_path=args.pseq_path, vec_path=args.vec_path)

    ppi_data.generate_data()
    print('a',ppi_data.data)
    import torch
    import networkx as nx
    from torch_geometric.utils import to_networkx



    # G = to_networkx(ppi_data.data, to_undirected=True)


    # num_components = nx.number_connected_components(G)
    # components = list(nx.connected_components(G))


    # for i, nodes in enumerate(components):
    #     subgraph = G.subgraph(nodes)
    #     num_nodes = subgraph.number_of_nodes()
    #     num_edges = subgraph.number_of_edges()



    print("----------------------- start split train and valid index -------------------")
    print("whether to split new train and valid index file, {}".format(args.split_new))
    if args.split_new:
        print("use {} method to split".format(args.split_mode))
    ppi_data.split_dataset(args.train_valid_index_path, random_new=args.split_new, mode=args.split_mode)
    print("----------------------- Done split train and valid index -------------------")

    graph = ppi_data.data
    ppi_list = ppi_data.ppi_list
    # if 'train_index' in ppi_data.ppi_split_dict:
    #     graph.train_mask = ppi_data.ppi_split_dict['train_index']
    # if 'valid_index' in ppi_data.ppi_split_dict:
    #     graph.val_mask = ppi_data.ppi_split_dict['valid_index']


    # print("train gnn, train_num: {}, valid_num: {}".format(len(graph.train_mask), len(graph.val_mask)))
    graph.edge_index_got = torch.cat((graph.edge_index[:, graph.train_mask], graph.edge_index[:, graph.train_mask][[1, 0]]), dim=1)
    graph.train_mask_got = [i for i in range(len(graph.train_mask))]

    graph.train_pair_index = ppi_data.ppi_split_dict['train_pair_index']
    graph.train_pair_label = ppi_data.ppi_split_dict['train_pair_label']
    index = torch.randperm(graph.train_pair_index.shape[1])
    graph.train_pair_index = graph.train_pair_index[:, index]
    graph.train_pair_label = graph.train_pair_label[index]

    if args.split_mode == 'random':
        graph.bs_pair_index = ppi_data.ppi_split_dict['bs_pair_index']
        graph.bs_pair_label = ppi_data.ppi_split_dict['bs_pair_label']
        index = torch.randperm(graph.bs_pair_index.shape[1])
        graph.bs_pair_index = graph.bs_pair_index[:, index]
        graph.bs_pair_label = graph.bs_pair_label[index]

    graph.es_pair_index = ppi_data.ppi_split_dict['es_pair_index']
    graph.es_pair_label = ppi_data.ppi_split_dict['es_pair_label']
    index = torch.randperm(graph.es_pair_index.shape[1])
    graph.es_pair_index = graph.es_pair_index[:, index]
    graph.es_pair_label = graph.es_pair_label[index]

    graph.ns_pair_index = ppi_data.ppi_split_dict['ns_pair_index']
    graph.ns_pair_label = ppi_data.ppi_split_dict['ns_pair_label']
    index = torch.randperm(graph.ns_pair_index.shape[1])
    graph.ns_pair_index = graph.ns_pair_index[:, index]
    graph.ns_pair_label = graph.ns_pair_label[index]

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    graph.to(device)
    
    # Initialize model based on model_type
    if args.model_type == 'prompt':
        print("Using Prompt model (InactivePromptBinding)")
        model = InactivePromptBinding(in_len=2000, hidden=args.hidden, gin_num_layer = args.gin_num_layer, num_token = args.num_token, \
                               use_jk = args.use_jk, use_GRU = args.use_GRU,method = args.token_init, device=device).to(device)
    elif args.model_type == 'gnn':
        print("Using GNN model (GIN_Net2)")
        base_model = GIN_Net2(in_len=2000, in_feature=13, gin_in_feature=256, num_layers=args.gin_num_layer, pool_size=3, cnn_hidden=1).to(device)
        
        # Wrapper class to make GIN_Net2 compatible with prompt training interface
        class GINWrapper(nn.Module):
            def __init__(self, base_model):
                super().__init__()
                self.base_model = base_model
            
            def forward(self, x, edge_index, train_edge_id, active_prune=0.1, gate_work=False):
                output = self.base_model(x, edge_index, train_edge_id)
                prob = torch.sigmoid(output)  # Dummy prob for compatibility
                return output, prob
            
            def trainable_param_gin(self):
                for param in self.base_model.parameters():
                    param.requires_grad = True
            
            def freeze_param_gin(self):
                for param in self.base_model.parameters():
                    param.requires_grad = False
            
            def freeze_param_token(self):
                pass  # No token parameters in GNN
        
        model = GINWrapper(base_model)

    elif args.model_type == 'mlp':
        print("Using MLP model (SimpleMLP)")
        model = SimpleMLP(13,args.hidden).to(device)
    elif args.model_type == 'PIPR':
        model = PIPR(in_len=2000, in_feature=13, gin_in_feature=256, num_layers=args.gin_num_layer, pool_size=3, cnn_hidden=1).to(device)
    else:
        raise ValueError(f"Unknown model type: {args.model_type}")
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-3)

    scheduler = None
    if args.use_lr_scheduler:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, verbose=True)

    # Use weighted loss for better class balance (especially important for MLP)
    # pos_weight helps when there's slight class imbalance
    pos_weight = torch.tensor([1.0]).to(device)  # Can be adjusted based on data distribution
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(device)
    
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    time_stamp = time.strftime("%Y-%m-%d %H:%M:%S")
    if args.model_type == 'prompt':
        save_path = os.path.join(args.save_path, "{}_prompt_{}_{}_{}_{}_{}_{}".format(args.model_type, args.description, time_stamp, args.hidden, args.lr, args.gin_num_layer, args.num_token))
    else:
        save_path = os.path.join(args.save_path, "{}_{}_{}_{}_{}_{}_{}".format(args.model_type, args.description, time_stamp, args.hidden, args.lr, args.gin_num_layer, args.num_token if args.model_type == 'prompt' else 'na'))
    result_file_path = os.path.join(save_path, "valid_results.txt")
    config_path = os.path.join(save_path, "config.txt")
    os.mkdir(save_path)

    with open(config_path, 'w') as f:
        args_dict = args.__dict__
        for key in args_dict:
            f.write("{} = {}".format(key, args_dict[key]))
            f.write('\n')
        f.write('\n')
        # f.write("train gnn, train_num: {}, valid_num: {}".format(len(graph.train_mask), len(graph.val_mask)))


    train(model,  graph, ppi_list, loss_fn, optimizer, device,
        result_file_path, save_path,
        batch_size=args.batch_size, epochs=args.epochs, scheduler=scheduler, 
        got=args.graph_only_train,th_epoch=args.th_epoch,mode = args.split_mode)


    wandb.finish()



if __name__ == "__main__":
    main()