import copy
import gc
import sys
from collections.abc import Callable
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Literal

import optuna
import torch
import torch.nn as nn
from optuna.study import MaxTrialsCallback
from optuna.trial import TrialState
from transformers import (
    AutoModelForCausalLM,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
)

from scripts.merge.mask import mask_model_weights
from scripts.merge.task_vector import TaskVector
from scripts.merge.ties import (
    _get_param_signs,
    _mask_smallest,
    _mask_smallest_magnitude_param_values,
    _single_vector_to_task_vector_param_dict,
    _task_vector_param_dict_to_single_vector,
)
from scripts.utils.pure import (
    SimpleDataLoder,
    copy_params_to_model,
    generate_response,
    print_cpu_memory_usage,
    print_gpu_memory_usage,
    print_metrics,
)


def run_inference(
    cfg: SimpleNamespace,
    model: nn.Module,
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
    dataloader: SimpleDataLoder,
    metrics: dict[str, Callable[[str, str], float]],
    device: torch.device,
) -> dict[str, float]:
    res = {k: 0.0 for k in metrics.keys()}
    skipped = 0
    for i, (text, image, dicom_id) in enumerate(dataloader):
        if not text:
            skipped += 1
            continue
        response = generate_response(cfg, model, tokenizer, image, device)
        if cfg.print_text:
            print(f"[Dicom ID {i}]\n{dicom_id}")
            print(f"[Text {i}]\n{text}")
            print(f"[Generated {i}]\n{response}")
        for metric_name, metric_fn in metrics.items():
            metric_value = metric_fn(response, text)
            res[metric_name] += metric_value
            if cfg.print_text:
                print(f"{metric_name}: {metric_value:.4f}")
    res = {k: v / (len(dataloader) - skipped) for k, v in res.items()}
    # wandb.log(res)
    return res


@dataclass
class DAREConfig:
    mask_rate_type: Literal["each", "same"]
    use_weight_rescale: dict[str, bool]
    mask_strategy: dict[str, str]


@dataclass
class TiesConfig:
    param_mask_type: Literal["each", "same"]
    scaling_coefficient: float | None
    use_each_model_weight: bool


# def run_evolve(
#     cfg: SimpleNamespace,
#     exclude_param_names_regex: list[str],
#     tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
#     dataloader: SimpleDataLoder,
#     metrics: dict[str, Callable[[str, str], float]],
#     device: torch.device,
# ):
#     with torch.no_grad():
#         base_model = (
#             AutoModelForCausalLM.from_pretrained(
#                 cfg.base_model_name,
#                 torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
#                 trust_remote_code=True,
#                 # device_map="auto",
#             ).eval()
#             # .to(device)
#         )

#         finetuned_models = {
#             name: AutoModelForCausalLM.from_pretrained(
#                 model_name,
#                 torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
#                 trust_remote_code=True,
#                 # device_map="auto",
#             ).eval()
#             # .to(device)
#             for name, model_name in cfg.finetuned_models.items()
#         }
#         finetuned_models["vlm"] = (
#             AutoModelForCausalLM.from_pretrained(
#                 cfg.vlm.name,
#                 torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
#                 trust_remote_code=True,
#                 # device_map="auto",
#             ).eval()
#             # .to(device)
#         )

#     def objective(trial):
#         with torch.no_grad():
#             if cfg.dare.mask_rate_type == "each":
#                 mask_rates = {
#                     finetuned_model_name: trial.suggest_float(
#                         f"dare_{finetuned_model_name}_mask_rate", 0.0, 1.0
#                     )
#                     for finetuned_model_name in finetuned_models.keys()
#                 }
#             elif cfg.dare.mask_rate_type == "same":
#                 mask_rates = trial.suggest_float("dare_mask_rate", 0.0, 1.0)
#             else:
#                 raise NotImplementedError

#             print("trial started")
#             print_gpu_memory_usage()

#             task_vectors = {}
#             base_model_gpu = copy.deepcopy(base_model).to(device)
#             for model_name, finetuned_model in finetuned_models.items():
#                 finetuned_model_gpu = copy.deepcopy(finetuned_model).to(device)
#                 task_vectors[model_name] = TaskVector(
#                     base_model=base_model_gpu,
#                     finetuned_model=finetuned_model_gpu,
#                     exclude_param_names_regex=exclude_param_names_regex,
#                     finetuned_param_name_convert_fn=cfg.finetuned_param_name_convert_fns.get(
#                         model_name
#                     ),
#                 )
#                 del finetuned_model
#                 del finetuned_model_gpu
#                 torch.cuda.empty_cache()
#             del base_model_gpu
#             torch.cuda.empty_cache()

#             print("task vectors created")
#             print_gpu_memory_usage()

#             for finetuned_model_name, task_vector in task_vectors.items():
#                 print(f"Masking {finetuned_model_name}...")
#                 mask_model_weights(
#                     task_vector,
#                     weight_mask_rate=(
#                         mask_rates[finetuned_model_name]
#                         if cfg.dare.mask_rate_type == "each"
#                         else mask_rates
#                     ),
#                     use_weight_rescale=cfg.dare.use_weight_rescale[
#                         finetuned_model_name
#                     ],
#                     mask_strategy=cfg.dare.mask_strategy[finetuned_model_name],
#                 )
#                 del task_vector
#                 torch.cuda.empty_cache()

#             print("masking done")
#             print_gpu_memory_usage()

#             # >> ties merging ---------------------------------------------------------------------
#             if cfg.ties.param_mask_type == "each":
#                 param_value_mask_rate = {
#                     n: trial.suggest_float(f"ties_mask_rate_{n}", 0.0, 1.0)
#                     for n in finetuned_models.keys()
#                 }
#             elif cfg.ties.param_mask_type == "same":
#                 param_value_mask_rate = trial.suggest_float("ties_mask_rate", 0.0, 1.0)
#             else:
#                 raise NotImplementedError

