"""
This script is used to run NePS based on the given train template, dataset and NePS yaml configuration.

Example usage for running a grid search on the learning rate for the slimpajama dataset:

python warms/run_neps.py \
    --neps_config_path <root_path>/configs/neps/lr_grid_search.yaml \
    --dataset slimpajama \
    --output_tree neps/quick_debug \
    --neps_seed 123 \
    --grid_search \\ # Only set this if you are using grid search!
    --target_scale <path_to_target_scale_config>

```lr_grid_search.yaml
pipeline_space:
  max_lr:
    choices: [0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001]
max_evaluations_total: 6
```

If multi-fidelity is being used, the runtime parameter used as fidelity needs to provided as a constant
to indicate the total runtime. Additionally, it needs to be provided as a hyperparameter while using the prefix
`early_stopping_` to control the fidelity brackets. This is needed for the lr scheduling to work correctly.
Note that you can provide other attributes as constants in the yaml file as well.

See the example below for a multi-fidelity hyperband search:
```
pipeline_space:
  max_lr:
    lower: 0.0001
    upper: 0.03
    log: True
  tokens_per_param: 20
  early_stopping_tokens_per_param:
    lower: 2.5
    upper: 20.0
    is_fidelity: true
max_evaluations_total: 100
searcher:
  strategy: "hyperband"
  eta: 2
```

Here, we are able to specify the search algorithm in the yaml file while we had pass it as an argument when using
grid search. This is because grid search is not natively supported in NePS.
"""

import argparse
import random
import time
from functools import reduce
from pathlib import Path
from typing import Callable

import lightning as L
import numpy as np
import pandas as pd
import torch
import yaml
from neps.api import run as run_neps

from freezes import CANVAS_BASE_PATH, DATASET_MAP, ExpCanvas, prepare_data_handler_from_file
from saws import TrainConfig, main
from saws.config.yaml_utils import path_constructor

try:
    # for the `master` branch
    from neps.optimizers.grid_search import GridSearch

    version = "master"
except ImportError:
    try:
        # for v0.12.2
        from neps.optimizers.grid_search.optimizer import GridSearch

        version = "v0.12.2"
    except ImportError:
        print("Error: Could not import GridSearch from either location.")
        raise  # re-raises the last exception that occured.
    except Exception as e:
        print(f"An unexpected error occurred during the second import: {e}")
        raise
except Exception as e:
    print(f"An unexpected error occurred during the first import: {e}")
    raise


def set_seed(seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)


def get_args():
    parser = argparse.ArgumentParser(description="Parser for ")

    parser.add_argument(
        "--neps_config_path",
        type=str,
        required=True,
        help="The path to config yaml file that defines the search space, searcher, and other neps arguments",
    )
    parser.add_argument(
        "--train_template",
        type=str,
        help="Path to the train config template file. "
        "If not provided, the default template in the selected canvas will be used.",
    )
    parser.add_argument(
        "--canvas_access",
        type=str,
        default="global-meta",
        help="The key to decide the access point of the experiment configuration",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="slimpajama",
        help="Dataset choice",
        choices=["wikitext", "slimpajama"],
    )
    parser.add_argument(
        "--output_tree",
        type=str,
        help="Creates a subdirectory tree starting from `results_root` in canvas configuration and uses it as the neps output directory. "
        "If not provided the `root_directory` in the config .yaml will be used. If both are given, an error will be raised.",
    )
    parser.add_argument(
        "--num_layers_train",
        type=int,
        default=None,
        help="Number of layers to train. If not provided, all layers will be trained.",
    )
    parser.add_argument(
        "--overwrite",
        action="store_true",
        help="Deletes and overwrites `root_directory` if it already exists.",
    )
    parser.add_argument(
        "--neps_seed",
        type=int,
        default=123,
        help="The seed used by the neps optimizer.",
    )
    parser.add_argument(
        "--no_continuations",
        action="store_true",
        help="Disables the default use of continuations in multi-fidelity evaluations.",
    )
    # For custom many fidelity search, esp. SH-like
    parser.add_argument(
        "--diagonal_SH",
        action="store_true",
        help="To use Successive Halving with diagonal progression over tokens/param and layers.",
    )

    args = parser.parse_args()

    return args


def warmstart_from_neps(train_config: dict, warmstart_neps_root_path: str):
    """
    Modifies the train_config in such a way that. If the search spaces do not align there is a lot that can go wrong.
    :param train_config:
    :param config:
    :param warmstart_neps_root_path:
    """
    if warmstart_neps_root_path is not None:
        summary_path = Path(warmstart_neps_root_path) / "summary_csv" / "config_data.csv"
        if not summary_path.exists():
            raise ValueError(
                "The summary .csv file does not exist. This might be because the NePS run did not finish successfully "
                "for all configurations."
            )

        summary = pd.read_csv(summary_path)

        # Check if all the runs finished successfully, if not raise an error
        if not all(summary["status"] == "complete"):
            raise ValueError(
                "Not all runs in the neps output directory finished successfully. "
                "Please make sure all runs are completed before using warmstart."
            )

        # select all the columns that are hyperparameters (more than 1 unique value and start with "config.")
        hp_columns = [
            col.removeprefix("config.")
            for col in summary.columns
            if col.startswith("config.") and summary[col].nunique() > 1
        ]

        # select the row that matches the train_config hp values
        row_selection = [summary["config." + col] == train_config[col] for col in hp_columns]
        row_selection = reduce(lambda x, y: x & y, row_selection)
        config = summary["config_id"][row_selection]

        assert len(config) == 1, "The warmstart config could not be found or might not be unique."
        config = config.iloc[0]

        train_config["warmstart_config"]["base_model_path"] = (
            Path(warmstart_neps_root_path) / "results" / f"config_{config}" / "output"
        )


