import numpy as np 
from string import Template
import json
from pathlib import Path
from types import SimpleNamespace
import os
import subprocess

EXP_ID = 1
ROOT_FOLDER = Path(os.getcwd())


config_dct = dict(
    # Experiments parameter
    EXP_ID = EXP_ID,
    N_TRIALS = 1,
    FREQ_MIN = 0.6,
    FREQ_MAX = 2.5,
    FREQ_N_STEPS = 39,
    DECAY_MIN = -0.3,
    DECAY_MAX = 3.0,
    DECAY_N_STEPS = 67, #52
    POWER_MIN = 1.0,
    POWER_MAX = 50.0,
    POWER_N_STEPS = 50,
    SAMPFREQ_MIN = 100,
    SAMPFREQ_MAX = 300,
    SAMPFREQ_N_STEPS = 19,
    ALPHA = np.linspace(0,1,11)[1:-1].tolist(), # Intermediate values between 0 and 1
    REAL_SCALE = 1.0,
    IMAG_SCALE = 1.0,
    SEED = 1,
    RESULT_FOLDER = str(ROOT_FOLDER / f"results/exp_{EXP_ID}"),

    # Signal parameters: 
    BASE_FREQ = [0.5,1.0],
    BASE_AMP = [1.0,1.0],
    BASE_DECAY = [0.0,0.0],
    BASE_PHASE = [0.0,0.0],
    BASE_FUNC = ["cos","cos"],
    BASE_STD = 1e-2,
    SAMPFREQ = 200,
    N_SAMPLES = 4001,

    # Koopman solver configuration
    POLY_ORDER = 1,
    RANK = 4,
    N_JOBS = 10,
    TIKHONOV_REG = 1e-8,
)

if __name__ == "__main__":

    config = SimpleNamespace(**config_dct)

    sentence = f"python deviation_script.py \
            --exp_id {EXP_ID} \
            --n_trials {config.N_TRIALS} \
            --min_frequency {config.FREQ_MIN} \
            --max_frequency {config.FREQ_MAX} \
            --n_step_frequency {config.FREQ_N_STEPS} \
            --min_decay {config.DECAY_MIN} \
            --max_decay {config.DECAY_MAX} \
            --n_step_decay {config.DECAY_N_STEPS} \
            --min_power {config.POWER_MIN} \
            --max_power {config.POWER_MAX} \
            --n_step_power {config.POWER_N_STEPS} \
            --min_sampfreq {config.SAMPFREQ_MIN} \
            --max_sampfreq {config.SAMPFREQ_MAX} \
            --n_step_sampfreq {config.SAMPFREQ_N_STEPS} \
            --alpha {" ".join(map(str, config.ALPHA))} \
            --real_scale {config.REAL_SCALE} \
            --imag_scale {config.IMAG_SCALE} \
            --base_freq {" ".join(map(str, config.BASE_FREQ))} \
            --base_amp {" ".join(map(str, config.BASE_AMP))} \
            --base_decay {" ".join(map(str, config.BASE_DECAY))} \
            --base_phase {" ".join(map(str, config.BASE_PHASE))} \
            --base_func {" ".join(config.BASE_FUNC)} \
            --base_std {config.BASE_STD} \
            --base_sampfreq {config.SAMPFREQ} \
            --n_samples {config.N_SAMPLES} \
            --poly_order {config.POLY_ORDER} \
            --rank {config.RANK} \
            --n_jobs {config.N_JOBS} \
            --tikhonov_reg {config.TIKHONOV_REG} \
            --seed {config.SEED} \
            --save_folder {config.RESULT_FOLDER} \
            "

    #Save config
    config_path = Path(config.RESULT_FOLDER)
    if not config_path.is_dir():
        config_path.mkdir(parents=True)
    with open(config_path / f"config.json", "w") as f: 
        json.dump(config_dct,f)

    #Run script
    subprocess.run(sentence,shell=True)