import click
import yaml
from datetime import datetime
from pathlib import Path
import random
import torch
import gc
import multiprocessing as mp

from es_llm.trainer import create_trainer
from es_llm.evaluator import create_evaluator

# import warnings
# warnings.filterwarnings("ignore")


def create_training_experiment_directory(config, seed):

    now = datetime.now()
    directory_name = str(now.strftime("%Y%m%d%H%M%S")) + "_"

    try:
        directory_name += config["trainer"]["name"] + "_"
        for param, value in config["trainer"]["parameters"].items():
            directory_name += str(param) + str(value) + "_"
        directory_name += f"seed{seed}"

        Path(config["experiment"]["experiment_directory"] + "/" + directory_name).mkdir(
            parents=True, exist_ok=True
        )

    except:
        raise ValueError("Cannot create experiment directory")

    directory_name = config["experiment"]["experiment_directory"] + "/" + directory_name
    return directory_name


def tuner(config, output_dir):
    seed = random.randint(0, 2**32 - 1)
    try:
        tuner = create_trainer(name=config["trainer"]["name"], config=config, seed=seed)
        if output_dir is None:
            output_dir = create_training_experiment_directory(config=config, seed=seed)
            Path(output_dir + "/model/").mkdir(parents=True, exist_ok=True)
    except:
        raise ValueError("Cannot create trainer")

    with open(output_dir + "/config.yaml", 'w') as file:
        yaml.dump(config, file, default_flow_style=False)

    Path(output_dir + "/log/").mkdir(parents=True, exist_ok=True)
    log = tuner.run(output_dir=output_dir)
    log.to_csv(f"{output_dir}/log/experiment_log.csv")

    del tuner
    torch.cuda.empty_cache()
    gc.collect()


def evaluator(config, output_dir):

    try:
        all_models = config["experiment"]["model_directory"] + "/model/"
        model_directory_path = Path(all_models)
        model_directory = [f.name for f in model_directory_path.iterdir() if f.is_dir()]
    except:
        all_models = ""
        model_directory = [None]

    if output_dir is None:
        if len(model_directory) > 1:
            output_dir = config["experiment"]["model_directory"].split("/")[-1]
        else:
            output_dir = config["trained_model"]["hf_cache"].split("/")[-1]

        output_dir = config["experiment"]["experiment_directory"] + output_dir
        Path(output_dir).mkdir(parents=True, exist_ok=True)

    for idx, model_weights_str in enumerate(model_directory):
        print(f"*** Evaluating model {idx+1} / {len(model_directory)} ***")
        if len(model_directory) > 1:
            model_weights = all_models + model_weights_str
        else:
            model_weights = model_weights_str
        try:
            tuner = create_evaluator(name=config["evaluator"]["name"], config=config, model_weights=model_weights)
        except:
            raise ValueError("Cannot create evaluator")

        Path(output_dir).mkdir(parents=True, exist_ok=True)
        log = tuner.run()
        log.to_csv(f"{output_dir}/{model_weights_str}.csv")

        print(f"*** Completed model {idx+1} / {len(model_directory)} ***")
        print(" ")

        del tuner


@click.command()
@click.option('--cf', default="", help='Configuration file to run experiments')
@click.option('--mode', default="train", help='Mode: train or eval')
@click.option(
    '--output', default=None, help='Output directory to write experiment results'
)
@click.option('--runs', default=1, help='Number of runs with different random seeds')
def run(cf, mode, output, runs):
    mp.set_start_method("spawn", force=True)
    for i in range(runs):
        print(f"\n=== Starting run {i+1}/{runs} ({mode}) ===")
        try:
            with open(cf, 'r') as file:
                loaded_config = yaml.safe_load(file)
            if mode == "train":
                #tuner(config=loaded_config, output_dir=output)
                p = mp.Process(target=tuner, args=(loaded_config, output))
            elif mode == "eval":
                #evaluator(config=loaded_config, output_dir=output)
                p = mp.Process(target=evaluator, args=(loaded_config, output))
            else:
                raise ValueError(
                    "Unknown mode. The selected mode must be train or eval"
                )
            
            p.start()
            p.join()
            if p.exitcode != 0:
                raise SystemExit(f"Child process failed on run {i+1} with exit code {p.exitcode}")
            print(f"=== Completed run {i+1}/{runs} ===")

        except yaml.YAMLError as exc:
            print(f"Error loading YAML file: {exc}")


if __name__ == "__main__":
    run()