#             scaling_coefficient = cfg.ties.scaling_coefficient or trial.suggest_float(
#                 "ties_scaling_coefficient", 0.0, 2.0
#             )
#             each_model_weight = (
#                 {
#                     n: trial.suggest_float(f"ties_weight_{n}", 0.0, 1.0)
#                     for n in finetuned_models.keys()
#                 }
#                 if cfg.ties.use_each_model_weight
#                 else None
#             )

#             flattened_model_names = list(task_vectors.keys())
#             flattened_models_to_merge_param_ls = []
#             for name in flattened_model_names:
#                 flattened_models_to_merge_param_ls.append(
#                     _task_vector_param_dict_to_single_vector(
#                         task_vector=task_vectors[name]
#                     )
#                 )
#                 del task_vectors[name]
#                 torch.cuda.empty_cache()
#                 gc.collect()
#             print("making flattened models done")
#             print_gpu_memory_usage()

#             # Tensor, shape (num_models_to_merge, num_total_params), flattened parameters of individual models that need to be merged
#             flattened_models_to_merge_param = torch.vstack(
#                 flattened_models_to_merge_param_ls
#             )
#             del flattened_models_to_merge_param_ls

#             # Tensor, shape (num_models_to_merge, num_total_params), mask the smallest-magnitude parameter values using param_value_mask_rate
#             _mask_smallest_magnitude_param_values(
#                 flattened_models_to_merge_param,
#                 flattened_model_names,
#                 param_value_mask_rate,
#             )
#             print("masking smallest magnitude done")
#             print_gpu_memory_usage()

#             # Tensor, shape (num_total_params, ), get the signs for each parameter in flattened_models_to_merge_param
#             param_signs = _get_param_signs(flattened_models_to_merge_param)
#             print("getting param signs done")
#             print_gpu_memory_usage()

#             print("[[trial params]]")
#             print(trial.params)

#             ## >> >> merge -----------------------------------------------------------------------
#             num_preserved = 0.0
#             merged_param = 0.0
#             flattened_models_to_merge_param_dic = {
#                 name: param if i == 0 else param.to("cpu")
#                 for i, (name, param) in enumerate(
#                     zip(flattened_model_names, flattened_models_to_merge_param)
#                 )
#             }
#             del flattened_models_to_merge_param
#             torch.cuda.empty_cache()
#             gc.collect()
#             print("making flattened models dict done")
#             print_gpu_memory_usage()
#             for i, name in enumerate(flattened_model_names):
#                 print(f"Merging {name}...")
#                 print_gpu_memory_usage()
#                 param = flattened_models_to_merge_param_dic[name]
#                 del flattened_models_to_merge_param_dic[name]
#                 if i != 0:
#                     param = param.to(device)
#                 mask = ((param_signs > 0) & (param > 0)) | (
#                     (param_signs < 0) & (param < 0)
#                 )
#                 param *= mask
#                 if each_model_weight is not None:
#                     param *= each_model_weight[name]
#                 print("mmmm")
#                 if isinstance(merged_param, float):
#                     merged_param = param.to("cpu")
#                 else:
#                     merged_param += param.to("cpu")
#                 del param
#                 torch.cuda.empty_cache()
#                 gc.collect()
#                 print(">mmmm")
#                 if isinstance(num_preserved, float):
#                     num_preserved = mask.half().to("cpu")
#                 else:
#                     num_preserved += mask.half().to("cpu")
#                 print(">>mmmm")
#                 del mask
#                 torch.cuda.empty_cache()
#                 gc.collect()
#             del param_signs
#             del flattened_models_to_merge_param_dic
#             torch.cuda.empty_cache()
#             gc.collect()
#             print("ZZZZZZZ")
#             print("merging done")
#             print_gpu_memory_usage()
#             assert isinstance(merged_param, torch.Tensor)
#             assert isinstance(num_preserved, torch.Tensor)
#             merged_param = merged_param.to(device)
#             merged_param /= torch.clamp(num_preserved.to(device), min=1.0)
#             del num_preserved
#             torch.cuda.empty_cache()
#             gc.collect()
#             print("rescaling done")
#             print_gpu_memory_usage()
#             ## << << merge -----------------------------------------------------------------------
#             merged_task_vector_param_dict = _single_vector_to_task_vector_param_dict(
#                 single_vector=merged_param,
#                 base_model=copy.deepcopy(base_model).to(device),
#             )
#             del merged_param
#             torch.cuda.empty_cache()
#             gc.collect()
#             print("making merged task vector param dict done")
#             print_gpu_memory_usage()
#             merged_task_vector = TaskVector.from_param_dict(
#                 merged_task_vector_param_dict,
#                 base_model,
#             )
#             del merged_task_vector_param_dict
#             torch.cuda.empty_cache()
#             gc.collect()
#             print("making merged task vector done")
#             print_gpu_memory_usage()
#             merged_params = merged_task_vector.combine_with_pretrained_model(
#                 base_model=copy.deepcopy(base_model).to(device),
#                 scaling_coefficient=scaling_coefficient,
#             )
#             del merged_task_vector.task_vector_param_dict
#             del merged_task_vector
#             torch.cuda.empty_cache()
#             gc.collect()
#             print("merging with pretrained model done")
#             print_gpu_memory_usage()
#             ## << ties merging ---------------------------------------------------------------------

#             vlm = copy.deepcopy(finetuned_models["vlm"]).to(device)
#             copy_params_to_model(vlm, merged_params)
#             del merged_params
#             torch.cuda.empty_cache()
#             gc.collect()
#             # vlm = vlm.to(device)
#             print("copying params to model done")
#             print_gpu_memory_usage()

#             val_metrics = run_inference(
#                 cfg, vlm, tokenizer, dataloader, metrics, device
#             )
#             print_metrics(val_metrics)
#             # wandb.log(val_metrics)

