import argparse
import logging
from functools import partial
from pathlib import Path

import git
import neps
import torch

from experiments.resnet.hpo.utils import training_pipeline

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    parser = argparse.ArgumentParser()

    parser.add_argument("--group_name", type=str, help="Group name")
    parser.add_argument("--dataset", type=str, help="Dataset to use", default="c100")
    parser.add_argument(
        "--num_dataloader_workers", type=int, help="Number of dataloader workers", default=8
    )
    parser.add_argument(
        "--optimizer", type=str, help="Optimizer to use", default="successive_halving"
    )
    parser.add_argument("--model_name", type=str, help="Model to use", default="resnet18")
    args = parser.parse_args()

    # Set seeds for reproducibility
    SEED = 42
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    import numpy as np

    np.random.seed(SEED)
    import random

    random.seed(SEED)

    pipeline_space = {
        "learning_rate": neps.Categorical(choices=[1e-4, 1e-3, 1e-2]),
        "beta1": neps.Categorical(choices=[0.9, 0.95, 0.99]),
        "beta2": neps.Categorical(choices=[0.9, 0.95, 0.99]),
        "weight_decay": neps.Categorical(choices=[1e-5, 1e-4, 1e-3, 1e-2]),
        "optimizer_name": neps.Categorical(choices=["adam", "sgd"]),
        "n_trainable": neps.Integer(lower=1, upper=11, is_fidelity=True),
    }

    root_directory = Path(git.Repo(".", search_parent_directories=True).working_tree_dir) / "output"  # type: ignore
    if not root_directory.exists():
        try:
            root_directory.mkdir(parents=True)
        except FileExistsError:
            print("Directory already exists")

    neps.run(
        evaluate_pipeline=partial(
            training_pipeline,
            batch_size=1024,
            num_dataloader_workers=args.num_dataloader_workers,
            dataset=args.dataset,
            model_name=args.model_name,
            optimizer=args.optimizer,
            epochs=10,
            fidelity="n_trainable",
        ),
        pipeline_space=pipeline_space,
        optimizer={"name": args.optimizer, "eta": 2, "early_stopping_rate": 0},
        root_directory=f"{root_directory}/{args.group_name}/n_trainable/{args.optimizer}/{args.dataset}",
        max_evaluations_total=200,
    )
