import pandas as pd
import yaml
import json
from datetime import datetime
import os
import sys

config_path = sys.argv[1]

# Load YAML configuration
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

base_dir = config["base_dir"]
benchmarks = config["benchmarks"]
datasets = config["datasets"]

# Load data
experiment_name = config.get("experiment_name", "default")
benchmark_data = {name: pd.read_csv(base_dir + file) for name, file in benchmarks.items()}
datasets = {name: pd.read_parquet(base_dir + file) for name, file in datasets.items()}
eval_dataset = config.get("eval_dataset")
eval_data = datasets.get("eval_data")

if eval_dataset == "zinc":
    ground_truth = eval_data[:2000]
elif eval_dataset == "chembl":
    ground_truth = eval_data.sample(n=2000, random_state=21)
elif eval_dataset == "moses":
    ground_truth = eval_data.sample(n=2000, random_state=21)
elif eval_dataset == "oled":
    ground_truth = eval_data.sample(n=100, random_state=21)

# Create a unique log filename using experiment name and timestamp
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_filename = f"logs/eval/{experiment_name}_{timestamp}.yaml"
os.makedirs("logs/eval", exist_ok=True)  # Ensure logs directory exists

def compute_qed_mae(df, ground_truth):
    df = df[df['valid_smiles'] == 1]
    return abs(ground_truth['QED'] - df['qed_value']).mean()

def compute_logp_mae(df, ground_truth):
    df = df[df['valid_smiles'] == 1]
    return abs(ground_truth['LOGP'] - df['clogp_value']).mean()

def validity(df):
    return df['valid_smiles'].mean()

# Function to log results
def convert_numpy_types(obj):
    """Recursively converts NumPy data types to standard Python types."""
    if isinstance(obj, dict):
        return {key: convert_numpy_types(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_types(value) for value in obj]
    elif isinstance(obj, (pd.Series, pd.DataFrame)):
        return obj.to_dict()  
    elif hasattr(obj, "item"):  
        return obj.item()
    return obj  

def save_results(results, config, log_file):
    """Saves evaluation results along with the config to a YAML file with proper serialization."""
    log_data = {
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "experiment_name": config.get("experiment_name", "default_experiment"),
        "config": config,
        "results": convert_numpy_types(results) 
    }
    with open(log_file, "w") as f:
        yaml.dump(log_data, f, default_flow_style=False)
    print(f"\nResults logged in {log_file}")

# Compute results
qed_mae_results = {name: compute_qed_mae(df, ground_truth) for name, df in benchmark_data.items()}
logp_mae_results = {name: compute_logp_mae(df, ground_truth) for name, df in benchmark_data.items()}
validity_rate = {name: validity(df) for name, df in benchmark_data.items()}

# Print results
print("\n------------------------- QED MAE Results -------------------------")
for name, qed_mae in qed_mae_results.items():
    print(f"{name.capitalize()} QED MAE: {qed_mae}")

print("\n------------------------- LogP MAE Results -------------------------")
for name, logp_mae in logp_mae_results.items():
    print(f"{name.capitalize()} LogP MAE: {logp_mae}")

print("\n------------------------- Validity Rate -------------------------")
for name, rate in validity_rate.items():
    print(f"{name.capitalize()} Validity Rate: {rate}")

# Prepare structured log data
log_results = {
    "qed_mae": qed_mae_results,
    "logp_mae": logp_mae_results,
    "validity_rate": validity_rate
}

# Save results to log file
save_results(log_results, config, log_filename)
