import numpy as np
import pandas as pd
import umap
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler


def infer_feat(emb, verbose=False):
    activ_, ys, attrs = emb["activ"], emb["ys"], emb["attrs"]
    # cluster the activ based for each value of ys

    # dimension reduction
    # from sklearn.manifold import TSNE
    # activ_ = TSNE(n_components=2).fit_transform(activ_)
    # from sklearn.decomposition import PCA
    # activ_ = PCA(n_components=2).fit_transform(activ_)
    normed_activ_ = StandardScaler().fit_transform(activ_)

    reducer = umap.UMAP(n_components=2)
    activ_ = reducer.fit_transform(activ_)

    normed_activ_ = reducer.fit_transform(normed_activ_)

    return activ_, normed_activ_


import argparse
from tqdm import tqdm
import os

parser = argparse.ArgumentParser(description="Get embeddings from models.")
parser.add_argument("--model", type=str, default="resnet", help="Model to use")
parser.add_argument("--split", type=str, choices=["tr", "te"], default="tr")

args = parser.parse_args()

model = args.model
split = args.split


all_res_feat = []
emb_dir = f"//exps/div_explore/celeba_v2/{model}_embeddings_{split}/"
files = os.listdir(emb_dir)

reduced_emb_dir = (
    f"//exps/div_explore/celeba_v2/{model}_embeddings_{split}/reduced_emb/"
)
# check if exists
if not os.path.exists(reduced_emb_dir):
    os.makedirs(reduced_emb_dir)

for f in tqdm(files):
    path = emb_dir + f
    emb = np.load(path, allow_pickle=True)

    target_path = os.path.join(reduced_emb_dir, "reduced_" + f)
    # check if the file already exists
    if os.path.exists(target_path):
        print(f"File {f} already exists, skipping...")
        continue
    else:
        reduced_activ, normed_reduced_activ = infer_feat(emb, verbose=False)
        np.savez(
            os.path.join(
                reduced_emb_dir,
                "reduced_" + f,
            ),
            activ=reduced_activ,
            normed_activ=normed_reduced_activ,
            ys=emb["ys"],
            attrs=emb["attrs"],
        )
