# %%
import torch

from config.inat_params import args_inat
from inat.InatExperimentRunner import InatExperimentRunner

# %%
# gpu tensor cores
torch.set_float32_matmul_precision('medium')

# %%
experiments = [
{ # e68 outer_product, no ortho-weight, result table inat
    "seed": list(range(10)),  # random seeds
    "max_epochs": [100],
    "patience": [0],
    "ortho_weight": [0, 1e-4],
    "ortho_exponent": [2],
    "combination_type": ["concatenation"],
    "time_embedding_dim": [8],
    "time_embedding_type": ["fourier", "legendre", "constant", 
                            "monomial", "triangle", "time_copy"],
    "legendre_polys": [8],  # determines the dimension of spatial embeddings
    "variable_cut_off": [8142], # 8142 is the number of classes in iNat2018
    "batch_size": [1000], # batch size for training
    "arch_name": ["baseline_arch_v1"], 
   },
]

# %%
# define wandb parameters
project = ""
dataset = "iNat2018"
log_dir = "./logs"

for experiment in experiments:
    import itertools
    sub_experiments = []
    for value_combination in itertools.product(*experiment.values()):
        sub_experiment = {}
        for key, value in zip(experiment.keys(), value_combination):
            sub_experiment[key] = value
        sub_experiments.append(sub_experiment)

    for sub_experiment in sub_experiments:
        tag = f"aaai2026"
        runner = InatExperimentRunner(
            args_inat=args_inat,
            sub_experiment={**sub_experiment},
            log_dir=log_dir,
            tag=tag,
            project=project,
            dataset=dataset,
        )
        runner.run()

