from .experiment_params import (
    arch_dict,
    process_args
)

locationencoder_args = {
        "ortho_weight": 0,
        "ortho_weight_space": 0,
        "ortho_weight_time": 0,
        "positional_embedding_type" : "sphericalharmonics",
        "harmonics_calculation" : "analytic",
        "legendre_polys" : 40,
        "combination_type" : "concatenation",
        "time_embedding_type" : "fourier",
        "time_embedding_dim" : 120,
        "time_grad_penalty_weight": 0.0, 
        "normality_flag" : True,
        "ortho_exponent" : 1,
        **arch_dict["baseline_arch_v1"],
    }

# from pathlib import Path
from typing import List

def get_ace_file_paths(
    data_root_dir: str,  
    n_months: int = 12,
    start_year: int = 2021,
    start_month: int = 1
) -> List[str]:
    """
    Returns file paths for a configurable number of consecutive months (n_months)
    for each year, starting from start_year for n_years.
    """
    first_year = start_year
    available_years = 4
    file_paths = []
    n_months = min(n_months, 12)
    # Use provided start_month and start_year for deterministic tests
    months = [(start_month - 1 + i) % 12 for i in range(n_months)]
    years = [first_year + (start_month - 1 + i) // 12 for i in range(n_months)]

    for year, month in zip(years, months):
        file_name = f"{year}{month+1:02d}0100.nc"
        file_paths.append(data_root_dir + file_name)

    return file_paths

# from pathlib import Path
data_root_dir = "./datasets/ace/"
datamodule_args = {
        "file_paths" : (
        #   [data_root_dir + "2021010100.nc"] 
          [data_root_dir + f"20210{m}0100.nc" for m in range(1, 10)] + [data_root_dir + f"2021{mm}0100.nc" for mm in range(10, 13)] 
        #  +
        #   [data_root_dir + f"20220{m}0100.nc" for m in range(1, 10)] + [data_root_dir + f"2022{mm}0100.nc" for mm in range(10, 13)]
        #  +
        #   [data_root_dir + f"20230{m}0100.nc" for m in range(1, 10)] + [data_root_dir + f"2023{mm}0100.nc" for mm in range(10, 13)]
        #  +
        #   [data_root_dir + f"20240{m}0100.nc" for m in range(1, 10)] + [data_root_dir + f"2024{mm}0100.nc" for mm in range(10, 13)]
        ),  
        "mode" : "spatio_temporal_interpolation",
        "train_fraction" : 0.01,
        "val_fraction" : 0.01,
        "test_fraction" : 0.01,
        # "T" : 112,
        "num_workers" : 8, 
        "batch_size" : 40_000,
        "perturbed_training" : True,
        "perturbation_scale" : 0.0,

        "shuffle_training_data" : True,
        "variable_selection" : list(range(15, 23)), # 8 TEMPERATURE VARIABLES,
        "subset_fraction" : 1.,
    }

args_ace_dict = {
    # Task Args
    "dataset": 'era5dataset_multi',  # Use this name for now; easier
    "regression": True,

    # Model Args
    "min_radius": None,

    # Training Args
    "max_epochs": -1,
    "lr": 0.001,
    "wd": 0.00001,
    "patience": 5,

    "gpus": 1,
    "accelerator": 'auto',

    # Logging and Visualization Args
    "wandb_project": 'ace-quickstart',
    "output_root": './output',
    "results_dir": 'results/train/notebook/',
    "save_model": True,
    "log_wandb": True,
    "expname": None,
    "resume_ckpt_from_results_dir": False,
    "matplotlib": False,
    "matplotlib_show": False,
    "use_expnamehps": False,

    # Other Args
    "hparams": './hparams.yaml',
    "seed": 1,
}

# args_ace_dict["max_epochs"] = 20
# # args_ace_dict["lr"] = 0.038608701121877934
# # args_ace_dict["wd"] = 1e-5
# # # datamodule_args["variable_selection"] = list(range(55)) # ALL VARIABLES
# datamodule_args["variable_selection"] = list(range(15, 23)) # 8 TEMPERATURE VARIABLES
# # # datamodule_args["variable_selection"] = [54] # Total water path variable (single variable)
# # # datamodule_args["variable_selection"] = [15] # Single temperature variable
# # datamodule_args["batch_size"] = 4_493
# locationencoder_args["ortho_weight"] = 1.401736663782711
# locationencoder_args["time_embedding_type"] = "fourier"
# locationencoder_args["time_embedding_dim"] = 72
# # locationencoder_args["combined_encoding_args"]["num_layers"] = 4

locationencoder_args, datamodule_args = process_args(locationencoder_args, datamodule_args)

from types import SimpleNamespace
args_ace_dict["locationencoder_args"] = locationencoder_args
args_ace_dict["datamodule_args"] = datamodule_args
args_ace = SimpleNamespace(**args_ace_dict)
