import os
from os import path
import time
import argparse
import traceback
# Initialize NVML
try:
    import pynvml
    pynvml.nvmlInit()
    device_count = pynvml.nvmlDeviceGetCount()

    # Find the GPU with the most available memory
    max_free_mem = 0
    best_gpu_index = 0

    for i in range(device_count):
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
        mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        free_mem = mem_info.free
        if free_mem > max_free_mem:
            max_free_mem = free_mem
            best_gpu_index = i
    # Set the best GPU
    os.environ["CUDA_VISIBLE_DEVICES"] = str(best_gpu_index)

    # Optional: Print selected GPU
    print(f"Using GPU {best_gpu_index} with {max_free_mem / 1024**2:.2f} MB free memory")
    pynvml.nvmlShutdown()
except ModuleNotFoundError:
    print("NVIDIA GPU monitoring is not available, please set CUDA_VISIBLE_DEVICES to the index of the GPU you want to use.")


os.environ["RAY_DISABLE_LOGGING"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["RAY_memory_monitor_refresh_ms"] = "0"
import traceback
from args import Args as CALM_ARGS


def run(calm_args: CALM_ARGS):
    max_seq_length = 4096 # Can increase for longer reasoning traces
    lora_rank = 32 # Larger rank = smarter, but slower

    online = calm_args.model_name.lower().startswith('online')
    if not online:
        from unsloth import FastLanguageModel, PatchFastRL
        PatchFastRL("GRPO", FastLanguageModel)

        from unsloth import is_bfloat16_supported
        from trl import GRPOConfig as TrainerConfig
        from transformers import set_seed as transformers_set_seed
        import torch

        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name = calm_args.model_name,
            max_seq_length = max_seq_length,
            load_in_4bit = True, # False for LoRA 16bit
            fast_inference = True, # Enable vLLM fast inference
            max_lora_rank = lora_rank,
            temperature = 1.0,
            gpu_memory_utilization = 0.8, # Reduce if out of memory
        )

        model = FastLanguageModel.get_peft_model(
            model,
            r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
            target_modules = [
                "q_proj", "k_proj", "v_proj", "o_proj",
                "gate_proj", "up_proj", "down_proj",
            ], # Remove QKVO if out of memory
            lora_alpha = 2*lora_rank,
            use_gradient_checkpointing = "unsloth", # Enable long context finetuning
        )


        training_args = TrainerConfig(
            use_vllm = True, # use vLLM for fast inference!
            learning_rate = calm_args.lr,
            adam_beta1 = 0.9,
            adam_beta2 = 0.99,
            weight_decay = 0.01,
            warmup_ratio = 0,
            lr_scheduler_type = "constant",
            optim = "adamw_8bit",
            logging_steps = 1,
            bf16 = is_bfloat16_supported(),
            fp16 = not is_bfloat16_supported(),
            per_device_train_batch_size = 1,
            gradient_accumulation_steps = 1 if calm_args.n_prompts == 1 else 2, # Increase to 4 for smoother training
            num_generations = calm_args.n_generations, # Decrease if out of memory
            max_prompt_length = 4000,
            max_completion_length = 1000,
            num_train_epochs = 1, # Set to 1 for a full training run
            max_grad_norm = 0.1,
            output_dir = None,
        )
    else:
        model = calm_args.model_name.split('/')[1]
        print(model)
        training_args = None
    
    from calm_trainer import Trainer
    trainer = Trainer(
        model = model,
        calm_args=calm_args,
        args = training_args,
    )

    # trainer.save_model()
    n_saved = 1
    while trainer.log_step < calm_args.max_steps * (calm_args.n_generations if online else 1):
        try:
            if not online:
                trainer.train()
            else:
                trainer.query_all()
            trainer.prepare_dataset()
            if (n_saved + 1) % 100 == 0 or trainer.log_step >= calm_args.max_steps:
                trainer.save_model()
            n_saved += 1
        except:
            print(traceback.format_exc())

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run CALM with one or more YAML config files.')
    parser.add_argument('--config', type=str, help='Path to the YAML config file', default='')
    # parser.add_argument('--diff-seeds-for-duplicate-exps', action='store_false', help='If set, use different seeds for running same config')
    args = parser.parse_args()
    config_path = args.config
    calm_args = CALM_ARGS.from_yaml(config_path)
    calm_args.log_name = path.splitext(path.basename(config_path))[0]
    n = 0
    cfp = path.abspath(path.dirname(__file__))
    while True:
        log_name = calm_args.log_name + f"_{n}"
        if not path.exists(path.join(cfp, 'calm_saved', calm_args.problem_name, log_name)):
            calm_args.log_name = log_name
            break
        n += 1
    run(calm_args)