import argparse
from src.load import get_model, eval_model
from examples.configs import get_config
from examples.energies import get_problem

parser = argparse.ArgumentParser(description="Choosing problem")
parser.add_argument("--problem", type=str, default="funnel")
parser.add_argument(
    "--experiment_id", type=int, default=None
)  # numbering for several runs, can be None
inp = parser.parse_args()

experiment_id = inp.experiment_id

# specify problem and dimension
if inp.problem == "funnel":
    dim = 10
    problem_name = "funnel"
elif inp.problem == "mustache" or inp.problem == "schnauzbart":
    dim = 2
    problem_name = "schnauzbart"
elif inp.problem == "8peaky":
    dim = 2
    problem_name = "8peaky"
elif inp.problem == "8mixtures" or inp.problem == "8modes":
    dim = 2
    problem_name = "8mixtures"
elif inp.problem == "GMM10":
    dim = 10
    problem_name = "mixtures"
elif inp.problem == "GMM20":
    dim = 20
    problem_name = "mixtures"
elif inp.problem == "GMM50":
    dim = 50
    problem_name = "mixtures"
elif inp.problem == "GMM100":
    dim = 100
    problem_name = "mixtures"
elif inp.problem == "GMM200":
    dim = 200
    problem_name = "mixtures"
elif inp.problem == "lgcp":
    dim = 1600
    problem_name = "lgcp"

if problem_name == "mixtures" and experiment_id is None:
    raise ValueError(
        "experiment_id is required for the GMM example for loading the same model as for training!"
    )

# make choice of the means reproducible if this is an mixture example
# otherwise this argument will be ignored
additional_info = {}
if experiment_id is not None:
    additional_info["mean_id"] = experiment_id

# load target energy (energy = negative log density up to some constant)
target_energy, sampler, dim, axis_scale, additional_info = get_problem(
    problem_name, dim=dim, additional_info=additional_info
)

# load means of mixture modes for evaluation if available
means = additional_info["means"] if "means" in additional_info.keys() else None

# load hyperparameters (we need only a few of them for evaluation, but for convenience we
# load them in the same way as for training)
args = get_config(problem_name, dim)

# choose stack size for sampling. This has no influence on the results, just on the computation time
stack_size = 25000
if problem_name == "mixtures" and dim == 200:
    stack_size = 5000
if problem_name == "lgcp":
    stack_size = 1000

# load model
model = get_model(target_energy, dim, args, problem_name, experiment_id)

# samples with corresponding energies (=normalized negative log densities) can be generated by
# samples, sample_energies = model.sample(n_samples)

# compute evaluation metrics
log_Z, energy_distance, mode_MSE, sampling_time = eval_model(
    model, target_energy, sampler, means, stack_size=stack_size
)

print(f"The generation of 50000 samples took {sampling_time} seconds.")
print(f"log(Z) estimate: {log_Z}")
if energy_distance is not None:
    print(f"Energy distance to target distribution: {energy_distance}")
if mode_MSE is not None:
    print(f"Mode MSE: {mode_MSE}")
