from types import SimpleNamespace
import sys
import os

import wandb

from utils.ace_utils import (
    fit_ace_ts, 
    )


import config.ace_params as ace_params

from config.experiment_params import (
  arch_dict, 
  process_args
)

from config.sweep_configs import (
    ace_config,
    )

# Set the WANDB_DIR environment variable
os.environ['WANDB_DIR'] = ace_params.args_ace.output_root

def sweep_objective_from_config(config, fit_fun, base_params):
    params = base_params
    if "max_epochs" in config.keys():
        params["max_epochs"] = config.max_epochs
    
    if "datamodule_args" in config.keys():
        for key, value in config["datamodule_args"].items():
            params["datamodule_args"][key] = value

    if "locationencoder_args" in config.keys():
        for key, value in config["locationencoder_args"].items():
            params["locationencoder_args"][key] = value

    if "arch" in config.keys():
        params["locationencoder_args"] = {**params["locationencoder_args"], **arch_dict[config.arch]}

    params["locationencoder_args"], params["datamodule_args"] = process_args(
        params["locationencoder_args"], 
        params["datamodule_args"]
    )
    
    if "seed" in config.keys():
        params["seed"] = config.seed
        
    model, trainer, datamodule = fit_fun(SimpleNamespace(**params))
    trainer.test(model, datamodule=datamodule)

    return trainer.callback_metrics["test_loss"]

def main_generic(project):
    wandb.init(project=project)
    fit_dict = {
        "ace": {
            "fit": fit_ace_ts,
            "args": ace_params.args_ace_dict,
        },
    }

    test_loss = sweep_objective_from_config(wandb.config, fit_dict[project]["fit"], fit_dict[project]["args"])
    wandb.log({"test_loss": test_loss})

if __name__ == "__main__":
    # Get the dataset argument
    dataset = sys.argv[1]
    
    # Check if debug mode is enabled
    debug_mode = "debug" in sys.argv
    # debug_mode = False if len(sys.argv) <= 2 or sys.argv[2] != "debug" else True

    config_dict = {
        "ace": ace_config,
    }

    if debug_mode:
        print("Running in debug mode...")
        # Use a test configuration for debugging
        test_config = {
            "max_epochs": 1,  # Run for only 1 epoch
            "datamodule_args": {
                "train_fraction": 0.01,  # Use a small fraction of the dataset
                "val_fraction": 0.01,
                "test_fraction": 0.01,
                "mode": "spatio_temporal_interpolation",  # Use a simple mode
            },
        }

        config_dict[dataset] = test_config

        # Call the main function directly with the test configuration
        wandb.init(project=dataset, mode="disabled")  # Disable wandb logging in debug mode
        main_generic(dataset)
    else:
        # Run the full sweep
        if len(sys.argv) > 2:
            sweep_id = sys.argv[2]
        else:
            sweep_id = wandb.sweep(config_dict[dataset], project=dataset)
        
        main = lambda: main_generic(dataset)
        wandb.agent(sweep_id, function=main)