import os
import yaml
from ml_collections import ConfigDict

import torch

from qrcp.helpers.lightner import Output


def smooth_prediction_filename(dataset_name, model_sigma, n_datapoints, smoothing_sigma, n_samples, r=0):
    logits_file_name = f"{dataset_name}-clean-npoints_{n_datapoints}-modelsigma_{str(model_sigma).replace('.', '_')}-smoothing_{str(smoothing_sigma).replace('.', '_')}-samples_{n_samples}"
    if r > 0:
        logits_file_name += f"-r_{str(r)}"
    logits_file_name += ".pth"
    return logits_file_name

def load_smooth_prediction(dataset_name, model_sigma, n_datapoints, smoothing_sigma, n_samples, r=0):
    general_config = yaml.safe_load(open("../conf/general.yaml", "r"))
    conf = ConfigDict(general_config["general"])
    
    logits_file_name = smooth_prediction_filename(dataset_name, model_sigma, n_datapoints, smoothing_sigma, n_samples, r)
    clean_d = torch.load(os.path.join(conf.logits_dir, logits_file_name))
    y_pred = clean_d["y_pred"]
    logits = clean_d["logits"]
    y_true = clean_d["y_true"]
    prediction = Output(y_pred=y_pred, logits=logits, y_true=y_true)
    return prediction