from .experiment_params import (
    arch_dict,
    process_args
)

locationencoder_args = {
        "ortho_weight": 0,
        "positional_embedding_type" : "sphericalharmonics",
        "harmonics_calculation" : "analytic",
        "legendre_polys" : 40,
        "combination_type" : "concatenation",
        "time_embedding_type" : "fourier",
        "time_embedding_dim" : 60,
        **arch_dict["baseline_arch_v1"],
    }

from pathlib import Path
data_root_dir = Path("./datasets/ace/")
datamodule_args = {
        "file_paths" : (
          [data_root_dir / "2021010100.nc"] 
        #   [data_root_dir / f"20210{m}0100.nc" for m in range(1, 10)] + [data_root_dir / f"2021{mm}0100.nc" for mm in range(10, 13)] 
        #  +
        #   [data_root_dir / f"20220{m}0100.nc" for m in range(1, 10)] + [data_root_dir / f"2022{mm}0100.nc" for mm in range(10, 13)]
        #  +
        #   [data_root_dir / f"20230{m}0100.nc" for m in range(1, 10)] + [data_root_dir / f"2023{mm}0100.nc" for mm in range(10, 13)]
        #  +
        #   [data_root_dir / f"20240{m}0100.nc" for m in range(1, 10)] + [data_root_dir / f"2024{mm}0100.nc" for mm in range(10, 13)]
        ),  
        # "mode" : "spatio_temporal_interpolation",
        "mode" : "forecast",
        "train_fraction" : 0.1,
        "val_fraction" : 0.1,
        "test_fraction" : 0.1,
        # "T" : 112,
        "num_workers" : 8, 
        "batch_size" : 40_000,
        "perturbed_training" : True,
        "perturbation_scale" : 0.0,

        "shuffle_training_data" : True,
        "variable_selection" : list(range(15, 23)), # 8 TEMPERATURE VARIABLES,
        "subset_fraction" : 1.,
        # "variable_selection" : [54]
    }

args_ace_dict = {
    # Task Args
    "dataset": 'era5dataset_multi',  # Use this name for now; easier
    "regression": True,

    # Model Args
    "min_radius": None,

    # Training Args
    "max_epochs": -1,
    "lr": 0.001,
    "wd": 0.00001,
    "patience": 5,

    "gpus": 1,
    "accelerator": 'auto',

    # Logging and Visualization Args
    "output_root": './output',
    "results_dir": 'results/train/notebook/',
    "save_model": False,
    "log_wandb": True,
    "expname": None,
    "resume_ckpt_from_results_dir": False,
    "matplotlib": False,
    "matplotlib_show": False,
    "use_expnamehps": False,

    # Other Args
    "hparams": './hparams.yaml',
    "seed": 1,
}

args_ace_dict["max_epochs"] = 5
# # datamodule_args["variable_selection"] = list(range(55)) # ALL VARIABLES
datamodule_args["variable_selection"] = list(range(15, 23)) # 8 TEMPERATURE VARIABLES
# # datamodule_args["variable_selection"] = [54] # Total water path variable (single variable)
# # datamodule_args["variable_selection"] = [15] # Single temperature variable
locationencoder_args["ortho_weight"] = 0*1e5
locationencoder_args["time_embedding_type"] = "fourier"
locationencoder_args["time_embedding_dim"] = 120
# locationencoder_args["combined_encoding_args"]["num_layers"] = 4
# locationencoder_args["combined_encoding_args"]["name"] = "siren"

locationencoder_args, datamodule_args = process_args(locationencoder_args, datamodule_args)

from types import SimpleNamespace
args_ace_dict["locationencoder_args"] = locationencoder_args
args_ace_dict["datamodule_args"] = datamodule_args
args_ace = SimpleNamespace(**args_ace_dict)