#             del vlm
#             torch.cuda.empty_cache()
#             gc.collect()
#             print("trial finished")
#             print_gpu_memory_usage()

#             return (
#                 tuple([val_metrics[met] for met in cfg.use_metrics])
#                 if len(cfg.use_metrics) > 1
#                 else val_metrics[cfg.use_metrics[0]]
#             )

#     with torch.no_grad():
#         study = optuna.create_study(
#             study_name=cfg.config_name,
#             storage=f"sqlite:///outputs/{cfg.config_name}.db",
#             sampler={
#                 "cmaes": optuna.samplers.CmaEsSampler,
#                 "tpe": optuna.samplers.TPESampler,
#             }[cfg.sampler](),
#             direction=cfg.direction[0] if len(cfg.use_metrics) == 1 else None,
#             directions=cfg.direction if len(cfg.use_metrics) > 1 else None,
#             load_if_exists=True,
#         )
#         study.optimize(
#             objective,
#             n_trials=cfg.n_trials,
#             callbacks=[MaxTrialsCallback(cfg.n_trials, states=(TrialState.COMPLETE,))],
#         )


def run_evolve_cpu(
    cfg: SimpleNamespace,
    exclude_param_names_regex: list[str],
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
    dataloader: SimpleDataLoder,
    metrics: dict[str, Callable[[str, str], float]],
    device: torch.device,
):
    def objective(trial):
        with torch.no_grad():
            print("trial started")
            print_gpu_memory_usage()
            print_cpu_memory_usage()

            base_model = (
                AutoModelForCausalLM.from_pretrained(
                    cfg.base_model_name,
                    torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
                    trust_remote_code=True,
                    # device_map="auto",
                ).eval()
                # .to(device)
            )
            print("base model loaded")
            print_cpu_memory_usage()
            # finetuned_models = {
            #     name: AutoModelForCausalLM.from_pretrained(
            #         model_name,
            #         torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
            #         trust_remote_code=True,
            #         # device_map="auto",
            #     ).eval()
            #     # .to(device)
            #     for name, model_name in cfg.finetuned_models.items()
            # }
            # finetuned_models["vlm"] = (
            #     AutoModelForCausalLM.from_pretrained(
            #         cfg.vlm.name,
            #         torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
            #         trust_remote_code=True,
            #         # device_map="auto",
            #     ).eval()
            #     # .to(device)
            # )
            flattened_model_names = list(cfg.finetuned_models.keys()) + ["vlm"]

            # print("models loaded")
            # print_cpu_memory_usage()

            if cfg.dare.mask_rate_type == "each":
                mask_rates = {
                    finetuned_model_name: trial.suggest_float(
                        f"dare_{finetuned_model_name}_mask_rate", 0.0, 1.0
                    )
                    for finetuned_model_name in flattened_model_names
                }
            elif cfg.dare.mask_rate_type == "same":
                mask_rates = trial.suggest_float("dare_mask_rate", 0.0, 1.0)
            else:
                raise NotImplementedError

            task_vectors = {}
            for model_name in flattened_model_names:
                print(f"loading {model_name}...")
                print_cpu_memory_usage()
                finetuned_model = AutoModelForCausalLM.from_pretrained(
                    cfg.finetuned_models[model_name]
                    if model_name != "vlm"
                    else cfg.vlm.name,
                    torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
                    trust_remote_code=True,
                    # device_map="auto",
                ).eval()
                sys.stdout.flush()
                print(f"Creating task vector for {model_name}...")
                print_cpu_memory_usage()
                task_vectors[model_name] = TaskVector(
                    base_model=base_model,
                    finetuned_model=finetuned_model,
                    exclude_param_names_regex=exclude_param_names_regex,
                    finetuned_param_name_convert_fn=cfg.finetuned_param_name_convert_fns.get(
                        model_name
                    ),
                )
                print_cpu_memory_usage()
                del finetuned_model
                gc.collect()
                print_cpu_memory_usage()
                sys.stdout.flush()
            del base_model
            gc.collect()

            print("task vectors created")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()

            for finetuned_model_name, task_vector in task_vectors.items():
                print(f"Masking {finetuned_model_name}...")
                print_cpu_memory_usage()
                mask_model_weights(
                    task_vector,
                    weight_mask_rate=(
                        mask_rates[finetuned_model_name]
                        if cfg.dare.mask_rate_type == "each"
                        else mask_rates
                    ),
                    use_weight_rescale=cfg.dare.use_weight_rescale[
                        finetuned_model_name
                    ],
                    mask_strategy=cfg.dare.mask_strategy[finetuned_model_name],
                    # device=device,
                )
                print_cpu_memory_usage()
                del task_vector
                gc.collect()
                sys.stdout.flush()

            print("masking done")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()

            # >> ties merging ---------------------------------------------------------------------
            if cfg.ties.param_mask_type == "each":
                param_value_mask_rate = {
                    n: trial.suggest_float(f"ties_mask_rate_{n}", 0.0, 1.0)
                    for n in flattened_model_names
                }
            elif cfg.ties.param_mask_type == "same":
                param_value_mask_rate = trial.suggest_float("ties_mask_rate", 0.0, 1.0)
            else:
                raise NotImplementedError

            scaling_coefficient = cfg.ties.scaling_coefficient or trial.suggest_float(
                "ties_scaling_coefficient", 0.0, 2.0
            )
            each_model_weight = (
                {
                    n: trial.suggest_float(f"ties_weight_{n}", 0.0, 2.0)
                    for n in flattened_model_names
                }
                if cfg.ties.use_each_model_weight
                else None
            )

            flattened_models_to_merge_param_ls = []
            for name in flattened_model_names:
                print(f"Making flattened model for {name}...")
                print_cpu_memory_usage()
                flattened_models_to_merge_param_ls.append(
                    _task_vector_param_dict_to_single_vector(
                        task_vector=task_vectors[name]
                    )
                )
                print_cpu_memory_usage()
                # del task_vectors[name].task_vector_param_dict
                del task_vectors[name]
                # torch.cuda.empty_cache()
                gc.collect()
                print_cpu_memory_usage()
            print("making flattened models done")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()

            # Tensor, shape (num_models_to_merge, num_total_params), flattened parameters of individual models that need to be merged
            flattened_models_to_merge_param = torch.vstack(
                flattened_models_to_merge_param_ls
            )
            del flattened_models_to_merge_param_ls

            # Tensor, shape (num_models_to_merge, num_total_params), mask the smallest-magnitude parameter values using param_value_mask_rate
            _mask_smallest_magnitude_param_values(
                flattened_models_to_merge_param,
                flattened_model_names,
                param_value_mask_rate,
            )
            print("masking smallest magnitude done")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()

            # Tensor, shape (num_total_params, ), get the signs for each parameter in flattened_models_to_merge_param
            param_signs = _get_param_signs(flattened_models_to_merge_param)
            print("getting param signs done")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()

            print("[[trial params]]")
            print(trial.params)

            ## >> >> merge -----------------------------------------------------------------------
            num_preserved = 0.0
            merged_param = 0.0
            flattened_models_to_merge_param_dic = {
                name: param
                for name, param in zip(
                    flattened_model_names, flattened_models_to_merge_param
                )
            }
            del flattened_models_to_merge_param
            # torch.cuda.empty_cache()
            gc.collect()
            print("making flattened models dict done")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()
            for i, name in enumerate(flattened_model_names):
                print(f"Merging {name}...")
                # print_gpu_memory_usage()
                print_cpu_memory_usage()
                param = flattened_models_to_merge_param_dic[name]
                del flattened_models_to_merge_param_dic[name]
                mask = ((param_signs > 0) & (param > 0)) | (
                    (param_signs < 0) & (param < 0)
                )
                print_cpu_memory_usage()
                sys.stdout.flush()
                mask = mask.half()
                gc.collect()
                param *= mask
                print_cpu_memory_usage()
                if each_model_weight is not None:
                    param *= each_model_weight[name]
                print("*")
                if isinstance(merged_param, float):
                    merged_param = param
                else:
                    merged_param += param
                print_cpu_memory_usage()
                del param
                # torch.cuda.empty_cache()
                gc.collect()
                print("**")
                if isinstance(num_preserved, float):
                    num_preserved = mask
                else:
                    num_preserved += mask
                print("***")
                print_cpu_memory_usage()
                del mask
                # torch.cuda.empty_cache()
                gc.collect()
                print_cpu_memory_usage()
                sys.stdout.flush()
            del param_signs
            del flattened_models_to_merge_param_dic
            # torch.cuda.empty_cache()
            gc.collect()
            print("merging done")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()
            assert isinstance(merged_param, torch.Tensor)
            assert isinstance(num_preserved, torch.Tensor)
            merged_param = merged_param.to(device)
            merged_param /= torch.clamp(num_preserved.to(device), min=1.0)
            del num_preserved
            torch.cuda.empty_cache()
            gc.collect()
            print("rescaling done")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()
            ## << << merge -----------------------------------------------------------------------
            base_model = (
                AutoModelForCausalLM.from_pretrained(
                    cfg.base_model_name,
                    torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
                    trust_remote_code=True,
                    # device_map="auto",
                )
                .eval()
                .to(device)
            )
            print("base model loaded")
            print_cpu_memory_usage()
            # sys.stdout.flush()

            merged_task_vector_param_dict = _single_vector_to_task_vector_param_dict(
                single_vector=merged_param,
                base_model=copy.deepcopy(base_model),
            )
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            del merged_param
            # torch.cuda.empty_cache()
            gc.collect()
            print("making merged task vector param dict done")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            merged_task_vector = TaskVector.from_param_dict(
                merged_task_vector_param_dict,
                base_model,
            )
            del merged_task_vector_param_dict
            torch.cuda.empty_cache()
            gc.collect()
            print("making merged task vector done")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()

            merged_params = merged_task_vector.combine_with_pretrained_model(
                base_model=base_model,
                scaling_coefficient=scaling_coefficient,
            )
            print("combining with pretrained model done")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            del merged_task_vector.task_vector_param_dict
            del merged_task_vector
            del base_model
            torch.cuda.empty_cache()
            gc.collect()
            print("merging with pretrained model done")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()
            ## << ties merging ---------------------------------------------------------------------

            vlm = (
                AutoModelForCausalLM.from_pretrained(
                    cfg.vlm.name,
                    torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
                    trust_remote_code=True,
                    # device_map="auto",
                )
                .eval()
                .to(device)
            )
            print("vlm loaded")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            copy_params_to_model(vlm, merged_params)
            print("copying params to model done")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            del merged_params
            torch.cuda.empty_cache()
            gc.collect()
            # vlm = vlm.to(device)
            print_gpu_memory_usage()
            print_cpu_memory_usage()

            print("running inference...")
            val_metrics = run_inference(
                cfg, vlm, tokenizer, dataloader, metrics, device
            )
            print_metrics(val_metrics)
            # wandb.log(val_metrics)

            del vlm
            torch.cuda.empty_cache()
            gc.collect()
            print("trial finished")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()

            return (
                tuple([val_metrics[met] for met in cfg.use_metrics])
                if len(cfg.use_metrics) > 1
                else val_metrics[cfg.use_metrics[0]]
            )

    #with torch.no_grad():
    study = optuna.create_study(
        study_name=cfg.config_name,
        storage=f"sqlite:///outputs/{cfg.config_name}.db",
        sampler={
            "cmaes": optuna.samplers.CmaEsSampler,
            "tpe": optuna.samplers.TPESampler,
            "gp": optuna.samplers.GPSampler,
            "nsga": optuna.samplers.NSGAIISampler,
        }[cfg.sampler](),
        direction=cfg.direction[0] if len(cfg.use_metrics) == 1 else None,
        directions=cfg.direction if len(cfg.use_metrics) > 1 else None,
        load_if_exists=True,
    )
    study.optimize(
        objective,
        n_trials=cfg.n_trials,
        callbacks=[MaxTrialsCallback(cfg.n_trials, states=(TrialState.COMPLETE,))],
    )



