import os
from typing import Callable

import numpy as np
import spaces

# global information for speedy weather related tasks
PARAM_SPACES = {
    # "planet": {
    #     # "rotation": spaces.NBoxSpace(min_=7.29e-5, max_=7.291e-5),
    #     # "gravity": spaces.NBoxSpace(min_=9.7639, max_=9.8337),
    #     # "axial_tilt": spaces.NBoxSpace(min_=23.44, max_=23.45),
    #     # "solar_constant": spaces.NBoxSpace(min_=1360, max_=1376),
    # },
    "atmosphere": {
        # "mol_mass_dry_air": spaces.NBoxSpace(min_=28.88, max_=28.97),
        # "mol_mass_vapour": spaces.NBoxSpace(min_=18.0153, max_=18.016),
        # "heat_capacity": spaces.NBoxSpace(min_=900, max_=1600),
        # "R_gas": spaces.NBoxSpace(min_=8.3145, max_=8.3146),
        # "R_dry": spaces.NBoxSpace(min_=287.05, max_=287.06),
        # "R_vapour": spaces.NBoxSpace(min_=461.5, max_=461.5),
        # "mol_ratio": spaces.NBoxSpace(min_=0.62197006, max_=0.62197007),
        # "mu_virt_temp": spaces.NBoxSpace(min_=0.60779446, max_=0.60779447),
        # "kappa": spaces.NBoxSpace(min_=0.2859107, max_=0.2859108),
        # "latent_heat_condensation": spaces.NBoxSpace(min_=2.501e6, max_=2.501e6),
        # "latent_heat_sublimation": spaces.NBoxSpace(min_=2.801e6, max_=2.802e6),
        # "stefan_boltzmann": spaces.NBoxSpace(min_=5.67e-8, max_=5.68e-8),
        "pres_ref": spaces.NBoxSpace(min_=9.2e4, max_=100e4),
        "temp_ref": spaces.NBoxSpace(min_=275, max_=300),
        "moist_lapse_rate": spaces.NBoxSpace(min_=3.5e-3, max_=9.8e-3),
        # "dry_lapse_rate": spaces.NBoxSpace(min_=0.0098, max_=0.0099),
        "layer_thickness": spaces.NBoxSpace(min_=8e3, max_=20e3),
    },
}

# compute parameter grid in discrete case
latent_sizes = [2] * 4  # [10], 5, 5, 10]
param_spaces = {f"{k}.{kk}": vv for k, v in PARAM_SPACES.items() for kk, vv in v.items()}
param_samples = {k: np.linspace(v.min_, v.max_, latent_sizes[i]) for i, (k, v) in enumerate(param_spaces.items())}
PARAM_GRID = np.stack(np.meshgrid(*[list(v) for v in param_samples.values()], indexing="ij"), axis=-1).reshape(
    -1, len(param_samples)
)

def discrete_uniform(julia_project_dir, output_path, simulation_days=30):
    for run_id, args in enumerate(PARAM_GRID):
        if run_id not in [4, 5]:
            arguments = ""
            for i, k in enumerate(param_samples.keys()):
                arguments += f" --{k}={args[i]}"
            arguments += f" --output.path {output_path} --simulation_days {simulation_days} --output.id {run_id+1:04d}"

            os.chdir(julia_project_dir)
            os.system(f"julia --project='.' src/speedy_weather_jl.jl {arguments}")


def simulate_weather(
    sim_fn: Callable,
    num_simulations=1000,
    simulation_days=30,
    output_path="",
):
    """
    Simulate weather using the Speedy Weather model
    """
    os.makedirs(output_path, exist_ok=True)
    return sim_fn(num_simulations, output_path, simulation_days)
