baseline_arch_v1 = {
    "combined_encoding_args": {
        "name" : "fcnet",
        "input_dim" : -1,
        "output_dim" : -1, 
        "dim_hidden" : 1024,
        "num_layers" : 2,
    },
    "time_encoding_args": {
        "name": "identity",
        "input_dim" : -1,
        "output_dim" : -1,
        "dim_hidden" : -1, 
        "num_layers" : -1, 
    },
    "spatial_encoding_args": {
        "name" : "identity",
        "input_dim" : -1,
        "output_dim" : -1,
        "dim_hidden" : -1,
        "num_layers" : -1,
    },
}

baseline_arch_v1_large = {
    "combined_encoding_args": {
        "name" : "fcnet",
        "input_dim" : -1,
        "output_dim" : -1, 
        "dim_hidden" : 1024,
        "num_layers" : 4,
    },
    "time_encoding_args": {
        "name": "identity",
        "input_dim" : -1,
        "output_dim" : -1,
        "dim_hidden" : -1, 
        "num_layers" : -1, 
    },
    "spatial_encoding_args": {
        "name" : "identity",
        "input_dim" : -1,
        "output_dim" : -1,
        "dim_hidden" : -1,
        "num_layers" : -1,
    },
}

# create dict of all archs
arch_dict = {
    "baseline_arch_v1": baseline_arch_v1,
    "baseline_arch_v1_large": baseline_arch_v1_large,
}


from pathlib import Path

def process_args(locationencoder, datamodule):
    # Ensure file_paths are Path objects
    if "file_paths" in datamodule:
        datamodule["file_paths"] = [
            fp if isinstance(fp, Path) else Path(fp)
            for fp in datamodule["file_paths"]
        ]

    # apply arch choice
    if "arch_name" in locationencoder:
        if locationencoder["arch_name"] in arch_dict:
            locationencoder.update(arch_dict[locationencoder["arch_name"]])
        else:
            raise ValueError(f"Unknown arch name: {locationencoder['arch_name']}")
        
    if "variable_selection" in datamodule:
        locationencoder["combined_encoding_args"]["output_dim"] = len(datamodule["variable_selection"])
    else:
        locationencoder["combined_encoding_args"]["output_dim"] = datamodule["num_classes"]

    if "num_timesteps" in datamodule:
        locationencoder["number_of_timesteps"] = datamodule["num_timesteps"]

    locationencoder["spatial_encoding_args"]["input_dim"] = locationencoder["legendre_polys"]**2
    locationencoder["time_encoding_args"]["input_dim"] = locationencoder["time_embedding_dim"]
    if locationencoder["time_encoding_args"]["name"] == "identity":
        locationencoder["time_encoding_args"]["output_dim"] = locationencoder["time_encoding_args"]["input_dim"]
    if locationencoder["spatial_encoding_args"]["name"] == "identity":
        locationencoder["spatial_encoding_args"]["output_dim"] = locationencoder["spatial_encoding_args"]["input_dim"]

    if locationencoder["combination_type"] == "concatenation":
        locationencoder["combined_encoding_args"]["input_dim"] = (
            locationencoder["time_encoding_args"]["output_dim"] + locationencoder["spatial_encoding_args"]["output_dim"]
        )
    elif locationencoder["combination_type"] == "outer_product":
        locationencoder["combined_encoding_args"]["input_dim"] = (
            locationencoder["time_encoding_args"]["output_dim"] * locationencoder["spatial_encoding_args"]["output_dim"]
        )
    elif locationencoder["combination_type"] == "hadamard_product":
        locationencoder["combined_encoding_args"]["input_dim"] = locationencoder["time_encoding_args"]["output_dim"]
        if locationencoder["time_encoding_args"]["output_dim"] != locationencoder["spatial_encoding_args"]["output_dim"]:
            raise ValueError("Time and spatial encoding output dimensions must be equal for hadamard product combination type")
    elif locationencoder["combination_type"] == "forget_time":
        locationencoder["combined_encoding_args"]["input_dim"] = locationencoder["spatial_encoding_args"]["output_dim"]

    return locationencoder, datamodule
