import argparse
from pathlib import Path

import optuna
import yaml

import data_generator
import train
from models import EnumMethod
from utils.functions import potentials_all

CONFIG_PATH = Path("config.yaml")
CONFIG_EXTRA_PATH = Path("config-inverse-jkonet-extra.yaml")


def objective(trial: optuna.Trial, solver: EnumMethod, potential: str, mode: str, dim: int) -> float:
    # Load and modify config
    with open(CONFIG_PATH) as f:
        config_main = yaml.safe_load(f)
    with open(CONFIG_EXTRA_PATH) as f:
        config_extra = yaml.safe_load(f)

    # Suggest OTMap optimizer params
    config_extra["otmap"]["optim"]["lr"] = trial.suggest_float("otmap_lr", 1e-5, 1e-2, log=True)
    config_extra["otmap"]["optim"]["weight_decay"] = trial.suggest_float("otmap_weight_decay", 1e-4, 1e-1, log=True)
    config_extra["otmap"]["optim"]["beta1"] = trial.suggest_float("otmap_beta1", 0.5, 0.99)
    config_extra["otmap"]["optim"]["beta2"] = trial.suggest_float("otmap_beta2", 0.9, 0.9999)
    config_extra["otmap"]["optim"]["eps"] = trial.suggest_float("otmap_eps", 1e-9, 1e-6, log=True)
    config_extra["otmap"]["optim"]["grad_clip"] = trial.suggest_float("otmap_grad_clip", 1.0, 20.0)
    config_extra["otmap"]["optim"]["inner_iter"] = trial.suggest_int("otmap_inner_iter", 5, 15)

    # OTMap model structure
    config_extra["otmap"]["model"]["type"] = "MLP"
    otmap_layers = trial.suggest_int("otmap_num_layers", 1, 4)
    otmap_dim = trial.suggest_categorical("otmap_layer_size", [32, 64, 128, 256])
    config_extra["otmap"]["models"]["MLP"]["dim_hidden"] = [otmap_dim] * otmap_layers

    # Energy optimizer
    config_main["energy"]["optim"]["lr"] = trial.suggest_float("energy_lr", 1e-5, 1e-2, log=True)
    config_main["energy"]["optim"]["weight_decay"] = trial.suggest_float("energy_weight_decay", 1e-4, 1e-1, log=True)
    config_main["energy"]["optim"]["beta1"] = trial.suggest_float("energy_beta1", 0.5, 0.99)
    config_main["energy"]["optim"]["beta2"] = trial.suggest_float("energy_beta2", 0.9, 0.9999)
    config_main["energy"]["optim"]["eps"] = trial.suggest_float("energy_eps", 1e-9, 1e-6, log=True)
    config_main["energy"]["optim"]["grad_clip"] = trial.suggest_float("energy_grad_clip", 1.0, 20.0)

    energy_layers = trial.suggest_int("energy_num_layers", 1, 4)
    energy_dim = trial.suggest_categorical("energy_layer_size", [32, 64, 128, 256])
    config_main["energy"]["model"]["layers"] = [energy_dim] * energy_layers

    # Shared settings
    SEED = 0
    DT = 0.01

    dataset_name = ""
    GROUP_NAME = ""
    data_generator_namespace = None

    if mode == "biology":
        dataset_name = f"RNA_PCA_{dim}"
        GROUP_NAME = f"param_search_for_{dataset_name}"
        n_timesteps = 5

        config_main_path = Path(f"configs/{dataset_name}/config_main_{trial.number}.yaml")
        config_extra_path = Path(f"configs/{dataset_name}/config_extra_{trial.number}.yaml")

        data_generator_namespace = argparse.Namespace(
            load_from_file=dataset_name,
            test_ratio=0.1,
            split_population=True,
        )

        DT = trial.suggest_float("dt", 0.001, 1.0, log=True)
    elif mode == "synthetic":
        # Synthetic data settings
        N = 2000
        SIMULATOR = "jko"
        GMM = 0
        PAIRED_TRAIN_RATIO = 0.0
        T = 5

        dataset_name = (
            f"simulator_{SIMULATOR}_potential_{potential}_internal_none_beta_0.0_interaction_none"
            f"_dt_{DT}_T_{T}_dim_{dim}_N_{N}_gmm_{GMM}_seed_{SEED}_split_0.5_split_trajectories_True"
            f"_lo_-1_sinkhorn_0.0_paired_train_ratio_{PAIRED_TRAIN_RATIO}"
        )
        GROUP_NAME = f"param_search_for_{potential}"
        n_timesteps = T + 1
        data_generator_namespace = argparse.Namespace(
            simulator=SIMULATOR,
            potential=potential,
            n_particles=N,
            test_ratio=0.5,
            dimension=dim,
            paired_train_ratio=PAIRED_TRAIN_RATIO,
            seed=SEED,
            dt=DT,
            n_gmm_components=GMM,
        )

        config_main_path = Path(f"configs/{potential}/config_main_{trial.number}.yaml")
        config_extra_path = Path(f"configs/{potential}/config_extra_{trial.number}.yaml")

    config_main["train"]["dt"] = DT

    config_main_path.parent.mkdir(parents=True, exist_ok=True)
    config_main_path.write_text(yaml.dump(config_main))
    config_extra_path.write_text(yaml.dump(config_extra))

    args_data_generator = data_generator.get_parser().parse_args([], namespace=data_generator_namespace)
    try:
        data_generator.main(args_data_generator)
    except Exception as e:
        raise optuna.exceptions.TrialPruned(f"Data generation failed: {e}")

    args_main = train.get_parser().parse_args(
        [],  # No arguments are parsed from CLI
        namespace=argparse.Namespace(
            solver=solver,
            dataset=str(dataset_name),
            wandb=True,
            debug=True,
            seed=SEED,
            device=0,
            group_name=GROUP_NAME,
            epochs=5000,
            n_timesteps=n_timesteps,
            eval="test_data",
            config=str(config_main_path),
            extra_config=str(config_extra_path),
        ),
    )

    try:
        return train.main(args_main)
    except Exception as e:
        raise optuna.exceptions.TrialPruned(f"Training failed: {e}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--solver",
        "-s",
        type=EnumMethod,
        choices=list(EnumMethod),
        default=EnumMethod.INVERSE_JKO_NET_MULTIMAP_POTENTIAL,
        help="Name of the solver to use.",
    )
    parser.add_argument(
        "--potential",
        type=str,
        default="flowers",
        choices=list(potentials_all.keys()) + ["none"],
        help="Name of the potential energy to use.",
    )
    parser.add_argument(
        "--dim",
        type=int,
        default=10,
        help="""
        Dimensionality of the particles generated in the synthetic data or used in PCA for single-cell.
        """,
    )
    parser.add_argument(
        "--mode",
        type=str,
        choices=["synthetic", "biology"],
        default="biology",
        help="Which dataset regime to use: 'synthetic' or 'biology'",
    )
    args = parser.parse_args()

    study = optuna.create_study(direction="minimize", study_name=f"optuna_search_{args.mode}")
    study.optimize(
        lambda trial: objective(trial, solver=args.solver, potential=args.potential, mode=args.mode, dim=args.dim),
        n_trials=30,
    )

    print("Best trial:")
    trial = study.best_trial
    print(f"  Value: {trial.value}")
    print("  Params: ")
    for key, value in trial.params.items():
        print(f"    {key}: {value}")
