import argparse
import numpy as np
import time
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score, davies_bouldin_score, accuracy_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_score

parser = argparse.ArgumentParser()
parser.add_argument("--X_path", type=str, required=True)
parser.add_argument("--y_path", type=str, required=True)
parser.add_argument("--split_ratio", type=float, default=0.5)
parser.add_argument("--K", type=int, default=200)
parser.add_argument("--batch_size", type=int, default=100)
args = parser.parse_args()

X = np.load(args.X_path).astype(np.float32)
y = np.load(args.y_path)
num_classes = len(np.unique(y))

split_ratio = args.split_ratio

X_init, y_init, X_rem, y_rem = [], [], [], []
for cls in range(num_classes):
    idxs = np.where(y == cls)[0]
    np.random.shuffle(idxs)
    n_init = int(split_ratio * len(idxs))
    X_init.append(X[idxs[:n_init]])
    y_init.append(y[idxs[:n_init]])
    X_rem.append(X[idxs[n_init:]])
    y_rem.append(y[idxs[n_init:]])

X_init = np.vstack(X_init)
y_init = np.hstack(y_init)
X_remaining = np.vstack(X_rem)
y_remaining = np.hstack(y_rem)

K = args.K
kmeans = KMeans(n_clusters=K, random_state=42).fit(X_init)
labels_init = kmeans.labels_

start = time.time()
tsne = TSNE(n_components=2, method="barnes_hut", random_state=42)
Y_init = tsne.fit_transform(X_init).astype(np.float32)

D, d = X_init.shape[1], 2
class Cluster:
    def __init__(self, high_mean, low_mean, sum_sq, std_dev, count):
        self.high_mean = high_mean
        self.low_mean  = low_mean
        self.sum_sq    = sum_sq
        self.std_dev   = std_dev
        self.count     = count

    def update_stats(self, new_pts):
        c0 = self.count
        m0 = self.high_mean
        sum0 = self.sum_sq * c0
        sq_norms = np.sum(new_pts * new_pts, axis=1)
        total_pts = c0 + len(new_pts)
        m_new = (c0*m0 + new_pts.sum(axis=0)) / total_pts
        sum_sq_new = (sum0 + sq_norms.sum()) / total_pts
        self.high_mean = m_new
        self.sum_sq    = sum_sq_new
        var = max(sum_sq_new - np.dot(m_new, m_new), 1e-9)
        self.std_dev   = np.sqrt(var)
        self.count     = total_pts

clusters = []
for k in range(K):
    pts_hd = X_init[labels_init == k]
    pts_ld = Y_init[labels_init == k]
    if len(pts_hd) == 0:
        clusters.append(Cluster(np.zeros(D, dtype=np.float32),
                                np.zeros(d, dtype=np.float32),
                                1.0,
                                1.0,
                                1))
    else:
        clusters.append(Cluster(
            high_mean=pts_hd.mean(axis=0),
            low_mean =pts_ld.mean(axis=0),
            sum_sq  =(np.linalg.norm(pts_hd, axis=1)**2).mean(),
            std_dev =np.std(np.linalg.norm(pts_hd, axis=1)),
            count   =len(pts_hd),
        ))

class IncrementalTSNEBatch:
    def __init__(self, clusters, X_init, Y_init, y_init,
                 eta=10.0, max_iters=1):
        self.clusters = clusters
        self.eta       = eta
        self.max_iters = max_iters
        self.X_all = X_init.astype(np.float32)
        self.Y_all = Y_init.astype(np.float32)
        self.y_all = list(y_init)

    def add_new_batch(self, Xb, yb_labels):
        Xb = Xb.astype(np.float32)
        B = Xb.shape[0]

        C_high = np.stack([c.high_mean for c in self.clusters])
        C_low  = np.stack([c.low_mean  for c in self.clusters])
        sigma  = np.array([c.std_dev for c in self.clusters], dtype=np.float32)
        sigma = np.clip(sigma, 1e-3, None)

        d2 = np.sum((Xb[:,None,:] - C_high[None,:,:])**2, axis=2)
        nearest = np.argmin(d2, axis=1)

        Yb = C_low[nearest] + (np.random.randn(B, d).astype(np.float32) * 0.1)

        for _ in range(self.max_iters):
            d2_high = np.sum((Xb[:,None,:] - C_high[None,:,:])**2, axis=2)
            P = np.exp(-d2_high / (2 * sigma[None,:]**2))
            P /= (P.sum(axis=1, keepdims=True) + 1e-12)

            d2_low = np.sum((Yb[:,None,:] - C_low[None,:,:])**2, axis=2)
            Q = 1.0 / (1.0 + d2_low)
            Q /= (Q.sum(axis=1, keepdims=True) + 1e-12)

            coef = 2.0 * (P - Q) / (1.0 + d2_low)
            grads = np.einsum('ik,ikj->ij', coef, (Yb[:,None,:] - C_low[None,:,:]))
            Yb -= self.eta * grads

        for k in range(K):
            mask = (nearest == k)
            if np.any(mask):
                self.clusters[k].update_stats(Xb[mask])

        self.X_all = np.vstack([self.X_all, Xb])
        self.Y_all = np.vstack([self.Y_all, Yb])
        self.y_all.extend(yb_labels)

rtsne = IncrementalTSNEBatch(clusters, X_init, Y_init, y_init,
                              eta=10.0, max_iters=1)

batch_size = args.batch_size
for i in range(0, len(X_remaining), batch_size):
    rtsne.add_new_batch(X_remaining[i:i+batch_size],
                        y_remaining[i:i+batch_size])
print(f"Bi-RSNE finished in {time.time() - start:.1f}s")

Y = rtsne.Y_all
labels = np.array(rtsne.y_all)

sil = silhouette_score(Y, labels)
db  = davies_bouldin_score(Y, labels)
print("Silhouette:", sil, "DB Index:", db)

plt.figure(figsize=(8,8))
for cls in np.unique(labels):
    idx = labels == cls
    plt.scatter(Y[idx,0], Y[idx,1], alpha=0.5, s=5)
plt.xticks([]); plt.yticks([])
plt.tight_layout()
plt.savefig("bi-rsne_embedding.png", dpi=300)
plt.show()
