import numpy as np
import pandas as pd
import umap
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler


def infer_feat(emb, verbose=False):
    ys, attrs = emb["ys"], emb["attrs"]
    activ_ = emb["activ"]
    normed_activ_ = emb["normed_activ"]

    # compute distances wrt gt labels
    c_intra, c_inter = infer_dist(activ_, ys, verbose=False)
    a_intra, a_inter = infer_dist(activ_, attrs, verbose=False)
    c_intra_n, c_inter_n = infer_dist(normed_activ_, ys, verbose=False)
    a_intra_n, a_inter_n = infer_dist(normed_activ_, attrs, verbose=False)

    # infer group labels with reduced activ
    activ0, activ1 = activ_[ys == 0], activ_[ys == 1]
    attrs0, attrs1 = attrs[ys == 0], attrs[ys == 1]
    sub_class0 = infer_labels(activ0, attrs0, verbose=verbose)
    sub_class1 = infer_labels(activ1, attrs1, verbose=verbose)
    group00 = np.sum(sub_class0 == 0)
    group01 = np.sum(sub_class0 == 1)
    group10 = np.sum(sub_class1 == 0)
    group11 = np.sum(sub_class1 == 1)
    total = group00 + group01 + group10 + group11
    sc_ = (group00 + group11) / total
    ci_ = (group00 + group01) / total
    ai_ = (group00 + group10) / total
    if verbose:
        print(sc_, ci_, ai_)
    # compute distances wrt inferred labels
    inferred_attrs = np.concatenate([sub_class0, sub_class1])
    activ_ = np.concatenate([activ0, activ1])
    a_intra_i, a_inter_i = infer_dist(activ_, inferred_attrs, verbose=False)

    # infer group labels with normed reduced activ
    activ0, activ1 = normed_activ_[ys == 0], normed_activ_[ys == 1]
    attrs0, attrs1 = attrs[ys == 0], attrs[ys == 1]
    sub_class0 = infer_labels(activ0, attrs0, verbose=verbose)
    sub_class1 = infer_labels(activ1, attrs1, verbose=verbose)
    group00 = np.sum(sub_class0 == 0)
    group01 = np.sum(sub_class0 == 1)
    group10 = np.sum(sub_class1 == 0)
    group11 = np.sum(sub_class1 == 1)
    total = group00 + group01 + group10 + group11
    normed_sc_ = (group00 + group11) / total
    normed_ci_ = (group00 + group01) / total
    normed_ai_ = (group00 + group10) / total
    if verbose:
        print(normed_sc_, normed_ci_, normed_ai_)
    # compute distances wrt inferred labels
    inferred_attrs = np.concatenate([sub_class0, sub_class1])
    normed_activ_ = np.concatenate([activ0, activ1])
    a_intra_in, a_inter_in = infer_dist(normed_activ_, inferred_attrs, verbose=False)

    return (
        sc_,
        ci_,
        ai_,
        normed_sc_,
        normed_ci_,
        normed_ai_,
        c_intra,
        a_intra,
        c_inter,
        a_inter,
        c_intra_n,
        a_intra_n,
        c_inter_n,
        a_inter_n,
        a_intra_i,
        a_inter_i,
        a_intra_in,
        a_inter_in,
    )


def infer_labels(activ, attrs, verbose=False):
    kmeans = KMeans(n_clusters=2, random_state=0).fit(activ)
    if verbose:
        print("silhouette_score", silhouette_score(activ, kmeans.labels_))
        print(
            max(
                np.array(kmeans.labels_ == attrs, dtype=int).mean(),
                np.array(kmeans.labels_ == (1 - attrs), dtype=int).mean(),
            )
        )
    if (
        np.array(kmeans.labels_ == attrs, dtype=int).mean()
        > np.array(kmeans.labels_ == (1 - attrs), dtype=int).mean()
    ):
        sub_class = kmeans.labels_
    else:
        sub_class = 1 - kmeans.labels_

    return sub_class


def infer_dist(activ, ys, verbose=False):
    c0 = activ[ys == 0].mean(axis=0)
    c1 = activ[ys == 1].mean(axis=0)
    c0_ave_dist = np.mean(np.linalg.norm(activ[ys == 0] - c0, axis=1))
    c1_ave_dist = np.mean(np.linalg.norm(activ[ys == 1] - c1, axis=1))
    c_ave_dist = (c0_ave_dist + c1_ave_dist) / 2
    c_inter_dist = np.linalg.norm(c0 - c1)

    return c_ave_dist, c_inter_dist


import argparse
from tqdm import tqdm
import os

parser = argparse.ArgumentParser(description="Get embeddings from models.")
parser.add_argument("--model", type=str, default="resnet", help="Model to use")
parser.add_argument("--split", type=str, choices=["tr", "te"], default="tr")
parser.add_argument("--datafolder", type=str, default="coco_v2")

args = parser.parse_args()

model = args.model
split = args.split
folder = args.datafolder


all_res_feat = []

reduced_emb_dir = f"//exps/div_explore/{folder}/{model}_embeddings_{split}/reduced_emb/"
files = os.listdir(reduced_emb_dir)
# check if exists
if not os.path.exists(reduced_emb_dir):
    os.makedirs(reduced_emb_dir)

for f in tqdm(files):
    curr_res = {}

    curr_res["sc"] = float(f.split("sc")[1].split("_")[0])
    curr_res["ci"] = float(f.split("ci")[1].split("_")[0])
    curr_res["ai"] = float(f.split("ai")[1][:-4])
    curr_res["n"] = int(f.split("n")[1].split("_")[0])

    emb = np.load(
        os.path.join(
            reduced_emb_dir,
            f,
        ),
        allow_pickle=True,
    )
    (
        sc_,
        ci_,
        ai_,
        normed_sc_,
        normed_ci_,
        normed_ai_,
        c_intra,
        a_intra,
        c_inter,
        a_inter,
        c_intra_n,
        a_intra_n,
        c_inter_n,
        a_inter_n,
        a_intra_i,
        a_inter_i,
        a_intra_in,
        a_inter_in,
    ) = infer_feat(emb, verbose=False)
    # add to res
    curr_res["sc_"] = sc_
    curr_res["ci_"] = ci_
    curr_res["ai_"] = ai_
    curr_res["normed_sc_"] = normed_sc_
    curr_res["normed_ci_"] = normed_ci_
    curr_res["normed_ai_"] = normed_ai_
    curr_res["c_intra"] = c_intra
    curr_res["a_intra"] = a_intra
    curr_res["c_inter"] = c_inter
    curr_res["a_inter"] = a_inter
    curr_res["c_intra_n"] = c_intra_n
    curr_res["a_intra_n"] = a_intra_n
    curr_res["c_inter_n"] = c_inter_n
    curr_res["a_inter_n"] = a_inter_n
    curr_res["a_intra_i"] = a_intra_i
    curr_res["a_inter_i"] = a_inter_i
    curr_res["a_intra_in"] = a_intra_in
    curr_res["a_inter_in"] = a_inter_in

    print(model, split, curr_res)

    all_res_feat.append(curr_res)
all_res_feat = pd.DataFrame(all_res_feat)
# save to csv
all_res_feat.to_csv(
    f"//exps/div_explore/{folder}/{model}_embeddings_{split}/proc_emb.csv"
)
