# : run_single_trial.py ()

import argparse
import torch
import numpy as np
import random
import sys
import os
import csv # <<< MODIFIED: Import csv module

# =============================================================================
PROJECT_ROOT = '/home/wxy/rcx/DYNOTEARS*/rub_LIN'
# =============================================================================
sys.path.append(PROJECT_ROOT)
from LIN.models.LIN_GaussianNet import LIN_GaussianNet
from LIN.utils import get_ri, get_ari, causalGraph_metrics, Logger

# <<< MODIFIED: New function to handle CSV writing >>>
def write_results_to_csv(filepath, data_dict):
    """（）CSV。"""
    file_exists = os.path.isfile(filepath)
    # ，CSV
    sorted_keys = sorted(data_dict.keys())
    
    with open(filepath, 'a', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=sorted_keys)
        
        if not file_exists:
            writer.writeheader()
            
        # 
        sorted_row = {key: data_dict[key] for key in sorted_keys}
        writer.writerow(sorted_row)

# 1. 
parser = argparse.ArgumentParser(description="Run a single trial for the Edgeexist experiment.")
parser.add_argument('--E', type=int, default=3, help='Number of clusters for the model')
parser.add_argument('--K', type=int, default=4, help='Number of clusters for the model')
parser.add_argument('--T', type=int, default=5000, help='Number of clusters for the model')
parser.add_argument('--pnt', type=float, default=1e-8, help='Penalty coefficient for structure learning')
parser.add_argument('--rdseeds', type=int, default=42, help='Random seed for reproducibility')
parser.add_argument('--edge_pnt', type=float, default=0.0, help='Penalty for the edge existence prior.')
parser.add_argument('--prior_rate', type=float, default=1.0, help='Fraction of true edges to use as prior knowledge.')
parser.add_argument('--threshold', type=float, default=0.5)

# <<< MODIFIED: Add arguments for data filename and results csv path >>>
parser.add_argument('--data_path', type=str, required=True, help='Prefix of the data files to load.')
parser.add_argument('--data_file_prefix', type=str, required=True, help='Prefix of the data files to load.')
parser.add_argument('--results_csv', type=str, default=None, help='Path to the CSV file to save aggregated results.')

opt = parser.parse_args()

# 2. 
D_VAL = 5
P_VAL = 2
T_VAL = opt.T
K_TRUE = opt.K

# 3. 
torch.manual_seed(opt.rdseeds)
torch.cuda.manual_seed(opt.rdseeds)
np.random.seed(opt.rdseeds)
random.seed(opt.rdseeds)

# 4. 
results_save_path = './results'
suffix = f"_pnt{int(-np.log10(opt.pnt))}_epnt{opt.edge_pnt}_prate{opt.prior_rate}_seed{opt.rdseeds}"
file_name = f"trial_d{D_VAL}_P{P_VAL}_E{opt.E}" + suffix
if not os.path.exists(results_save_path):
    os.makedirs(results_save_path)
log_path = os.path.join(results_save_path, file_name + '.log') 
# print(f": {os.path.abspath(log_path)}")
# sys.stdout = Logger(log_path) # ，log

print(f": d={D_VAL}, P={P_VAL}, T={T_VAL}, pnt={opt.pnt}, edge_pnt={opt.edge_pnt}, prior_rate={opt.prior_rate}, seed={opt.rdseeds}")

# 5. 
data_load_path = opt.data_path
# <<< MODIFIED: Use the passed data file prefix >>>
data_file_name = opt.data_file_prefix 

G_true = torch.load(os.path.join(data_load_path, data_file_name + '_G.pt'))
data = torch.load(os.path.join(data_load_path, data_file_name + '_data.pt'))
data = data[:T_VAL]
intervention_type = torch.load(os.path.join(data_load_path, data_file_name + '_intvs.pt'))
intervention_type = intervention_type[:T_VAL]
# 6. 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data = data.to(device)
intervention_type = intervention_type.to(device)
G_true = G_true.to(device)

