from .experiment_params import (
    arch_dict,
    process_args
)

locationencoder_args = {
        "ortho_weight": 0,
        "ortho_weight_space" : 0,
        "ortho_weight_time" : 0,
        "positional_embedding_type" : "sphericalharmonics",
        "harmonics_calculation" : "analytic",
        "legendre_polys" : 40,
        "combination_type" : "concatenation",
        "time_embedding_type" : "fourier",
        "time_embedding_dim" : 120,
        "time_grad_penalty_weight" : 0,
        "ortho_exponent": 1,
        "normality_flag": True,
        "arch_name" : "baseline_arch_v1",
        "number_of_timesteps": 365,
    }

from pathlib import Path
datamodule_args = {
        "root":"/network/projects/location-embeddings/datasets/birdsnap/", 
        "mode" : "spacetime",
        "num_workers" : 4, 
        "batch_size" : 2000,
        "num_classes": 500,
        # New subset parameters for iid sampling:
        "subset_fraction": 1.0,   # adjust (e.g., 0.1 for 10% subset)
        #"shuffle_training_data" : True,
    # "num_timesteps":365,

    }

args_birdsnap_dict = {
    # Task Args
    "dataset": 'birdsnap',
    "regression": False,
    "presence_only_loss": False,
    # Model Args
    "min_radius": None,
    
    # Training Args
    "max_epochs": 10,
    "lr": 0.001,
    "wd": 1e-5,
    "patience": 5,
 
    "gpus": 1,
    "accelerator": 'auto',

    # Logging and Visualization Args
    "output_root": './output_inat',
    "results_dir": './output_inat',
    "save_model": True,
    "log_wandb": True,
    "expname": None,
    "resume_ckpt_from_results_dir": False,
    "matplotlib": False,
    "matplotlib_show": False,
    "use_expnamehps": False,

    # Other Args
    "hparams": './hparams.yaml',
    "seed": 15,
}

locationencoder_args, datamodule_args = process_args(locationencoder_args, datamodule_args)

from types import SimpleNamespace
args_birdsnap_dict["locationencoder_args"] = locationencoder_args
args_birdsnap_dict["datamodule_args"] = datamodule_args
args_birdsnap = SimpleNamespace(**args_birdsnap_dict)
