from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
from torch.optim import Adam
from scGCN import GlobalBranch
from torch.optim.lr_scheduler import StepLR
import scipy.sparse as sp
from utils import scRNADataset, load_data, adj2saprse_tensor, Evaluation,  Network_Statistic
import pandas as pd
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import random
import glob
import os

import time
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, default=3e-3, help='Initial learning rate.')
parser.add_argument('--epochs', type=int, default= 30, help='Number of epoch.')
parser.add_argument('--alpha', type=float, default=0.2, help='Alpha for the leaky_relu.')
parser.add_argument('--hidden_dim', type=int, default=[128,64,32], help='The dimension of hidden layer')
parser.add_argument('--output_dim', type=int, default=16, help='The dimension of latent layer')
parser.add_argument('--batch_size', type=int, default=256, help='The size of each batch')
parser.add_argument('--loop', type=bool, default=False, help='whether to add self-loop in adjacent matrix')
parser.add_argument('--seed', type=int, default=8, help='Random seed')
parser.add_argument('--Type',type=str,default='dot', help='score metric')
parser.add_argument('--flag', type=bool, default=False, help='the identifier whether to conduct causal inference')
parser.add_argument('--reduction',type=str,default='concate', help='how to integrate multihead attention')
parser.add_argument('--base_path', type=str, default=None, help='Base path for the dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping.')
parser.add_argument('--monitor', type=str, default='AUC', help='Metric to monitor for early stopping (AUC or AUPR).')

args = parser.parse_args()
seed = args.seed
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)

# if args.base_path:
#     base_path = args.base_path
base_path = 'yourpath/scUniGP/Data_processing/Dataspilt/STRING/mHSC-L/TFs_500'

def embed2file(gene_file, **kwargs):
    """
    保存多个嵌入张量到CSV文件。
    :param gene_file: 用作索引的基因列表文件路径。
    :param kwargs: {path: tensor} 形式的嵌入保存路径和张量。
    """
    gene_set = pd.read_csv(gene_file)
    gene_set.index = gene_set.iloc[:, 0]
    gene_set = gene_set.drop(gene_set.columns[0], axis=1)
    gene_set.index.name = 'Gene'
    for path, tensor in kwargs.items():
        if tensor is None:
            continue
        embed_np = tensor.cpu().detach().numpy()
        embed_df = pd.DataFrame(embed_np, index=gene_set.index)
        embed_df.to_csv(path)
        print(f"Embedding saved to {path}")

if args.base_path:
    base_path = args.base_path
# base_path = 'yourpath/scUniGP/Data_processing/Dataspilt/Specific/hESC/TFs_500'

exp_file = f'{base_path}/BL--ExpressionData.csv'
tf_file = f'{base_path}/TF.csv'
target_file = f'{base_path}/Target.csv'

train_file = f'{base_path}/Train_set.csv'
val_file = f'{base_path}/Validation_set.csv'
# 假设你已经有了测试数据文件路径
test_file = f'{base_path}/Test_set.csv'


# 设置嵌入保存路径
layer1_embed_path = f'{base_path}/gene_gcn1_embedding{args.hidden_dim[0]}.csv'
layer2_embed_path = f'{base_path}/gene_gcn2_embedding{args.hidden_dim[1]}.csv'
tf_embed_path = f'{base_path}/gcn_TF_Channel1.csv'
target_embed_path = f'{base_path}/gcn_Target_Channel2.csv'



data_input = pd.read_csv(exp_file,index_col=0)
loader = load_data(data_input)
feature = loader.exp_data()
tf = pd.read_csv(tf_file,index_col=0)['index'].values.astype(np.int64)


target = pd.read_csv(target_file,index_col=0)['index'].values.astype(np.int64)
feature = torch.from_numpy(feature)
tf = torch.from_numpy(tf)



device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data_feature = feature.to(device)
tf = tf.to(device)


train_data = pd.read_csv(train_file, index_col=0).values
validation_data = pd.read_csv(val_file, index_col=0).values
test_data = pd.read_csv(test_file, index_col=0).values


train_load = scRNADataset(train_data, feature.shape[0], flag=args.flag)
val_load = scRNADataset(validation_data, feature.shape[0], flag=args.flag)
test_load = scRNADataset(test_data, feature.shape[0], flag=args.flag)

adj = train_load.Adj_Generate(tf,loop=args.loop)
adj = adj2saprse_tensor(adj)


train_data = torch.from_numpy(train_data)
val_data = torch.from_numpy(validation_data)
test_data = torch.from_numpy(test_data)

model = GlobalBranch(input_dim=feature.size()[1],
                hidden1_dim=args.hidden_dim[0],
                hidden2_dim=args.hidden_dim[1],
                hidden3_dim=args.hidden_dim[2],
                output_dim=args.output_dim,
                device=device,
                type=args.Type,
                reduction=args.reduction
                )


adj = adj.to(device)
model = model.to(device)
train_data = train_data.to(device)
validation_data = val_data.to(device)
test_data = test_data.to(device)

optimizer = Adam(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=0.99)

# 创建目录并生成时间戳
timestamp = int(time.time())
model_dir = 'model'
os.makedirs(model_dir, exist_ok=True)  # 创建目录（如果不存在）
model_path = os.path.join(model_dir, f'saved_model_{timestamp}.pkl')

# 初始化列表用于存储每个batch的预测值和真实标签
all_train_y_true = []
all_train_y_pred = []

# --- Early Stopping and Checkpointing Initialization ---
best_val_score = 0.0
best_model_state = None
patience_counter = 0
print(f"Monitoring validation {args.monitor} for early stopping with patience {args.patience}.")
# ---

