import numpy as np
import torch
import sys

# name = sys.argv[0]


# def 
# def append_to_stream(S):

name = "./data/torch-embeddings/ImageNet21k_clip_embeddings.pt"
# name = "./data/torch-embeddings/CIFAR10_clip_embeddings.pt"
# name = "./data/torch-embeddings/CIFAR100_clip_embeddings.pt"
data = torch.load(name, map_location="cpu", weights_only=False)

# path = name[::-1].split("/", 1)[-1][::-1]
path = "./data/numpy-embeddings/"

model = data["model"]
dataset = data["dataset"]
embeddings = data["embeddings"]
labels = data["labels"]

embeddings = embeddings.float()
labels = labels.int()

# similarity = (embeddings[:10] @ embeddings.T)
# print(similarity.shape)

# above = similarity > 0.85

# tots = above.sum(dim=-1) - 1
# print(tots.shape)
# print(tots)


# print(embeddings.shape)
# np.savez(f"{path}{dataset}-{model}", embeddings=embeddings.numpy(), labels=labels.numpy())


# Sort embeddings by label once
sorted_indices = torch.argsort(labels)
sorted_labels = labels[sorted_indices]
sorted_embeddings = embeddings[sorted_indices]
sorted_embeddings = torch.nn.functional.normalize(sorted_embeddings, p=2, dim=1)

# Find boundaries for each class
unique_classes, class_counts = torch.unique(sorted_labels, return_counts=True)
class_boundaries = torch.cat([torch.tensor([0]), torch.cumsum(class_counts, dim=0)])

sampled_embs = []
sampled_labels = []
# Extract partitions efficiently
partitions = {}
for i, cls in enumerate(unique_classes):
    start_idx = class_boundaries[i]
    end_idx = class_boundaries[i + 1]
    partitions[cls.item()] = sorted_embeddings[start_idx:end_idx]
    if cls.item() % 10 == 0:
        sampled_embs.append(sorted_embeddings[start_idx:end_idx])
        sampled_labels.append(sorted_labels[start_idx:end_idx])

sampled_embs = torch.cat(sampled_embs)
sampled_labels = torch.cat(sampled_labels)

shuffle_order = torch.randperm(sampled_embs.shape[0])
sampled_embs = sampled_embs[shuffle_order]
sampled_labels = sampled_labels[shuffle_order]

print(sampled_embs.shape)
np.savez(f"{path}{dataset}-{model}-SAMPLED", embeddings=sampled_embs.numpy(), labels=sampled_labels.numpy())

tot_avg = 0
grp_avg = 0
tot_sims = []
for i, embs_i in partitions.items():
    print(f"on class {i} have embeddings {embs_i.shape}")
    n_i = embs_i.shape[0]
    if n_i <= 1:
        continue
    sims_i = embs_i @ embs_i.T
    upper_idx = torch.triu_indices(n_i, n_i, offset=1)
    flat_i = sims_i[upper_idx[0], upper_idx[1]]
    sum_i = flat_i.sum().item()
    avg_i = sum_i / flat_i.shape[0]
    print(f"got avg similarity {avg_i}")
    tot_avg += avg_i * n_i
    grp_avg += avg_i
    tot_sims.append(flat_i)

tot_avg /= embeddings.shape[0]
grp_avg /= len(partitions)
print(f"got total avg {tot_avg}, grp avg {grp_avg}")
sims = torch.cat(tot_sims, dim=0)
print(f"have tensor of {sims.shape[0]} sims")

sims = sims.numpy()
print(f"have mean {np.mean(sims)}, std {np.std(sims)}")
print(np.info(sims))
sims.sort()

for p in [50, 60, 70, 80, 90, 95, 99, 99.9]:
    i = int(p * sims.shape[0] / 100)
    print(f"{p}th %ile @ {i} -> {sims[i]}")

# print("all sims list results: ", np.percentile(sims, [50, 60, 70, 80, 90, 95, 99])) 



embeddings=embeddings.numpy()

def make_sample(K):
    sample = embeddings[np.random.choice(embeddings.shape[0], size=K, replace=False), :]
    sample /= np.sqrt((sample ** 2).sum(-1))[..., np.newaxis]
    samp_sims = sample @ sample.T
    return samp_sims

K = 5 * 10**4
samp_sims = make_sample(K)

def count_over(samp_sims, thresh):
    num_over = (samp_sims > thresh).sum() - samp_sims.shape[0]
    p = samp_sims.shape[0] / embeddings.shape[0]
    est = num_over / p / p
    ratio = num_over / samp_sims.shape[0] / samp_sims.shape[0]
    print(f"got {num_over} over {thresh}, est {est}, % is {100*ratio}%")

print(samp_sims.shape)
# sims = embeddings @ embeddings.T
print(np.percentile(samp_sims, [50, 60, 70, 80, 90, 95, 99, 99.9])) 
# print(np.percentile(sims.numpy(), [90, 95, 99])) 
# CIFAR10
# [0.4102352  0.43609664 0.46642071 0.50487423 0.56163967 0.61434799 0.72647297]
# CIFAR100
# [0.40391621 0.42926472 0.4572176  0.49077576 0.53828603 0.57847041 0.65903586]