# <<< MODIFIED: This logic for creating prior is still correct >>>
true_edges_any_lag = torch.any(G_true, dim=0).float()
G_prior_existence = torch.zeros_like(true_edges_any_lag)
if opt.prior_rate > 0 and opt.edge_pnt > 0:
    existing_edge_indices = true_edges_any_lag.nonzero(as_tuple=False)
    num_existing_edges = existing_edge_indices.shape[0]
    num_prior_edges = int(num_existing_edges * opt.prior_rate)
    if num_prior_edges > 0:
        shuffled_indices = torch.randperm(num_existing_edges)
        selected_indices = existing_edge_indices[shuffled_indices[:num_prior_edges]]
        rows, cols = selected_indices.unbind(dim=1)
        G_prior_existence[rows, cols] = 1.0

# 7. 
# <<< MODIFIED: evaluate_model now returns a dictionary of metrics >>>
def evaluate_model(model, G_prior, calc_dist=False):
    """
    ，。
    :param model: 。
    :param G_prior: (d, d) ，。
    :param calc_dist:  SID。
    :return: 。
    """
    # ---  ---
    accu, recall, shd, sid, F1  = causalGraph_metrics(model.G, G_true, calc_dist=calc_dist, allow_skeletion=False)
    
    # ---  ---
    the_intv_p = model.intv_p
    ri = get_ri(intervention_type, the_intv_p.argmax(1))
    ari = get_ari(intervention_type, the_intv_p.argmax(1))
    
    # --- ： (Prior Recovery Rate) ---
    # 1. 、 lag 
    print(f"threshold:{opt.threshold}")
    print(torch.sigmoid(model.G).min())
    print(torch.sigmoid(model.G).max())
    print(model.G)
    est_G_any_lag = torch.any((torch.sigmoid(model.G) > opt.threshold), dim=0).float()
    
    # 2. 
    num_priors = G_prior.sum().item()
    
    # 3. 
    if num_priors > 0:
        # ，1
        recovered_priors = (est_G_any_lag * G_prior).sum().item()
        print(recovered_priors)
        print(num_priors)
        prior_recovery_rate = recovered_priors / num_priors
    else:
        # ，0（NaN，0）
        prior_recovery_rate = 0.0

    # ---  ---
    print('\n---  ---')
    print('\taccu\trecall\tF1\tshd\tsid\tri\tari\tprior_rec') # 
    print(('\t{:6.4f}' * 8).format(accu, recall, F1, shd, sid, ri, ari, prior_recovery_rate)) # 
    print('--------------------------\n')

    # 
    metrics = {
        'accu': accu, 'recall': recall, 'F1': F1, 'shd': shd, 'sid': sid, 'ri': ri, 'ari': ari,
        'prior_recovery_rate': prior_recovery_rate
    }
    return metrics


# 8. 
intv_args = {"hidden_dim": D_VAL, "n_hidden_lyr": 1}
fit_args = {
    "epoch": 1000, "lr_net": 1e-2, "lr": 1e-2, "batch_size": 256,
    "train_sample": 0.8, "struct_pnt_coeff":  (opt.pnt, opt.pnt, 0), 
    "patient": 100, "update_patient": 3, "tol_rate": 0, 
    "itr_per_epoch": 100, "verbose_period": 10, # Set high to reduce log spam
    "lag_delta": 0.99,
    "G_prior_existence": G_prior_existence,
    "edge_existence_pnt": opt.edge_pnt,
    "threshold": opt.threshold
}
    
# 9. 
model = LIN_GaussianNet(D_VAL, P_VAL, opt.E, intv_args, device=device, lgr_init = 1e-8) # Removed best_model_path for speed
model = model.fit(data, **fit_args)

# 10. <<< MODIFIED: Collect and save final results >>>
print('---  ---')
final_metrics = evaluate_model(model, G_prior_existence, calc_dist=True)

#  --results_csv 
if opt.results_csv:
    # 
    results_data = {
        'd': D_VAL,
        'P': P_VAL,
        'T': T_VAL,
        'true_edges': int(G_true.sum().item()),
        'seed': opt.rdseeds,
        'threshold': opt.threshold,
        'pnt': opt.pnt,
        'edge_pnt': opt.edge_pnt,
        'prior_rate': opt.prior_rate,
        'E': opt.E,
        **final_metrics, # 
        'bic': model._criteria # 
    }
    
    # CSV
    write_results_to_csv(opt.results_csv, results_data)
    print(f": {os.path.abspath(opt.results_csv)}")

print('---  ---')
