
import os
import sys
sys.path.append(os.getcwd())
from tqdm import tqdm
import toml
from src.run import run_one_model
from src.plotting_results.storing_results import create_file_path, store_results
import numpy as np
project_path = os.getcwd()


def create_file_name(cfg):
    list_name = [str(cfg['dataset']['T']), str(
        cfg["experiment"]["retrain_cost"]), str(
        cfg["experiment"]["mode"]), cfg["model"]['name']]
    for key, val in cfg['model'].items():
        if type(val) is bool:
            if val:
                file_str = '_'+key
        elif type(val) is str:
            file_str = '_'+key+'_'+val.replace(' ', '_')
        else:
            file_str = '_'+key+'_'+str(val)
        list_name.append(file_str)

    dataset_path = cfg['dataset']['name']
    if cfg["experiment"]["mode"] != 'normal':
        dataset_path = dataset_path +'_'+ cfg["experiment"]["mode"]
    filepath = create_file_path(
        list_name, [project_path, cfg['reporter']['store_path'], dataset_path])
    return filepath


if __name__ == "__main__":
    
    config_name = 'configs/retraining.toml'
    with open(os.path.join(project_path, config_name), mode="r") as f:
        cfg = toml.load(f)
    dataset_config_name = 'configs/datasets_config/' + \
        sys.argv[1]+'.toml'

    with open(os.path.join(project_path, dataset_config_name), mode="r") as f:
        cfg_dataset = toml.load(f)
    cfg['dataset'] = cfg_dataset['dataset']
    mode = sys.argv[2]
    relative_pe = False
    cfg['experiment']['relative_pe'] = relative_pe
    cfg['experiment']['mode'] = mode
    if relative_pe:  # if we are considering a relative loss, the costs should be smaller
        retrain_costs = [t/5 for t in [0, 0.01, 0.05, 0.1, 0.2, 0.4, 0.5, 1]]
    else:
        dataset_name = cfg['dataset']['name']
        if dataset_name == 'yelp':
            a_max = 0.1
        elif dataset_name == 'circles':
            a_max = 0.25
        elif dataset_name == 'airplanes':
            a_max = 0.7
        elif dataset_name == 'electricity':
            a_max = 1
        elif dataset_name == 'epicgames':
            a_max = 0.1
        elif dataset_name == 'gauss':
            a_max = 0.5
        elif dataset_name == 'wild':
            a_max = 1
        
        retrain_costs = np.linspace(0, a_max, 11).tolist()
      
        print(retrain_costs)
    dict_filepaths = {}
    for retrain_cost in tqdm(retrain_costs):
        cfg["experiment"]["retrain_cost"] = retrain_cost
        for model_cfg in cfg["models"]:
            cfg["model"] = model_cfg
            store_per_trial = run_one_model(cfg, reporter=None)
            filepath = create_file_name(cfg)
            dict_filepaths[filepath] = True
            store_results(store_per_trial, cfg, filepath)
