# %%
import itertools
import torch
from config.birdsnap import args_birdsnap

from birdsnap.birdsnap_experiment_runner import BirdsnapExperimentRunner

import logging

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s BIRDSNAP %(levelname)s %(name)s: %(message)s"
)
logger = logging.getLogger(__name__)

torch.set_float32_matmul_precision('medium')

experiments = [
    { 
        "seed": list(range(10)),  # random seeds
        "max_epochs": [100],
        "patience" : [0],
        "ortho_weight": [0, 1e-4], 
        "combination_type": ["outer_product"],
        "ortho_exponent": [2],
        "time_embedding_dim": [8],
        "time_embedding_type": [
            "constant",
            "fourier",
            "time_copy",
            "legendre",
            "monomial",
            "triangle",
            ],
        "subset_fraction": [1.0],
        "arch_name": ["baseline_arch_v1_large"],
    },
]

log_dir = "./logs"

for experiment in experiments:
    sub_experiments = []
    for value_combination in itertools.product(*experiment.values()):
        sub_experiment = {}
        for key, value in zip(experiment.keys(), value_combination):
            sub_experiment[key] = value
        sub_experiments.append(sub_experiment)

    for sub_experiment in sub_experiments:
        runner = BirdsnapExperimentRunner(
            args_birdsnap=args_birdsnap,
            sub_experiment=sub_experiment,
            log_dir=log_dir,
            project="location-encoder",
            dataset="birdsnap", 
            tag="",
        )
        runner.run()