def neps_training_wrapper(args: argparse.Namespace) -> Callable:
    """
    Wrapper function to create a pipeline for neps training
    :param args: argparse arguments
    :return: function that can be passed to neps.run for the run_pipeline argument
    """

    def run_pipeline(pipeline_directory: Path, previous_pipeline_directory: Path, **config) -> dict:
        """
        Runs the pipeline with the given neps configuration and the arguments from the wrapper function.
        Supports resuming the pipeline from a previous pipeline directory to speed up multi-fidelity HPO.

        :param pipeline_directory: The directory where the information of the current pipeline is saved.
        :param previous_pipeline_directory: The directory from which the pipeline will be resumed if multifidelity is used.
        :param config: The hyperparameters and constants to be used in the pipeline.
        :return: loss at the end of training
        """
        canvas = ExpCanvas(CANVAS_BASE_PATH, args.canvas_access)

        # Load the config from the train_template as base
        yaml.SafeLoader.add_constructor("!path", path_constructor)
        with (
            canvas.train_template if args.train_template is None else Path(args.train_template)
        ).open(encoding="utf-8") as yaml_file:
            train_config = yaml.safe_load(yaml_file)

        ############################################################################################
        # Crucial update of training configuration, accounting for hyperparameters passed by NePS #
        ############################################################################################
        # Apply all the hyperparameters and constants from the config to the train_config
        for key, value in config.items():
            # Nested dictionaries can be specified by combining the keys of the nested dict with a '.'
            sub_dict = (
                train_config  # this is the innermost dictionary which key and value will be set
            )
            for nested_key in key.split(".")[
                :-1
            ]:  # we are iterating through the keys of the nested dictionary
                if nested_key not in sub_dict or not isinstance(sub_dict[nested_key], dict):
                    sub_dict[nested_key] = {}
                sub_dict = sub_dict[nested_key]

            # Setting all training time arguments except the given one to None since we can't pass None constants via NePS.
            trainings_time_arguments = ["tokens_per_param", "max_tokens", "max_train_steps"]
            if key in trainings_time_arguments:
                for tt_arg in trainings_time_arguments:
                    sub_dict[tt_arg] = None

            sub_dict[key.split(".")[-1]] = value

        # Resume run from previous fidelity
        if previous_pipeline_directory is not None and not args.no_continuations:
            train_config["load_state_path"] = previous_pipeline_directory / "output"

        # TODO: check role of this
        # warmstart_from_neps(train_config, args.warmstart_neps_root_path)

        fabric = L.Fabric(accelerator="auto", devices="auto", strategy="auto")
        train_config["devices"] = fabric.world_size

        # Handling custom SH with diagonal fidelities
        if args.diagonal_SH:
            ############################# WARNING ##########################
            # Super hard-coded for the s1 experimental setup with SH eta=2 #
            ################################################################
            LAYERS_TO_TKP_MAP = {  # for eta=2, for tokens/param=[2, 20], layers_to_train=[1, 10]
                1: 2,
                2: 5,
                5: 10,
                10: 20,
            }
            train_config.update(
                {"tokens_per_param": LAYERS_TO_TKP_MAP[train_config["layers_to_train"]]}
            )
        # end of diagonal SH handling

        train_config = TrainConfig(**train_config)

        data_config = prepare_data_handler_from_file(
            data_config_path=canvas.data_handler_root / DATASET_MAP(args.dataset),
            train_config=train_config,
            root_data_path=canvas.data_root,
        )

        start = time.time()
        run_metrics = main(
            fabric=fabric,
            data=data_config,
            train_args=train_config,
            out_dir=Path(pipeline_directory) / "output",
        )
        training_cost = time.time() - start

        # neps tensorboard logging
        # tblogger.log(
        #     objective_to_minimize=run_metrics["val_loss"],
        #     current_epoch=run_metrics["train_steps"],  # global step for tensorboard
        #     writer_config_scalar=True,
        #     writer_config_hparam=True,
        #     write_summary_incumbent=True,
        #     # extra_data={
        #     #     "train_loss": tblogger.scalar_logging(value=run_metrics["train_loss"]),
        #     #     "train_steps": tblogger.scalar_logging(value=run_metrics["train_steps"]),
        #     #     **{
        #     #         f"val_{k}": tblogger.scalar_logging(value=v) for k, v in config.items()
        #     #     }
        #     # }
        # )

        return {
            "objective_to_minimize": run_metrics["val_loss"],
            "cost": training_cost,
            "info_dict": run_metrics,
        }

    return run_pipeline


if __name__ == "__main__":
    args = get_args()

    set_seed(args.neps_seed)

    neps_args = {}
    if args.output_tree is not None:
        canvas = ExpCanvas(CANVAS_BASE_PATH, args.canvas_access)
        neps_args["root_directory"] = canvas.results_root / args.output_tree
    # else:
    #     raise NotImplementedError("Only grid search is supported for now.")
    if args.overwrite:
        neps_args["overwrite_working_directory"] = True

    with open(args.neps_config_path, "r", encoding="utf-8") as file:
        neps_config = yaml.safe_load(file)

    # Crucial change specific to this repo
    if args.num_layers_train:
        neps_config["pipeline_space"].update({"layers_to_train": args.num_layers_train})

    # For possible use with older NePS-based YAMLs
    # if "optimizer" in neps_config:
    #     neps_args.pop("optimizer")

    match version:
        case "master":
            pipeline_call = {"evaluate_pipeline": neps_training_wrapper(args)}
        case "v0.12.2":
            pipeline_call = {"run_pipeline": neps_training_wrapper(args)}
        case _:
            raise NotImplementedError(f"Version {version} is not supported.")

    run_neps(**pipeline_call, **neps_config, **neps_args)
