import os, sys
from pathlib import Path
from metrics import aggregate_metrics
import json, os
import argparse
from datasets import load_dataset

sys.path.append(str(Path(__file__).parent.parent))
from utils.utils import prepare_model, save_metrics

def get_args():
    parser = argparse.ArgumentParser(description='Dataset Inference on a language model')
    parser.add_argument('--dataset_path', type=str, help='The path to the dataset to use')
    parser.add_argument('--raw_values_path', type=str, help='The path to the raw values to use')
    parser.add_argument('--cache_dir', type=str, default="~/.cache", help='The directory to cache the model')
    parser.add_argument('--output_dir', type=str, default="results", help='The directory to save the results')
    parser.add_argument('--result_file_name', type=str, default="metrics.json", help='The name of the result file')
    parser.add_argument('--reference_models_metrics_path', type=str, nargs='*', default=[], help='List of paths to reference models metrics files')
    parser.add_argument('--loss_estimation_method', type=str, default="raw", help='Method for loss estimation (raw, ref, sigmoid)')
    args = parser.parse_args()
    return args

def load_file(filepath):
    """
    Load a JSON file and return its contents as a dictionary.
    """
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

def main():
    args = get_args()
    results_file = f"{args.output_dir}/{args.result_file_name}"

    raw_values_path = args.raw_values_path
    
    dataset = load_dataset("json", data_files=args.dataset_path, split="train")
    raw_values = load_file(raw_values_path)

    # Load reference models metrics if provided
    reference_models_metrics = []
    for ref_path in args.reference_models_metrics_path:
        reference_models_metrics.append(load_file(ref_path))

    print("Data loaded")

    metric_list = ["k_min_probs", "k_strip_probs", "ppl", "zlib_ratio", "k_max_probs", "perturbation", "reference_model", "petal"]

    metrics = aggregate_metrics(
        raw_values,
        dataset,
        metric_list,
        args,
        reference_models_dics=reference_models_metrics,
        loss_estimation_method=args.loss_estimation_method
    )

    # save the metrics
    os.makedirs(args.output_dir, exist_ok=True)
    save_metrics(results_file, metrics)

if __name__ == "__main__":
    main()