def run_evolve_cpu2(
    cfg: SimpleNamespace,
    exclude_param_names_regex: list[str],
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
    dataloader: SimpleDataLoder,
    metrics: dict[str, Callable[[str, str], float]],
    device: torch.device,
):
    def objective(trial):
        with torch.no_grad():
            print("trial started")
            print_gpu_memory_usage()
            print_cpu_memory_usage()

            base_model = (
                AutoModelForCausalLM.from_pretrained(
                    cfg.base_model_name,
                    torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
                    trust_remote_code=True,
                    # device_map="auto",
                ).eval()
                # .to(device)
            )
            print("base model loaded")
            print_cpu_memory_usage()
            # finetuned_models = {
            #     name: AutoModelForCausalLM.from_pretrained(
            #         model_name,
            #         torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
            #         trust_remote_code=True,
            #         # device_map="auto",
            #     ).eval()
            #     # .to(device)
            #     for name, model_name in cfg.finetuned_models.items()
            # }
            # finetuned_models["vlm"] = (
            #     AutoModelForCausalLM.from_pretrained(
            #         cfg.vlm.name,
            #         torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
            #         trust_remote_code=True,
            #         # device_map="auto",
            #     ).eval()
            #     # .to(device)
            # )
            flattened_model_names = list(cfg.finetuned_models.keys()) + ["vlm"]

            # print("models loaded")
            # print_cpu_memory_usage()

            if cfg.dare.mask_rate_type == "each":
                mask_rates = {
                    finetuned_model_name: trial.suggest_float(
                        f"dare_{finetuned_model_name}_mask_rate", 0.0, 1.0
                    )
                    for finetuned_model_name in flattened_model_names
                }
            elif cfg.dare.mask_rate_type == "same":
                mask_rates = trial.suggest_float("dare_mask_rate", 0.0, 1.0)
            else:
                raise NotImplementedError

            task_vectors = {}
            for model_name in flattened_model_names:
                print(f"loading {model_name}...")
                print_cpu_memory_usage()
                finetuned_model = AutoModelForCausalLM.from_pretrained(
                    cfg.finetuned_models[model_name]
                    if model_name != "vlm"
                    else cfg.vlm.name,
                    torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
                    trust_remote_code=True,
                    # device_map="auto",
                ).eval()
                sys.stdout.flush()
                print(f"Creating task vector for {model_name}...")
                print_cpu_memory_usage()
                task_vectors[model_name] = TaskVector(
                    base_model=base_model,
                    finetuned_model=finetuned_model,
                    exclude_param_names_regex=exclude_param_names_regex,
                    finetuned_param_name_convert_fn=cfg.finetuned_param_name_convert_fns.get(
                        model_name
                    ),
                )
                print_cpu_memory_usage()
                del finetuned_model
                gc.collect()
                print_cpu_memory_usage()
                sys.stdout.flush()
            del base_model
            gc.collect()

            print("task vectors created")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()

            for finetuned_model_name, task_vector in task_vectors.items():
                print(f"Masking {finetuned_model_name}...")
                print_cpu_memory_usage()
                mask_model_weights(
                    task_vector,
                    weight_mask_rate=(
                        mask_rates[finetuned_model_name]
                        if cfg.dare.mask_rate_type == "each"
                        else mask_rates
                    ),
                    use_weight_rescale=cfg.dare.use_weight_rescale[
                        finetuned_model_name
                    ],
                    mask_strategy=cfg.dare.mask_strategy[finetuned_model_name],
                    # device=device,
                )
                print_cpu_memory_usage()
                del task_vector
                gc.collect()
                sys.stdout.flush()

            print("masking done")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()

            # >> ties merging ---------------------------------------------------------------------
            if cfg.ties.param_mask_type == "each":
                param_value_mask_rate = {
                    n: trial.suggest_float(f"ties_mask_rate_{n}", 0.0, 1.0)
                    for n in flattened_model_names
                }
            elif cfg.ties.param_mask_type == "same":
                param_value_mask_rate = trial.suggest_float("ties_mask_rate", 0.0, 1.0)
            else:
                raise NotImplementedError

            scaling_coefficient = cfg.ties.scaling_coefficient or trial.suggest_float(
                "ties_scaling_coefficient", 0.0, 2.0
            )
            each_model_weight = (
                {
                    n: trial.suggest_float(f"ties_weight_{n}", 0.0, 2.0)
                    for n in flattened_model_names
                }
                if cfg.ties.use_each_model_weight
                else None
            )

            flattened_models_to_merge_param_ls = []
            for name in flattened_model_names:
                print(f"Making flattened model for {name}...")
                print_cpu_memory_usage()
                flattened_models_to_merge_param_ls.append(
                    _task_vector_param_dict_to_single_vector(
                        task_vector=task_vectors[name]
                    )
                )
                print_cpu_memory_usage()
                # del task_vectors[name].task_vector_param_dict
                del task_vectors[name]
                # torch.cuda.empty_cache()
                gc.collect()
                print_cpu_memory_usage()
            print("making flattened models done")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()

            # Tensor, shape (num_models_to_merge, num_total_params), flattened parameters of individual models that need to be merged
            flattened_models_to_merge_param = torch.vstack(
                flattened_models_to_merge_param_ls
            )
            del flattened_models_to_merge_param_ls

            # Tensor, shape (num_models_to_merge, num_total_params), mask the smallest-magnitude parameter values using param_value_mask_rate
            _mask_smallest_magnitude_param_values(
                flattened_models_to_merge_param,
                flattened_model_names,
                param_value_mask_rate,
            )
            print("masking smallest magnitude done")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()

            # Tensor, shape (num_total_params, ), get the signs for each parameter in flattened_models_to_merge_param
            param_signs = _get_param_signs(flattened_models_to_merge_param)
            print("getting param signs done")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()

            print("[[trial params]]")
            print(trial.params)

            ## >> >> merge -----------------------------------------------------------------------
            num_preserved = 0.0
            merged_param = 0.0
            flattened_models_to_merge_param_dic = {
                name: param
                for name, param in zip(
                    flattened_model_names, flattened_models_to_merge_param
                )
            }
            del flattened_models_to_merge_param
            # torch.cuda.empty_cache()
            gc.collect()
            print("making flattened models dict done")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()
            for i, name in enumerate(flattened_model_names):
                print(f"Merging {name}...")
                # print_gpu_memory_usage()
                print_cpu_memory_usage()
                param = flattened_models_to_merge_param_dic[name]
                del flattened_models_to_merge_param_dic[name]
                mask = ((param_signs > 0) & (param > 0)) | (
                    (param_signs < 0) & (param < 0)
                )
                print_cpu_memory_usage()
                sys.stdout.flush()
                mask = mask.half()
                gc.collect()
                param *= mask
                print_cpu_memory_usage()
                if each_model_weight is not None:
                    param *= each_model_weight[name]
                print("*")
                if isinstance(merged_param, float):
                    merged_param = param
                else:
                    merged_param += param
                print_cpu_memory_usage()
                del param
                # torch.cuda.empty_cache()
                gc.collect()
                print("**")
                if isinstance(num_preserved, float):
                    num_preserved = mask
                else:
                    num_preserved += mask
                print("***")
                print_cpu_memory_usage()
                del mask
                # torch.cuda.empty_cache()
                gc.collect()
                print_cpu_memory_usage()
                sys.stdout.flush()
            del param_signs
            del flattened_models_to_merge_param_dic
            # torch.cuda.empty_cache()
            gc.collect()
            print("merging done")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()
            assert isinstance(merged_param, torch.Tensor)
            assert isinstance(num_preserved, torch.Tensor)
            # merged_param = merged_param.to(device)
            merged_param /= torch.clamp(num_preserved, min=1.0)
            del num_preserved
            torch.cuda.empty_cache()
            gc.collect()
            print("rescaling done")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()
            ## << << merge -----------------------------------------------------------------------
            base_model = (
                AutoModelForCausalLM.from_pretrained(
                    cfg.base_model_name,
                    torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
                    trust_remote_code=True,
                    # device_map="auto",
                )
                .eval()
                # .to(device)
            )
            print("base model loaded")
            print_cpu_memory_usage()
            # sys.stdout.flush()

            merged_task_vector_param_dict = _single_vector_to_task_vector_param_dict(
                single_vector=merged_param,
                base_model=copy.deepcopy(base_model),
            )
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            del merged_param
            # torch.cuda.empty_cache()
            gc.collect()
            print("making merged task vector param dict done")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            merged_task_vector = TaskVector.from_param_dict(
                merged_task_vector_param_dict,
                base_model,
            )
            del merged_task_vector_param_dict
            torch.cuda.empty_cache()
            gc.collect()
            print("making merged task vector done")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()

            merged_params = merged_task_vector.combine_with_pretrained_model(
                base_model=base_model,
                scaling_coefficient=scaling_coefficient,
            )
            print("combining with pretrained model done")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            del merged_task_vector.task_vector_param_dict
            del merged_task_vector
            del base_model
            torch.cuda.empty_cache()
            gc.collect()
            print("merging with pretrained model done")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()
            ## << ties merging ---------------------------------------------------------------------

            vlm = (
                AutoModelForCausalLM.from_pretrained(
                    cfg.vlm.name,
                    torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
                    trust_remote_code=True,
                    # device_map="auto",
                )
                .eval()
                # .to(device)
            )
            print("vlm loaded")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            copy_params_to_model(vlm, merged_params)
            print("copying params to model done")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            del merged_params
            torch.cuda.empty_cache()
            gc.collect()
            vlm = vlm.to(device)
            print_gpu_memory_usage()
            print_cpu_memory_usage()

            print("running inference...")
            val_metrics = run_inference(
                cfg, vlm, tokenizer, dataloader, metrics, device
            )
            print_metrics(val_metrics)
            # wandb.log(val_metrics)

            del vlm
            torch.cuda.empty_cache()
            gc.collect()
            print("trial finished")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()

            return (
                tuple([val_metrics[met] for met in cfg.use_metrics])
                if len(cfg.use_metrics) > 1
                else val_metrics[cfg.use_metrics[0]]
            )

    #with torch.no_grad():
    study = optuna.create_study(
        study_name=cfg.config_name,
        storage=f"sqlite:///outputs/{cfg.config_name}.db",
        sampler={
            "cmaes": optuna.samplers.CmaEsSampler,
            "tpe": optuna.samplers.TPESampler,
            "gp": optuna.samplers.GPSampler,
            "nsga": optuna.samplers.NSGAIISampler,
        }[cfg.sampler](),
        direction=cfg.direction[0] if len(cfg.use_metrics) == 1 else None,
        directions=cfg.direction if len(cfg.use_metrics) > 1 else None,
        load_if_exists=True,
    )
    study.optimize(
        objective,
        n_trials=cfg.n_trials,
        callbacks=[MaxTrialsCallback(cfg.n_trials, states=(TrialState.COMPLETE,))],
    )


