import numpy as np
import matplotlib.pyplot as plt
from datasets import load_from_disk

# ——— Helper to extract & average embeddings for a given index set ———
def get_avg_embeddings(ds, scenario, indices):
    """
    ds: HuggingFace Dataset loaded via load_from_disk
    scenario: one of "r", "g", "avg", or "concat"
    indices: list or array of integer indices to sample from ds
    returns: (len(indices), D) array of L2-normalized embeddings
    """
    embs = []
    for i in indices:
        rec = ds[int(i)]
        
        has_g = "embeddings_g" in rec
        has_r = "embeddings_r" in rec
        
        if has_g and has_r:
            eg = np.array(rec["embeddings_g"], dtype=np.float32)
            er = np.array(rec["embeddings_r"], dtype=np.float32)
            ag = eg.mean(axis=0)
            ar = er.mean(axis=0)
            if scenario == "concat":
                x = np.concatenate([ag, ar])
            elif scenario == "g":
                x = ag
            elif scenario == "r":
                x = ar
            else:  # "avg"
                x = 0.5 * (ag + ar)
        elif has_r:
            er = np.array(rec["embeddings_r"], dtype=np.float32)
            ar = er.mean(axis=0)
            x = ar
        elif has_g:
            eg = np.array(rec["embeddings_g"], dtype=np.float32)
            ag = eg.mean(axis=0)
            x = ag
        else:
            raise ValueError("No embeddings found in record")
            
        # L2-normalize
        x = x / np.linalg.norm(x)
        embs.append(x)
    return np.stack(embs)

csdr1_path = "/xxx/hf_csdr1_multiband_raw4_embeddings_astromer_1_subclass_pad_correct_gr/train"
macho_path = "/xxx/hf_macho_unlabel_embeddings_astromer_val/validation"

print("Loading CSDR1 dataset...")
csdr1_ds = load_from_disk(csdr1_path)
print("Loading MACHO dataset...")
macho_ds = load_from_disk(macho_path)

n_samples = 500
rng = np.random.default_rng(seed=42)
cs_indices = rng.choice(len(csdr1_ds), size=n_samples, replace=False).tolist()
ma_indices = rng.choice(len(macho_ds),  size=n_samples, replace=False).tolist()

print(f"CSDR1 dataset size: {len(csdr1_ds)}")
print(f"MACHO dataset size: {len(macho_ds)}")
print(f"Using {n_samples} samples from each dataset")

scenarios = ["r", "g"]
print("Preparing CSDR1 embeddings...")
csdr1_embs = {s: get_avg_embeddings(csdr1_ds, s, cs_indices) for s in scenarios}
print("Preparing MACHO embeddings...")
macho_embs = get_avg_embeddings(macho_ds, "r", ma_indices)

for scen, emb in csdr1_embs.items():
    print(f"Computing similarities for scenario: {scen}")
    sims = emb @ macho_embs.T               # shape (n_samples, n_samples)
    nn_cos = sims.max(axis=1)               # nearest-neighbour for each sampled CSDR1 star

    plt.figure()
    plt.hist(nn_cos, bins=50, density=True)
    plt.title(f"NN Cosine: CSDR1 ({scen}) vs MACHO (n={n_samples})")
    plt.xlabel("Cosine similarity")
    plt.ylabel("Density")
    plt.tight_layout()
    plt.show()
    
    print(f"Mean cosine similarity for {scen}: {nn_cos.mean():.4f}")
    print(f"Std cosine similarity for {scen}: {nn_cos.std():.4f}") 