import pickle
import numpy as np
import os
from lira import compute_score_lira
import argparse
import glob
NUM_TARGET_MODELS = 256


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", default=None, help="Path to in_indices, stat, and output score files.")
    parser.add_argument("--filter", default="*", help="Further filtering of stat files e.g., by dataset.")

    args = parser.parse_args()

    all_stats = glob.glob(os.path.join(args.data_path, f'stat{args.filter}.pkl'))
    all_scores = glob.glob(os.path.join(args.data_path, f'scores{args.filter}.pkl'))
    unprocessed_scores = list(set([x.replace("stat", "scores") for x in all_stats]) - set(all_scores))
    print(f"already processed {len(all_stats)-len(unprocessed_scores)} still to be processed: {len(unprocessed_scores)}")
    for unprocessed_score in unprocessed_scores:
        print(unprocessed_score, flush=True)

        # newer version, to be tested
        with open(unprocessed_score.replace("scores", "in_indices"), "rb") as f:
            in_indices = pickle.load(f)
        with open(unprocessed_score.replace("scores", "stat"), "rb") as f:
            stat = pickle.load(f)
        n = len(stat[0])

        # Now we do MIA for each model
        all_scores = []
        all_y_true = []
        for idx in range(NUM_TARGET_MODELS):
            print(f'Target model is #{idx}', flush=True)
            stat_target = stat[idx]  # statistics of target model, shape (n, k)
            in_indices_target = in_indices[idx]  # ground-truth membership, shape (n,)

            # `stat_shadow` contains statistics of the shadow models, with shape
            # (num_shadows, n, k). `in_indices_shadow` contains membership of the shadow
            # models, with shape (num_shadows, n). We will use them to get a list
            # `stat_in` and a list `stat_out`, where stat_in[j] (resp. stat_out[j]) is a
            # (m, k) array, for m being the number of shadow models trained with
            # (resp. without) the j-th example, and k being the number of augmentations
            # (1 in our case).
            stat_shadow = np.array(stat[:idx] + stat[idx + 1:])
            in_indices_shadow = np.array(in_indices[:idx] + in_indices[idx + 1:])
            stat_in = [stat_shadow[:, j][in_indices_shadow[:, j]] for j in range(n)]
            stat_out = [stat_shadow[:, j][~in_indices_shadow[:, j]] for j in range(n)]

            # Compute the scores and use them for MIA
            scores = compute_score_lira(stat_target, stat_in, stat_out, fix_variance=True)

            #y_score = np.concatenate((scores[in_indices_target], scores[~in_indices_target]))
            #y_true = np.concatenate((np.zeros(len(scores[in_indices_target])),
            #                         np.ones(len(scores[~in_indices_target]))))
            
            # preserve the order of samples
            y_score = scores
            y_true = [0 if mask else 1 for mask in in_indices_target]

            all_scores.append(y_score)
            all_y_true.append(y_true)

        all_y_true = np.hstack(all_y_true)
        all_scores = np.hstack(all_scores)
        result = {
            'y_true': all_y_true,
            'scores': all_scores
        }

        with open(unprocessed_score, "wb") as f:
            pickle.dump(result, f)


if __name__ == '__main__':
    main()
