import os
from pathlib import Path
import logging
from dataclasses import asdict

import hydra
from omegaconf import OmegaConf
from dotenv import load_dotenv

from redflag.sft_runner import RedFlagScriptArguments, RedFlagConfig, RedFlagModelConfig, run
from redflag.utils import (
    setup_distributed_environment,
    create_output_dir_with_overrides, 
    log_hydra_override_information, 
    get_global_rank,
    init_wandb,
)

logging.basicConfig(level=logging.INFO)


@hydra.main(config_path="configs_v2", config_name="config", version_base='1.3')
def main(cfg):
    setup_distributed_environment()

    logging.info("----- Hydra Config -----")
    logging.info(OmegaConf.to_yaml(cfg))

    # Debug log all environment variables
    logging.debug("----- Environment Variables -----")
    for key, value in sorted(os.environ.items()):
        logging.debug(f"{key}={value}")
    logging.debug("-" * 50)

    # Get and log override information using utils
    overrides_info = log_hydra_override_information(logging)

    # convert Hydra config
    script_args = RedFlagScriptArguments(**OmegaConf.to_container(cfg.script_args))
    training_args = RedFlagConfig(**OmegaConf.to_container(cfg.training_args))
    model_config = RedFlagModelConfig(**OmegaConf.to_container(cfg.model_config))
    
    # Update output directory based on overrides using utils
    original_output_dir = training_args.output_dir
    updated_output_dir = create_output_dir_with_overrides(original_output_dir, overrides_info['filtered_override_dirname'])
    if updated_output_dir != original_output_dir and cfg.get('update_output_dir', False):
        logging.info(f"Updated output_dir from '{original_output_dir}' to '{updated_output_dir}'")
        training_args.output_dir = updated_output_dir
    
    if cfg.get('adv_evals', False):
        cfg.adv_evals.llmqc_cfg._model_path = training_args.output_dir  # important to do this before OmegaConf.resolve()!!!
    
    config = {
        'script_args': asdict(script_args),
        'training_args': asdict(training_args),
        'model_config': asdict(model_config)
    }   
    
    if script_args.load_dotenv: 
        load_dotenv()

    # master flag for resuming from checkpoints
    force_no_resume = not script_args.resume_checkpoint
    output_dir = Path(training_args.output_dir)

    if cfg.get('dry_run', False):
        logging.info("Dry run - exiting")
        exit()

    group = output_dir.parent.stem if cfg.get('update_output_dir', False) else None
    OmegaConf.resolve(cfg)

    if get_global_rank() == 0:
        init_wandb(config, force_no_resume, output_dir, group)

    # checkpoint logic for HF Trainer
    if force_no_resume:
        logging.info("`resume_checkpoint`=False => ignoring old checkpoints; starting fresh.")
        script_args.resume_checkpoint = False
    else:
        checkpoint_dirs = list(output_dir.glob("checkpoint-*"))
        if checkpoint_dirs:
            logging.info("Found checkpoint => will resume training.")
            script_args.resume_checkpoint = True
        else:
            logging.info("No checkpoint found => fresh training.")
            script_args.resume_checkpoint = False

    run(script_args, training_args, model_config)
    logging.info("Done")

if __name__ == "__main__":
    main()
