"""
Similar as the other test.py script, but the latents are now binarized by 
threshholding in the middle of the values. 

The metrics are taken from Towards Robust Metrics For Concept Representation Evaluation [Zarlenga et al. 2023]

"""

from argparse import ArgumentParser
import os
import pickle
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
sys.path.insert(1, os.path.join(sys.path[0], '../..'))
import numpy as np

from shared_utils.dl_experiments import print_params
from DL_experiments.shared_utils.dl_experiments_binarized import (
    do_cbm_test,
    create_labels,
    set_seed
)

def load_datasets(data_dir: str) -> np.array:
    z_gt_all = np.load(os.path.join(data_dir, "z_gt_final.npy"))
    # z_hat_all = np.load(os.path.join(data_dir, "z_hat_final.npy"))
    z_hat_all = np.load(os.path.join(data_dir, "z_hat_final_ordered.npy"))
    x_all = np.load(os.path.join(data_dir, "x_all.npy"))
    return x_all, z_gt_all, z_hat_all


def binarize_latents(all_latents):
    num_latents = all_latents.shape[1]
    bin_all_latents = np.zeros(all_latents.shape)
    for i in range(num_latents):
        min_val = np.min(all_latents[:, i])
        max_val = np.max(all_latents[:, i])

        mid_point = (max_val + min_val) / 2 
        print('concept ', i)
        print(f"min: {min_val}, max: {max_val}, mid: {mid_point}")
        bin_all_latents[all_latents[:, i] > mid_point, i] = 1

    return bin_all_latents


def get_data_dir(model_args, base_dir):
    if model_args["model_type"] == "iVAE":
        model_dir = os.path.join(base_dir, "iVAE")
    elif model_args["model_type"] == "DMS-VAE":
        model_dir = os.path.join(base_dir, "DMS_VAE")
    elif model_args["model_type"] == "TCVAE":
        model_dir = os.path.join(base_dir, "TCVAE")

    if model_args["dataset"] == "action":
        data_dir = os.path.join(model_dir, "action_sparsity_non_trivial")
    elif model_args["dataset"] == "temporal":
        data_dir = os.path.join(model_dir, "temporal_sparsity_non_trivial")

    return data_dir


if __name__ == "__main__":
    models_cbm = [
        {"model_type": "CEM"},
    ]
    repeats = 10
    seeds = list(range(100, 100 + repeats))

    Ns = [20, 100, 1000, 10000]

    parser = ArgumentParser()
    parser.add_argument('--cluster', action='store_true')

    args = parser.parse_args()
    model_args = vars(args)

    # do_vae = args.vae
    # do_cbm = args.cbm
    datasets = ["action", "temporal"]
    if args.cluster:
        base_dir = "checkpoints"
        
        for dataset in datasets:
       
            model_args["model_type"] = "DMS-VAE"
            model_args["dataset"] = dataset

            data_dir = get_data_dir(model_args, base_dir)
            x_all, z_gt, z_hat = load_datasets(data_dir=data_dir)
            
            model_args["model_type"] = "CEM"
            print_params("Permutation Concepts experiment", model_args)

            for i, seed in enumerate(seeds):
                set_seed(seed)

                bin_all_latents = binarize_latents(z_gt)
                y_values = create_labels(bin_all_latents)

                do_cbm_test(
                    x_all, 
                    bin_all_latents, 
                    y_values,
                    Ns,
                    seed=seed, 
                    model_type="CEM",
                    base_dir=base_dir,
                    cluster=args.cluster,
                    dataset_name=dataset,
                )
    
    

