import numpy as np
import itertools
import torch
from torch import optim
from tqdm import tqdm
from sklearn.cluster import KMeans
import argparse

parser = argparse.ArgumentParser(description='Compute the final low-dimensional embeddings')
parser.add_argument('--embedding_path', type = str)
parser.add_argument('--num_params_path', type = str)
parser.add_argument('--out_path', type = str)
parser.add_argument('--emb_dim', type = int, default = 2)
args = parser.parse_args()


arr = np.load(args.embedding_path)
p_log = np.log(np.load(args.num_params_path))

arr /= torch.functional.norm(torch.tensor(arr),dim = -1)[...,None]

combs = list(itertools.product(range(4),range(4)))
c = combs.pop()
D = 1-arr[:, c[0]] @ arr[:, c[1]].T

for c in combs:
    D += 1-arr[:, c[0]] @ arr[:, c[1]].T
    
D = torch.nn.functional.relu(D)
D /= D.std()


# step one -> do metric conserving projection
embs_lowd = torch.pca_lowrank(arr.mean(1), q = args.emb_dim)[0].cuda()
embs_lowd.requires_grad = True
D = D.cuda()
optimizer = optim.Adam([embs_lowd], lr = 1e-1)
progress = tqdm(range(2000))
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
for i in progress:
    
    idx = torch.randint(len(D), (1000,))
    
    dist = torch.norm(embs_lowd[idx, None] - embs_lowd, dim = -1)#**2
    loss = ((dist-D[idx].sqrt())**2).mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    progress.set_description(f'loss:{loss.item():.3f}')
    if i%5==0:
        scheduler.step()
    
embs_lowd = embs_lowd.cpu()

# step two -> uniform area

from sklearn.cluster import KMeans

def ml_asgd(m, p, its = 2000, lr = 0.2, beta = 0.75):
    
    num_clusters = int(np.sqrt(len(m)))
    km = KMeans(n_clusters=num_clusters).fit(p)
    c_idx = [np.where(km.labels_==i)[0] for i in range(num_clusters)]
    c_p = km.cluster_centers_
    c_m = np.array([m[c_idx[i]].sum() for i in range(num_clusters)])

    v_tilde_c = v_c = np.zeros(c_m.shape)
    
    centroids = np.zeros(p.shape)
    samples_per_centroid = np.zeros(len(p))
    
    p_in_centroid = [p[idx] for idx in c_idx]
    m_in_centroid = [m[idx]/m[idx].sum() for idx in c_idx]
    
    v_tilde_in_centroid = [np.zeros(m.shape) for m in m_in_centroid]
    
    samples_per_cluster = np.zeros(c_m.shape)
    
    for k in tqdm(range(1, its+1)):
        x = np.random.uniform(0,1, (1,p.shape[-1]))
        step = lr / np.sqrt(k)

        #layer zero.
        D = ((x-c_p)**2).sum(-1)-v_tilde_c
        idx = D.argmin()
        v_tilde_c[idx] -= step
        v_tilde_c += c_m * step
        v_c = 1/k * v_tilde_c + ((k-1)/k) * v_c

        samples_per_cluster[idx]+=1;

        # layer one
        k2 = samples_per_cluster[idx]
        step = lr / np.sqrt(k2)
        D = ((x-p_in_centroid[idx])**2).sum(-1)-v_tilde_in_centroid[idx]
        local_index = D.argmin()
        global_index = c_idx[idx][local_index]
        
        v_tilde_in_centroid[idx][local_index] -= step
        v_tilde_in_centroid[idx] += m_in_centroid[idx] * step
        
        #v[idx_in_c] = 1/k * v_tilde[idx_in_c] + ((k-1)/k) * v[idx_in_c]
        
        centroids[global_index] = centroids[global_index]*beta + (1-beta)*x 
        samples_per_centroid[global_index] += 1
        
    return centroids/(1-beta**samples_per_centroid)[:,None], samples_per_centroid


def unif(n):
    return (np.ones(n)/n).astype(np.float32)
    
w = embs_lowd.detach().numpy().astype(np.float32)
w_ = ((w - w.min(0))/(w.max(0)-w.min(0)))
centroids, masses = ml_asgd(unif(len(w_)), w_, its = 5000000, lr = 0.1)
centroids, masses = ml_asgd(unif(len(w_)), centroids, its = 5000000, lr = 0.1, beta = 0.99) # approx loyd relaxation ish

p_log-=p_log.min()
p_log/=p_log.max()
final_embeddings = np.hstack((centroids, p_log[:,None]))
np.save(args.out_path, final_embeddings)