for epoch in range(args.epochs):
    running_loss = 0.0
    model.train()

    # 每个epoch开始时清空存储的预测和真实标签
    all_train_y_true.clear()
    all_train_y_pred.clear()

    for train_x, train_y in DataLoader(train_load, batch_size=args.batch_size, shuffle=True):
        optimizer.zero_grad()

        if args.flag:
            train_y = train_y.to(device)
        else:
            train_y = train_y.to(device).view(-1, 1)

        pred = model(data_feature, adj, train_x)
        if args.flag:
            pred = torch.softmax(pred, dim=1)
        else:
            pred = torch.sigmoid(pred)

        loss_BCE = F.binary_cross_entropy(pred, train_y)
        loss_BCE.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss_BCE.item()

        # 收集训练集的真实标签和预测值
        if args.flag:
            pred_for_metric = pred[:, -1].cpu().detach().numpy().flatten()
        else:
            pred_for_metric = pred.cpu().detach().numpy().flatten()

        all_train_y_true.extend(train_y.cpu().numpy().flatten())
        all_train_y_pred.extend(pred_for_metric)

    # 计算训练集的AUC和AUPR
    AUC_train, AUPR_train, _ = Evaluation(y_pred=torch.tensor(all_train_y_pred), y_true=torch.tensor(all_train_y_true),
                                          flag=args.flag)

    model.eval()
    with torch.no_grad():
        # 验证集评估
        score_val = model(data_feature, adj, validation_data)
        if args.flag:
            score_val = torch.softmax(score_val, dim=1)
        else:
            score_val = torch.sigmoid(score_val)
        AUC_val, AUPR_val, AUPR_norm_val = Evaluation(y_pred=score_val, y_true=validation_data[:, -1], flag=args.flag)

        # 测试集评估
        score_test = model(data_feature, adj, test_data)
        if args.flag:
            score_test = torch.softmax(score_test, dim=1)
        else:
            score_test = torch.sigmoid(score_test)
        AUC_test, AUPR_test, AUPR_norm_test = Evaluation(y_pred=score_test, y_true=test_data[:, -1], flag=args.flag)

    print('Epoch:{}'.format(epoch + 1),
          'train loss:{}'.format(running_loss),
          'train AUC:{:.3F}'.format(AUC_train),
          'train AUPR:{:.3F}'.format(AUPR_train),
          '| Validation AUC:{:.3F}'.format(AUC_val),
          'Validation AUPR:{:.3F}'.format(AUPR_val),
          '| Test AUC:{:.3F}'.format(AUC_test),
          'Test AUPR:{:.3F}'.format(AUPR_test))

    # --- Early Stopping and Checkpointing Logic ---
    current_val_score = AUC_val if args.monitor == 'AUC' else AUPR_val
    if current_val_score > best_val_score:
        best_val_score = current_val_score
        best_model_state = model.state_dict()
        patience_counter = 0
        print(f"Validation {args.monitor} improved to {best_val_score:.4f}. Saving model checkpoint.")
    else:
        patience_counter += 1
        print(f"Validation {args.monitor} did not improve for {patience_counter} epoch(s).")

    if patience_counter >= args.patience:
        print(f"Early stopping triggered after {patience_counter} epochs with no improvement.")
        break
    # ---

    model.train()

# 保存模型
torch.save(model.state_dict(), model_path)
print('Model saved to:', model_path)

# 加载最佳模型并获取嵌入
if best_model_state:
    print(f"\nLoading best model from epoch with validation {args.monitor} = {best_val_score:.4f} for embedding generation.")
    model.load_state_dict(best_model_state)
else:
    print("\nWarning: No best model found. Using model from the last epoch.")
model.eval()

# 运行一次前向传播以确保所有嵌入都已计算
with torch.no_grad():
    _ = model(data_feature, adj, train_data[:1])

# 获取两层GCN的嵌入
layer1_embed, layer2_embed = model.get_layer_embeddings()
print("Layer 1 Embedding Shape:", layer1_embed.shape)
print("Layer 2 Embedding Shape:", layer2_embed.shape)

# 获取最终的TF和Target嵌入
tf_embed, target_embed = model.get_embedding()
print("TF Embedding Shape:", tf_embed.shape)
print("Target Embedding Shape:", target_embed.shape)

# 保存所有嵌入
embed2file(
    gene_file=exp_file,
    **{
        layer1_embed_path: layer1_embed,
        layer2_embed_path: layer2_embed,
        tf_embed_path: tf_embed,
        target_embed_path: target_embed
    }
)

# --------- 保存所有基因对的得分 ---------
batch_size = 512
all_samples = torch.cat([train_data, validation_data, test_data], dim=0)
unique_samples, unique_indices = torch.unique(all_samples[:, :2], dim=0, return_inverse=True)
model.eval()
preds = []
with torch.no_grad():
    for i in range(0, len(unique_samples), batch_size):
        batch_pairs = unique_samples[i:i+batch_size]
        pred = model(data_feature, adj, batch_pairs)
        if args.flag:
            pred = torch.softmax(pred, dim=1)[:, -1]
        else:
            pred = torch.sigmoid(pred).squeeze()
        preds.append(pred.cpu())
all_preds = torch.cat(preds, dim=0).numpy()
tf_ids = unique_samples[:, 0].cpu().numpy()
target_ids = unique_samples[:, 1].cpu().numpy()
results_df = pd.DataFrame({
    'TFs': tf_ids,
    'Targets': target_ids,
    'Predicted Labels': all_preds
})
prediction_file = f'{base_path}/gcn_predictions.csv'
results_df.to_csv(prediction_file, index=False)
print(f'GCN predictions saved to {prediction_file}')






