# Override FileLock globally to use SoftFileLock
import filelock
from filelock import SoftFileLock
from stuned.utility.helpers_for_main import prepare_wrapper_for_experiment
from stuned.utility.logger import try_to_log_in_csv, try_to_log_in_wandb
from stuned.utility.utils import AttrDict

filelock.FileLock = SoftFileLock

import os
from pathlib import Path

from transformers import AutoModelForCausalLM, AutoTokenizer

from utils.config import load_yaml, update_config
from utils.data_utils import load_data
from utils.model_utils import configure_padding_token, seed_everything


def get_model_settings(model_config, model_name):
    """Get model settings from config key (e.g., "llama2_7b")."""
    if model_name not in model_config:
        available_models = list(model_config.keys())
        raise ValueError(
            f"Model '{model_name}' not found in config. Available models: {available_models}"
        )

    return model_config[model_name]


def main(experiment_config, logger, processes_to_kill_before_exiting):
    # Parse arguments and load configs
    # if 'None' in experiment_config set to None
    for k, v in experiment_config.items():
        if v == "None":
            experiment_config[k] = None

    args = AttrDict(experiment_config)

    config_dir = Path(args.config_dir)

    training_config = load_yaml(config_dir / "training_config.yaml")
    model_config = load_yaml(config_dir / "model_config.yaml")
    training_config = update_config(training_config, args)
    model_config = update_config(model_config, args)

    # get only top folder name from args.temp_saving_path
    args.temp_saving_path = args.temp_saving_path.split("/")[-1]

    # Get model configuration
    if not args.model_name:
        available_models = list(model_config.keys())
        raise ValueError(
            f"Model name must be provided via config key. Available models: {available_models}"
        )

    model_settings = get_model_settings(model_config, args.model_name)
    seed_everything(training_config["training"]["seed"])

    # Initialize model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_settings["name"]
        if "path" not in model_settings
        else model_settings["path"]
    )

    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        model_settings["name"]
        if "path" not in model_settings
        else model_settings["path"],
        torch_dtype="bfloat16",
        device_map="auto",
    )

    # Configure padding token
    tokenizer, model = configure_padding_token(tokenizer, model, model_settings)

    # print nvidia-smi
    os.system("nvidia-smi")

    # Load and prepare data using the centralized load_data function
    datasets = load_data(training_config["data"])

    # Train the model using selected backend
    if training_config["training"]["training_backend"] == "llama_factory":
        from unlocking.train_llama_factory import train_model

        try:
            del model
            del tokenizer
        except Exception as e:
            print(f"Error deleting model and tokenizer: {e}")
        model = None
        tokenizer = None
    else:
        from unlocking.train import train_model

    train_model(
        model=model,
        tokenizer=tokenizer,
        data=datasets,
        training_config=training_config,
        output_dir=training_config["output"]["base_path"],
        model_name=model_settings["name"]
        if "path" not in model_settings
        else model_settings["path"],
        model_settings=model_settings,
        temp_saving_path=args.temp_saving_path,
    )


def check_config_for_demo_experiment(config, config_path, logger):
    pass


def run_experiment():
    prepare_wrapper_for_experiment(check_config_for_demo_experiment)(main)()


if __name__ == "__main__":
    run_experiment()
