
from src.plotting_results.storing_results import create_file_path

from factory import (
    make_model,
    make_dataset
)
import pickle as pk
from retraining_task_processing.f_generator import PredictorFGenerator
from experiment_retrain import RetrainingExp
import os
# not good, it assumes the run.py is called from a fixed path
project_path = os.getcwd()


def create_file_path_experiment_data(trial, cfg_exp, cfg_dataset, cfg_reporter):
    list_name = [trial, 'seed', cfg_exp['seed']]
    list_name = list_name + [cfg_dataset['T'],cfg_dataset['N'], cfg_dataset['offline_t']]
    dataset_path = cfg_dataset['name']
    filepath = create_file_path(
        list_name, [project_path, cfg_reporter['experiment_path'], dataset_path])
    return filepath


def load_experiment_data(trial, cfg_exp, cfg_dataset, cfg_reporter):
    filepath = create_file_path_experiment_data(
        trial, cfg_exp, cfg_dataset, cfg_reporter)
    if not os.path.exists(filepath):
        return None
    with open(filepath, 'rb') as file:
        return pk.load(file)


def store_experiment_data(dict_things_needed, trial, cfg_exp, cfg_dataset, cfg_reporter):
    filepath = create_file_path_experiment_data(
        trial, cfg_exp, cfg_dataset, cfg_reporter)
    with open(filepath, 'wb') as file:
        return pk.dump(dict_things_needed, file)


def run_one_model(cfg, reporter):
    store_per_trial = {}
    for trial in range(cfg["experiment"]["trials"]):
        model = make_model(cfg["model"], cfg["dataset"], cfg["experiment"])
        f_generator = PredictorFGenerator(cfg["experiment"], cfg["dataset"])

        dict_things_needed = load_experiment_data(
                trial, cfg["experiment"], cfg["dataset"], cfg["reporter"])
        if dict_things_needed is None:
            dataset = make_dataset(
                cfg["dataset"], cfg["experiment"], seed=trial)
            exp = RetrainingExp(
                [model],
                dataset,
                reporter,
                f_generator,
                cfg
            )
            dict_things_needed = exp.precompute()
            store_experiment_data(dict_things_needed, trial,
                                  cfg["experiment"], cfg["dataset"], cfg["reporter"])
        else:
            exp = RetrainingExp(
                [model],
                None,
                reporter,
                f_generator,
                cfg
            )
            exp.precompute(dict_things_needed)
        exp.train_offline()
        exp.eval()
        store_per_trial[trial] = {
            'result': exp.experiment_results, 'PE_matrix': exp.loss_dict_matrix}
    return store_per_trial