def run_evolve_gpu(
    cfg: SimpleNamespace,
    exclude_param_names_regex: list[str],
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
    dataloader: SimpleDataLoder,
    metrics: dict[str, Callable[[str, str], float]],
    device: torch.device,
):
    def objective(trial):
        with torch.no_grad():
            print("trial started")
            print_gpu_memory_usage()
            print_cpu_memory_usage()

            base_model = (
                AutoModelForCausalLM.from_pretrained(
                    cfg.base_model_name,
                    torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
                    trust_remote_code=True,
                    # device_map="auto",
                ).eval()
                # .to(device)
            )
            print("base model loaded")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            # finetuned_models = {
            #     name: AutoModelForCausalLM.from_pretrained(
            #         model_name,
            #         torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
            #         trust_remote_code=True,
            #         # device_map="auto",
            #     ).eval()
            #     # .to(device)
            #     for name, model_name in cfg.finetuned_models.items()
            # }
            # finetuned_models["vlm"] = (
            #     AutoModelForCausalLM.from_pretrained(
            #         cfg.vlm.name,
            #         torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
            #         trust_remote_code=True,
            #         # device_map="auto",
            #     ).eval()
            #     # .to(device)
            # )
            use_model_names = list(cfg.finetuned_models.keys()) + ["vlm"]

            # print("models loaded")
            # print_cpu_memory_usage()

            # [DARE params] -------------------
            if cfg.dare.mask_rate_type == "each":
                mask_rates = {
                    model_name: trial.suggest_float(
                        f"dare_{model_name}_mask_rate", 0.0, 1.0
                    )
                    for model_name in use_model_names
                }
            elif cfg.dare.mask_rate_type == "same":
                mask_rates = trial.suggest_float("dare_mask_rate", 0.0, 1.0)
            else:
                raise NotImplementedError

            # [TIES params] -------------------
            if cfg.ties.param_mask_type == "each":
                param_value_mask_rate = {
                    n: trial.suggest_float(f"ties_mask_rate_{n}", 0.0, 1.0)
                    for n in use_model_names
                }
            elif cfg.ties.param_mask_type == "same":
                param_value_mask_rate = trial.suggest_float("ties_mask_rate", 0.0, 1.0)
            else:
                raise NotImplementedError

            scaling_coefficient = cfg.ties.scaling_coefficient or trial.suggest_float(
                "ties_scaling_coefficient", 0.0, 2.0
            )
            each_model_weight = (
                {
                    n: trial.suggest_float(f"ties_weight_{n}", 0.0, 1.0)
                    for n in use_model_names
                }
                if cfg.ties.use_each_model_weight
                else None
            )
            # ---------------------------------

            print("[[trial params]]")
            print(trial.params)

            # task_vectors = {}

            flattened_task_vectors = []
            for model_name in use_model_names:
                print(f"[[{model_name}]]")
                print(f"loading {model_name}...")
                print_gpu_memory_usage()
                print_cpu_memory_usage()
                finetuned_model = (
                    AutoModelForCausalLM.from_pretrained(
                        cfg.finetuned_models[model_name]
                        if model_name != "vlm"
                        else cfg.vlm.name,
                        torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
                        trust_remote_code=True,
                        # device_map="auto",
                    )
                    .eval()
                    .to(device)
                )
                # sys.stdout.flush()
                print(f"Creating task vector for {model_name}...")
                # print_cpu_memory_usage()
                task_vector = TaskVector(
                    base_model=copy.deepcopy(base_model).to(device),
                    finetuned_model=finetuned_model,
                    exclude_param_names_regex=exclude_param_names_regex,
                    finetuned_param_name_convert_fn=cfg.finetuned_param_name_convert_fns.get(
                        model_name
                    ),
                )
                print_cpu_memory_usage()
                del finetuned_model
                torch.cuda.empty_cache()
                gc.collect()

                print(f"task vectors created: {model_name}")
                print_gpu_memory_usage()
                print_cpu_memory_usage()
                sys.stdout.flush()

                print(f"Masking {model_name}...")

                mask_model_weights(
                    task_vector,
                    weight_mask_rate=(
                        mask_rates[model_name]
                        if cfg.dare.mask_rate_type == "each"
                        else mask_rates
                    ),
                    use_weight_rescale=cfg.dare.use_weight_rescale[model_name],
                    mask_strategy=cfg.dare.mask_strategy[model_name],
                    # device=device,
                )
                print("masking done")
                print_gpu_memory_usage()
                print_cpu_memory_usage()
                gc.collect()
                sys.stdout.flush()

                # >> ties merging ---------------------------------------------------------------------

                print(f"Making flattened model for {model_name}...")
                print_cpu_memory_usage()

                flattened_task_vector = _task_vector_param_dict_to_single_vector(
                    task_vector=task_vector
                ).half()
                print_cpu_memory_usage()
                del task_vector
                torch.cuda.empty_cache()
                gc.collect()
                print("making flattened models done")
                print(f"shape: {flattened_task_vector.shape}")
                print_gpu_memory_usage()
                print_cpu_memory_usage()
                sys.stdout.flush()

                # # Tensor, shape (num_models_to_merge, num_total_params), flattened parameters of individual models that need to be merged
                # flattened_models_to_merge_param = torch.vstack(
                #     flattened_models_to_merge_param_ls
                # )
                # del flattened_models_to_merge_param_ls

                mask_rate = (
                    param_value_mask_rate[model_name]
                    if isinstance(param_value_mask_rate, dict)
                    else param_value_mask_rate
                )
                num_mask_params = int(mask_rate * flattened_task_vector.shape[0])
                _mask_smallest(flattened_task_vector, num_mask_params)
                print("masking smallest magnitude done")
                print_gpu_memory_usage()
                print_cpu_memory_usage()
                sys.stdout.flush()

                flattened_task_vectors.append(flattened_task_vector.to("cpu").half())
                print_cpu_memory_usage()
                del flattened_task_vector
                torch.cuda.empty_cache()
                gc.collect()
                print(f"flattened task vector all done: {model_name}")
                print_gpu_memory_usage()
                print_cpu_memory_usage()
                sys.stdout.flush()

            print("!!all task vectors done")
            print_gpu_memory_usage()
            print_cpu_memory_usage()

            # Tensor, shape (num_total_params, ), get the signs for each parameter in flattened_models_to_merge_param
            param_signs = _get_param_signs(flattened_task_vectors)
            print("getting param signs done")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()

            ## >> >> merge -----------------------------------------------------------------------
            num_preserved = 0.0
            merged_param = 0.0
            flattened_models_to_merge_param_dic = {
                name: param
                for name, param in zip(use_model_names, flattened_task_vectors)
            }
            del flattened_task_vectors
            # torch.cuda.empty_cache()
            gc.collect()
            param_signs = param_signs.to(device)
            print("making flattened models dict done")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()
            for i, name in enumerate(use_model_names):
                print(f"Merging {name}...")
                # print_gpu_memory_usage()
                print_cpu_memory_usage()
                param = flattened_models_to_merge_param_dic[name]
                param = param.to(device)
                del flattened_models_to_merge_param_dic[name]
                mask = ((param_signs > 0) & (param > 0)) | (
                    (param_signs < 0) & (param < 0)
                )
                print_cpu_memory_usage()
                sys.stdout.flush()
                mask = mask.half()  # .to(device)
                gc.collect()
                param *= mask
                print_cpu_memory_usage()
                if each_model_weight is not None:
                    param *= each_model_weight[name]
                print("*")
                if isinstance(merged_param, float):
                    merged_param = param.to("cpu")
                else:
                    merged_param += param.to("cpu")
                print_cpu_memory_usage()
                del param
                torch.cuda.empty_cache()
                gc.collect()
                print("**")
                if isinstance(num_preserved, float):
                    num_preserved = mask.to("cpu")
                else:
                    num_preserved += mask.to("cpu")
                print("***")
                print_cpu_memory_usage()
                del mask
                torch.cuda.empty_cache()
                gc.collect()
                print_cpu_memory_usage()
                sys.stdout.flush()
            del param_signs
            del flattened_models_to_merge_param_dic
            # torch.cuda.empty_cache()
            gc.collect()
            print("merging done")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()
            assert isinstance(merged_param, torch.Tensor)
            assert isinstance(num_preserved, torch.Tensor)
            merged_param = merged_param.to(device)
            merged_param /= torch.clamp(num_preserved.to(device), min=1.0)
            del num_preserved
            torch.cuda.empty_cache()
            gc.collect()
            print("rescaling done")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()
            ## << << merge -----------------------------------------------------------------------
            base_model = (
                AutoModelForCausalLM.from_pretrained(
                    cfg.base_model_name,
                    torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
                    trust_remote_code=True,
                    # device_map="auto",
                )
                .eval()
                .to(device)
            )
            print("base model loaded")
            print_cpu_memory_usage()
            # sys.stdout.flush()

            merged_task_vector_param_dict = _single_vector_to_task_vector_param_dict(
                single_vector=merged_param,
                base_model=copy.deepcopy(base_model),
            )
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            del merged_param
            # torch.cuda.empty_cache()
            gc.collect()
            print("making merged task vector param dict done")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            merged_task_vector = TaskVector.from_param_dict(
                merged_task_vector_param_dict,
                base_model,
            )
            del merged_task_vector_param_dict
            torch.cuda.empty_cache()
            gc.collect()
            print("making merged task vector done")
            # print_gpu_memory_usage()
            print_cpu_memory_usage()

            merged_params = merged_task_vector.combine_with_pretrained_model(
                base_model=base_model,
                scaling_coefficient=scaling_coefficient,
            )
            print("combining with pretrained model done")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            del merged_task_vector.task_vector_param_dict
            del merged_task_vector
            del base_model
            torch.cuda.empty_cache()
            gc.collect()
            print("merging with pretrained model done")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()
            ## << ties merging ---------------------------------------------------------------------

            vlm = (
                AutoModelForCausalLM.from_pretrained(
                    cfg.vlm.name,
                    torch_dtype=torch.float16,  # if cfg.fp16 else torch.float32,
                    trust_remote_code=True,
                    # device_map="auto",
                )
                .eval()
                .to(device)
            )
            print("vlm loaded")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            copy_params_to_model(vlm, merged_params)
            print("copying params to model done")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            del merged_params
            torch.cuda.empty_cache()
            gc.collect()
            # vlm = vlm.to(device)
            print_gpu_memory_usage()
            print_cpu_memory_usage()

            print("running inference...")
            val_metrics = run_inference(
                cfg, vlm, tokenizer, dataloader, metrics, device
            )
            print_metrics(val_metrics)
            # wandb.log(val_metrics)

            del vlm
            torch.cuda.empty_cache()
            gc.collect()
            print("trial finished")
            print_gpu_memory_usage()
            print_cpu_memory_usage()
            sys.stdout.flush()

            return (
                tuple([val_metrics[met] for met in cfg.use_metrics])
                if len(cfg.use_metrics) > 1
                else val_metrics[cfg.use_metrics[0]]
            )

    with torch.no_grad():
        study = optuna.create_study(
            study_name=cfg.config_name,
            storage=f"sqlite:///outputs/{cfg.config_name}.db",
            sampler={
                "cmaes": optuna.samplers.CmaEsSampler,
                "tpe": optuna.samplers.TPESampler,
            }[cfg.sampler](),
            direction=cfg.direction[0] if len(cfg.use_metrics) == 1 else None,
            directions=cfg.direction if len(cfg.use_metrics) > 1 else None,
            load_if_exists=True,
        )
        study.optimize(
            objective,
            n_trials=cfg.n_trials,
            callbacks=[MaxTrialsCallback(cfg.n_trials, states=(TrialState.COMPLETE,))],
        )
