import datetime
import numpy as np

from sgns.co_occurrence import construct_co_occurrence
from sgns.sbm import stochastic_block_model, expected_sbm
from sgns.sgns_numba import SkipgramNegativeSampling

# Draw a graph from a stochastic block model with parameters n,k,p,q
n = 1200
K = 4
q = 0.1
p = 0.6

# A = stochastic_block_model(n,K,p,q)
A = expected_sbm(n, K, p, q)

# Construct a co-occurrence matrix
L = 100_000
S = 5

print("Constructing Co-occurrence Matrix...", end="")
C = construct_co_occurrence(A, L, S)
print("completed")

t_iter = 100_000_000
s_n = [1, 2, 3, 5, 10]

for s in s_n:
    print("\n" + "-" * 30)
    print("negative samples: ", s)

    start_time = datetime.datetime.now()
    sgns = SkipgramNegativeSampling(eta=1 / n, t_iter=t_iter, r=1 / n**2, s_n=s, verbose=True).fit(C, K=K)
    end_time = datetime.datetime.now()
    print("Running Time: ", end_time - start_time)

    # Save potentials
    pot_arr = np.array(sgns.pot_)
    pot_community_arr = np.array(sgns.pot_community_)
    iterations = np.arange(1, len(pot_arr) + 1) * 5000
    potentials_data = np.vstack((iterations, pot_arr, pot_community_arr)).T
    np.savetxt(f"../dat/potentials_s{s}_n{n}_K{K}.csv", potentials_data, delimiter=",",
               header="iteration,potential,community_potential", comments="")

    # Save embeddings
    np.savetxt(f"../dat/X_s{s}_n{n}_K{K}.csv", sgns.X, delimiter=",")
    np.savetxt(f"../dat/Y_s{s}_n{n}_K{K}.csv", sgns.Y, delimiter=",")

print("\n" + "-" * 30)
print("Finished saving results.")
