import datetime
import numpy as np
import os

from sgns.co_occurrence import construct_co_occurrence
from sgns.sbm import expected_sbm
from sgns.sgns_numba import SkipgramNegativeSampling

def run_radius_experiment():
    """
    Runs the SGNS algorithm for a fixed configuration while varying the initialization radius 'r'.
    Saves the resulting potentials to CSV files for later analysis.
    """
    # --- Experiment Configuration ---
    n = 1000 # Using n=1200 as per user request in previous turns
    K = 2
    s_n = 1e
    t_iter = 100_000_000

    # Define the range of initialization radii to test with symbolic labels
    # NOTE: Keep the order consistent between this and the plotting script
    r_labels = [
        (1/n**2, r'$1/n^2$'),
        (1/n, r'$1/n$'),
        (1/np.sqrt(n), r'$1/\sqrt{n}$'),
        (1.0, r'$1$')
    ]

    # --- Setup ---
    data_dir = "../dat"
    os.makedirs(data_dir, exist_ok=True)
    
    print("--- Starting Experiment 1: Effect of Initialization Radius (r) ---")
    print(f"Configuration: n={n}, K={K}, s_n={s_n}\n")

    # Construct a single co-occurrence matrix for all runs
    L = 100_000
    S = 5
    print("Constructing Co-occurrence Matrix...", end="")
    A = expected_sbm(n, K, p=0.6, q=0.1)
    C = construct_co_occurrence(A, L, S)
    print("Completed.\n")

    # --- Main Experiment Loop ---
    for r_val, r_label in r_labels:
        print("-" * 30)
        print(f"Running for radius r = {r_label} (numerical value: {r_val:.2e})")

        start_time = datetime.datetime.now()
        
        # Instantiate and fit the model
        sgns = SkipgramNegativeSampling(
            eta=1 / n, 
            t_iter=t_iter, 
            r=r_val, 
            s_n=s_n,
            verbose=True
        ).fit(C, K=K)
        
        end_time = datetime.datetime.now()
        print(f"Completed in {end_time - start_time}")

        # --- Save Results ---
        # Use the label for the filename, replacing problematic characters
        safe_r_label = r_label.replace('/', '_').replace('$', '').replace('\\', '').replace('{', '').replace('}', '')
        output_filename = os.path.join(data_dir, f"potentials_n{n}_K{K}_sn{s_n}_r{safe_r_label}.csv")
        
        pot_arr = np.array(sgns.pot_)
        pot_community_arr = np.array(sgns.pot_community_)
        iterations = np.arange(1, len(pot_arr) + 1) * 5000

        potentials_data = np.vstack((iterations, pot_arr, pot_community_arr)).T
        
        np.savetxt(
            output_filename, 
            potentials_data, 
            delimiter=",",
            header="iteration,potential,community_potential", # Update header
            comments=""
        )
        print(f"Results saved to {output_filename}\n")

    print("-" * 30)
    print("Experiment 1 finished.")

if __name__ == "__main__":
    run_radius_experiment()
