import numpy as np
import pandas as pd
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler
from dataset import Dataset,adj2saprse_tensor
from model import scUniGP
from train_val import train, validate
from losses import get_loss_function
import os


def main(data_dir, args):

    expression_data_path = data_dir + '/BL--ExpressionData.csv'
    train_data_path = data_dir + '/Train_set.csv'
    val_data_path = data_dir + '/Validation_set.csv'
    test_data_path = data_dir + '/Test_set.csv'

    tf_gnn_embedding_file = data_dir + '/TF_gat_Channel1.csv'
    target_gnn_embedding_file = data_dir + '/Target_gat_Channel2.csv'
    l1_gnn_embedding_file = data_dir + '/gene_gat1_embedding128.csv'
    l2_gnn_embedding_file = data_dir + '/gene_gat2_embedding64.csv'

    expression_data = np.array(pd.read_csv(expression_data_path, index_col=0, header=0))
    tf_gnn_embedding_data = np.array(pd.read_csv(tf_gnn_embedding_file, index_col=0, header=0))
    target_gnn_embedding_data = np.array(pd.read_csv(target_gnn_embedding_file, index_col=0, header=0))
    l1_gnn_embedding_data = np.array(pd.read_csv(l1_gnn_embedding_file, index_col=0, header=0))
    l2_gnn_embedding_data = np.array(pd.read_csv(l2_gnn_embedding_file, index_col=0, header=0))

    TF_path = data_dir + '/TF.csv'
    TF = torch.from_numpy(pd.read_csv(TF_path, index_col=0, header=0)['index'].values.astype(np.int64))

    standard = StandardScaler()
    scaled_df = standard.fit_transform(expression_data.T)
    expression_data = scaled_df.T
    expression_data_shape = expression_data.shape

    train_dataset = Dataset(train_data_path, expression_data)
    val_dataset = Dataset(val_data_path, expression_data)
    test_dataset = Dataset(test_data_path, expression_data)

    adj = train_dataset.Adj_Generate(TF, loop=False)
    adj = adj2saprse_tensor(adj).coalesce()

    Batch_size = args.batch_size
    Embed_size = args.embed_size
    Num_layers = args.num_layers
    Num_head = args.num_head
    LR = args.lr
    EPOCHS = args.epochs
    step_size = args.step_size
    gamma = args.gamma
    global schedulerflag
    schedulerflag = args.scheduler_flag

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=Batch_size,
                                               shuffle=True,
                                               drop_last=False,
                                               num_workers=8)

    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                             batch_size=Batch_size,
                                             shuffle=False,
                                             drop_last=False)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=Batch_size,
                                              shuffle=False,
                                              drop_last=False)

    T = scUniGP(expression_data_shape, Embed_size, Num_layers, Num_head,
                tf_gnn_embedding_data, target_gnn_embedding_data,
                l1_gnn_embedding_data, l2_gnn_embedding_data,
                use_l1=True,use_l2=True,use_tf=True,use_target=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    T = T.to(device)
    optimizer = torch.optim.Adam(T.parameters(), lr=LR, weight_decay=getattr(args, 'weight_decay', 1e-5))
    scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
    
    loss_type = getattr(args, 'loss_type', 'combined')
    loss_func = get_loss_function(
        loss_type=loss_type,
        focal_weight=getattr(args, 'focal_weight', 0.7),
        bce_weight=getattr(args, 'bce_weight', 0.2),
        consistency_weight=getattr(args, 'consistency_weight', 0.1),
        focal_alpha=getattr(args, 'focal_alpha', 1),
        focal_gamma=getattr(args, 'focal_gamma', 2),
        label_smoothing=getattr(args, 'label_smoothing', 0.1)
    )
    print("current loss function :",loss_type)

    best_val_auc = 0.0
    best_val_aupr = 0.0
    all_val_aucs = []
    all_val_auprs = []
    best_test_auc = 0.0
    best_test_aupr = 0.0
    all_test_aucs = []
    all_test_auprs = []

    patience = args.patience
    epochs_no_improve = 0
    save_dir = './model_pt'
    os.makedirs(save_dir, exist_ok=True)
    last3 = os.path.normpath(data_dir).split(os.sep)[-3:]
    model_name = '_'.join(last3) + '.pth'
    model_save_path = os.path.join(save_dir, model_name)
    best_model_wts = T.state_dict()

    for epoch in range(1, EPOCHS + 1):
        train(T, train_loader, loss_func, optimizer, epoch, scheduler, args)
        AUC_val, AUPR_val = validate(T, val_loader, loss_func)

        if AUC_val > best_val_auc:
            best_val_auc = AUC_val
            epochs_no_improve = 0
            best_model_wts = T.state_dict()
            torch.save(best_model_wts, model_save_path)
            print(f"[Model Saved] 当前最优模型已保存到: {model_save_path}")
        else:
            epochs_no_improve += 1
        if AUPR_val > best_val_aupr:
            best_val_aupr = AUPR_val

        all_val_aucs.append(AUC_val)
        all_val_auprs.append(AUPR_val)

        print('-' * 100)
        print(f'| end of epoch {epoch:3d} | valid AUROC {AUC_val:8.3f} | valid AUPRC {AUPR_val:8.3f}')
        print(f'| Current Best | valid AUROC {best_val_auc:8.3f} | valid AUPRC {best_val_aupr:8.3f}')

        if AUC_val < 0.501:
            print("AUC_val < 0.501 !!")
            break

        if epoch % 5 == 0 or epoch == EPOCHS:
            AUC_test, AUPR_test = validate(T, test_loader, loss_func)

            if AUC_test > best_test_auc:
                best_test_auc = AUC_test
            if AUPR_test > best_test_aupr:
                best_test_aupr = AUPR_test

            all_test_aucs.append(AUC_test)
            all_test_auprs.append(AUPR_test)

            print('| end of epoch {:3d} | test  AUROC {:8.3f} | test  AUPRC {:8.3f}'.format(epoch, AUC_test, AUPR_test))
            print('| Current Best | test  AUROC {:8.3f} | test  AUPRC {:8.3f}'.format(best_test_auc, best_test_aupr))
        
        print('-' * 100)

        if epochs_no_improve >= patience:
            print(f'Early stopping triggered after {patience} epochs with no improvement.')
            break

    avg_val_auc = sum(all_val_aucs) / len(all_val_aucs)
    avg_val_aupr = sum(all_val_auprs) / len(all_val_auprs)
    avg_test_auc = sum(all_test_aucs) / len(all_test_aucs) if all_test_aucs else 0.0
    avg_test_aupr = sum(all_test_auprs) / len(all_test_auprs) if all_test_auprs else 0.0

    print('\nFinal Results:')
    print(f'Average val AUROC: {avg_val_auc:.3f} | Best val AUROC: {best_val_auc:.3f}')
    print(f'Average val AUPRC: {avg_val_aupr:.3f} | Best val AUPRC: {best_val_aupr:.3f}')
    print(f'Average test AUROC: {avg_test_auc:.3f} | Best test AUROC: {best_test_auc:.3f}')
    print(f'Average test AUPRC: {avg_test_aupr:.3f} | Best test AUPRC: {best_test_aupr:.3f}')

    result_file = 'yourpath/scUniGP/result_tf500.log'
    with open(result_file, 'a') as f:
        f.write(f'==== Results for {data_dir} ====\n')
        f.write('Final Results:\n')
        f.write(f'Average val AUROC: {avg_val_auc:.3f} | Best val AUROC: {best_val_auc:.3f}\n')
        f.write(f'Average val AUPRC: {avg_val_aupr:.3f} | Best val AUPRC: {best_val_aupr:.3f}\n')
        f.write(f'Average test AUROC: {avg_test_auc:.3f} | Best test AUROC: {best_test_auc:.3f}\n')
        f.write(f'Average test AUPRC: {avg_test_aupr:.3f} | Best test AUPRC: {best_test_aupr:.3f}\n')
        f.write('\n')
    print(f"Results appended to {result_file}")