"""Run sweeps over different hyperparameters by relation."""

import argparse
import logging
from typing import Union

from src import data, hparams, models, sweeps
from src.utils import experiment_utils, logging_utils
from src.utils.typing import Layer

import torch

logger = logging.getLogger(__name__)


def main(args: argparse.Namespace) -> None:
    logging_utils.configure(args)
    experiment = experiment_utils.setup_experiment(args)

    device = args.device or "cuda" if torch.cuda.is_available() else "cpu"

    dataset = data.load_dataset_from_args(args)
    mt = models.load_model(args.model, fp16=args.fp16, device=device)
    h_layers = args.h_layers
    if h_layers is not None:
        h_layers = [
            int(h_layer) if h_layer not in ["emb", "ln_f"] else h_layer
            for h_layer in h_layers
        ]
    results = sweeps.sweep(
        mt=mt,
        dataset=dataset,
        h_layers=h_layers,
        n_trials=args.n_trials,
        n_train_samples=args.n_train_samples,
        recall_k=args.recall_k,
        batch_size=args.batch_size,
        results_dir=experiment.results_dir,
        resume=args.resume,
        subj_token_filter=args.subj_token_filter,
        use_bare_prompt=args.use_bare_prompt,
        o1_approxes_path=args.o1_approxes_path or None,
    )
    for relation in results.relations:
        log_msg = f"{relation.relation_name}"
        if len(relation.trials) < sweeps.DEFAULT_N_TRIALS:
            log_msg += f" -- not enough number of trials ({len(relation.trials)} < {sweeps.DEFAULT_N_TRIALS}) --> skipping"
            logger.info(log_msg)
            continue
        log_msg += f" (n_trials={len(relation.trials)})"
        logger.info(log_msg)
        best_by_f = relation.best_by_faithfulness()
        best_by_e = relation.best_by_efficacy()
        hparams.RelationHParams(
            relation_name=relation.relation_name,
            h_layer=best_by_f.layer,  # type: ignore
            h_layer_edit=best_by_e.layer,  # type: ignore
            z_layer=-1,
            beta=best_by_f.beta.mean,
            # Not clear what this should be set to, if anything.
            # rank=math.floor(best_by_e.rank.mean),
            model_name=mt.name,
        ).save()

    results_file = experiment.results_dir / "results_all.json"
    results_file.parent.mkdir(exist_ok=True, parents=True)
    with results_file.open("w") as handle:
        handle.write(results.to_json(indent=4))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="sweep over hyperparameters")
    data.add_data_args(parser)
    experiment_utils.add_experiment_args(parser)
    logging_utils.add_logging_args(parser)
    models.add_model_args(parser)
    parser.add_argument(
        "--h-layers",
        nargs="+",
        help="h layers to try, defaults to all",
    )
    parser.add_argument(
        "--recall-k",
        type=int,
        default=sweeps.DEFAULT_RECALL_K,
        help="compute up to recall@k",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=sweeps.DEFAULT_BATCH_SIZE,
        help="max batch size for lm",
    )
    parser.add_argument(
        "--subj-token-filter",
        type=str,
        default="all",
        choices=["all", "multi", "single"],
        help="allows filtering out samples with multiple or single subj tokens. defaults to all",
    )
    parser.add_argument(
        "--n-trials",
        type=int,
        default=sweeps.DEFAULT_N_TRIALS,
        help="number of trials per relation",
    )
    parser.add_argument(
        "--n-train-samples",
        type=int,
        default=sweeps.DEFAULT_N_TRAIN_SAMPLES,
        help="number of train samples to use per trial",
    )
    parser.add_argument(
        "--use-bare-prompt",
        action="store_true",
        default=False,
        help='will use bare prompt "{subj} {obj}"',
    )
    parser.add_argument(
        "--o1-approxes-path",
        type=str,
        default="",  # If defined will try to load the approxes without calculating them
    )
    args = parser.parse_args()
    logger.info(args)
    main(args)
