import glob
import os
import argparse
import numpy as np
import pickle 

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_path",
        default=None,
        help="Path to in_indices, stat, and output score files.",
    )

    args = parser.parse_args()

    all_logit_paths = glob.glob(os.path.join(args.data_path, "in_indices_*"))
    all_logit_paths = [l.replace("in_indices", "logits") for l in all_logit_paths]

    results = list()
    for logit_path in all_logit_paths:
        folder = logit_path.replace(".pkl", "_all_files")

        if not os.path.isdir(folder):
             
            with open(logit_path, "rb") as f:
                logits = pickle.load(f)
            logits = np.array(logits)
            

            with open(logit_path.replace("logits", "in_indices"), "rb") as f:
                indicies = np.array(pickle.load(f))

            
            with open(logit_path.replace("logits", "labels"), "rb") as f:
                y_true = np.array(pickle.load(f))
            

            
            os.mkdir(folder)

            np.save(file=os.path.join(folder, "x_train.npy"), arr=y_true)
            np.save(file=os.path.join(folder, "y_train.npy"), arr=y_true)

            M = len(logits)
            print(f"Detected M={M} with file {logit_path}")

            for m in range(M):
                m_folder = os.path.join(folder, f'experiment-{m}_data')
                os.mkdir(m_folder)


                # make logits folder
                logits_folder = os.path.join(m_folder, "logits")
                os.mkdir(logits_folder)

                
                np.save(file=os.path.join(logits_folder,"logits"), arr=logits[m][:,None, None,:].repeat(2, axis=1))
                np.save(file=os.path.join(m_folder, "keep.npy"), arr=indicies[m])