import argparse
import numpy as np
import time
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score, davies_bouldin_score

parser = argparse.ArgumentParser()
parser.add_argument("--X_path", type=str, required=True)
parser.add_argument("--y_path", type=str, required=True)
parser.add_argument("--split_ratio", type=float, default=0.5)
parser.add_argument("--K", type=int, default=1000)
parser.add_argument("--eta", type=float, default=10.0)
args = parser.parse_args()

X = np.load(args.X_path).astype(np.float32)
y = np.load(args.y_path)
num_classes = len(np.unique(y))

split_ratio = args.split_ratio

X_init, y_init, X_rem, y_rem = [], [], [], []
for cls in range(num_classes):
    idxs = np.where(y == cls)[0]
    np.random.shuffle(idxs)
    n0 = int(split_ratio * len(idxs))
    X_init.append(X[idxs[:n0]])
    y_init.append(y[idxs[:n0]])
    X_rem.append(X[idxs[n0:]])
    y_rem.append(y[idxs[n0:]])

X_init = np.vstack(X_init)
y_init = np.hstack(y_init)
X_rem = np.vstack(X_rem)
y_rem = np.hstack(y_rem)

K = args.K
kmeans = KMeans(n_clusters=K, random_state=42).fit(X_init)
labels_init = kmeans.labels_

t0 = time.time()
Y_init = TSNE(n_components=2, method="barnes_hut", random_state=42).fit_transform(X_init).astype(np.float32)

D, d = X_init.shape[1], 2
_eps = 1e-9

class Cluster:
    def __init__(self, high_mean, low_mean, sum_sq, std_dev, count):
        self.high_mean = high_mean
        self.low_mean  = low_mean
        self.sum_sq    = sum_sq
        self.std_dev   = max(std_dev, _eps)
        self.count     = count

    def update(self, new_pts):
        n0    = self.count
        m0    = self.high_mean
        sum0  = self.sum_sq * n0
        n1    = new_pts.shape[0]
        total = n0 + n1

        m1      = (n0*m0 + new_pts.sum(axis=0)) / total
        norms2  = np.sum(new_pts**2, axis=1).sum()
        sum_sq1 = (sum0 + norms2) / total

        var = sum_sq1 - np.dot(m1, m1)
        self.high_mean = m1
        self.sum_sq    = sum_sq1
        self.std_dev   = np.sqrt(max(var, _eps))
        self.count     = total

clusters = []
for k in range(K):
    mask = (labels_init == k)
    hd = X_init[mask]
    ld = Y_init[mask]
    if hd.shape[0] == 0:
        clusters.append(Cluster(np.zeros(D), np.zeros(d), 1.0, 1.0, 1))
    else:
        clusters.append(Cluster(
            high_mean=hd.mean(axis=0),
            low_mean =ld.mean(axis=0),
            sum_sq   =(np.linalg.norm(hd,axis=1)**2).mean(),
            std_dev  =np.std(np.linalg.norm(hd,axis=1)),
            count    =hd.shape[0]
        ))

class IncrementalRSNE:
    def __init__(self, clusters, X0, Y0, y0, eta=10.0):
        self.clusters = clusters
        self.eta       = eta
        self.X_list = [row for row in X0]
        self.Y_list = [row for row in Y0]
        self.y_all  = list(y0)

    def add_point(self, x_new, y_label):
        C_high = np.stack([c.high_mean for c in self.clusters])
        C_low  = np.stack([c.low_mean  for c in self.clusters])
        sigma  = np.clip(np.array([c.std_dev for c in self.clusters], dtype=np.float32), _eps, None)

        d2h = np.sum((C_high - x_new[None,:])**2, axis=1)
        k   = np.argmin(d2h)

        y_new = C_low[k] + np.random.randn(d).astype(np.float32)*0.1

        P = np.exp(-d2h / (2 * sigma**2))
        P /= (P.sum() + _eps)

        d2l = np.sum((C_low - y_new[None,:])**2, axis=1)
        Q   = 1.0 / (1.0 + d2l)
        Q  /= (Q.sum() + _eps)

        coef = 2.0 * (P - Q) / (1.0 + d2l)
        diff = (y_new[None,:] - C_low)
        grad = np.einsum('k,kj->j', coef, diff)
        y_new -= self.eta * grad

        self.clusters[k].update(x_new[None,:])

        self.X_list.append(x_new)
        self.Y_list.append(y_new)
        self.y_all.append(y_label)

rtsne = IncrementalRSNE(clusters, X_init, Y_init, y_init, eta=args.eta)
t1    = time.time()
total = len(X_rem)
for i, (x_pt, lbl) in enumerate(zip(X_rem, y_rem), start=1):
    rtsne.add_point(x_pt, lbl)
    if i % 1000 == 0 or i == total:
        elapsed = time.time() - t1
        print(f"  → processed {i}/{total} samples ({i/total*100:.1f}%) in {elapsed:.1f}s")

print("Incremental SNE done in", time.time() - t1, "s")

X_all = np.vstack(rtsne.X_list)
Y_all = np.vstack(rtsne.Y_list)
labels = np.array(rtsne.y_all)

sil = silhouette_score(Y_all, labels)
db  = davies_bouldin_score(Y_all, labels)
print(f"Silhouette: {sil:.4f}, DBI: {db:.4f}")

plt.figure(figsize=(8,8))
for cls in np.unique(labels):
    mask = (labels == cls)
    plt.scatter(Y_all[mask,0], Y_all[mask,1], s=5, alpha=0.5)
plt.xticks([]); plt.yticks([])
plt.tight_layout()
plt.savefig("i_rsne_final.png", dpi=300)
plt.show()
