# Resume Training Script for SimuAgent
# This script resumes training from a checkpoint using hardcoded configuration values

import unsloth
import os, random, numpy as np, torch
import logging
import sys
from pathlib import Path

# Add project root to Python path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

# Resume configuration
resume_from_checkpoint = True
checkpoint_path = "outputs/simuagent_training/resume_run_001/checkpoint-500"

print(f"Resume from checkpoint: {resume_from_checkpoint}")
if resume_from_checkpoint and checkpoint_path:
    print(f"Resume enabled - checkpoint path: {checkpoint_path}")
else:
    print("Resume not configured - will start fresh training")

# Configuration for log suppression
suppress_logs = True
if suppress_logs:
    # Suppress all logging messages from system_parser and related modules
    logging.getLogger('system_parser').setLevel(logging.CRITICAL)
    logging.getLogger('system_parser.system_graph').setLevel(logging.CRITICAL)
    logging.getLogger('system_parser.pandapower').setLevel(logging.CRITICAL)
    logging.getLogger('system_parser.pandapower.pandapower_converter').setLevel(logging.CRITICAL)
    
    # Suppress pandapower library logs
    logging.getLogger('pandapower').setLevel(logging.CRITICAL)
    logging.getLogger('pandapower.diagnostic_reports').setLevel(logging.CRITICAL)
    
    # Also suppress other potential noisy loggers
    logging.getLogger('transformers').setLevel(logging.ERROR)
    logging.getLogger('torch').setLevel(logging.ERROR)
    logging.getLogger('unsloth').setLevel(logging.ERROR)
    
    # Set root logger to only show critical messages
    logging.getLogger().setLevel(logging.CRITICAL)
else:
    # Suppress system_parser module logs when not fully suppressing
    logging.getLogger('system_parser.system_graph').setLevel(logging.CRITICAL)

# Environment variables and settings
SEED = 42

# Set random seeds for reproducibility
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import dependencies
from envs.environments import ToolEnvironment
from envs.validation_config import ValidationConfig
from trainers.grpo_env_trainer_unsloth import UnslothGRPOEnvTrainer
from rewards.power_system_reward import PowerSystemReward
from trl import GRPOConfig
from unsloth import FastLanguageModel
from utils.data_utils import preprocess_dataset
from utils.wandb_callbacks import create_script_upload_callback
from tools.search_blocks import search_blocks

# Project configuration
dataset_name = "simuagent_dataset"
wandb_project = "simuagent_training"
os.environ["WANDB_PROJECT"] = wandb_project

# Generate run name
run_name = "resume_run_001"

print(f"Run name: {run_name}")

# Tool prompt for reasoning
SYSTEM_PROMPT = """
Think step-by-step inside <think>...</think> tags. Provide your final answer inside <answer>...</answer> tags.

You have access to tools to help solve problems:
{tool_descriptions}

Call tools using a JSON command within <tool> tags, including:

"name": tool name
"args": tool arguments
Tool output will appear in <result> tags. Multiple tool calls are allowed if needed.
<answer>...</answer> tags must contain only the final answer.</answer>
"""

# Setup environment
dataset = preprocess_dataset(dataset_name, "train")

# Create validation configuration
validation_config = ValidationConfig()

# Reward weights
reward_weights = {
    'convergence': 0.3,
    'voltage_violations': 0.2,
    'thermal_violations': 0.2,
    'power_balance': 0.15,
    'reactive_power': 0.1,
    'frequency': 0.05
}

# Create the refactored ToolEnvironment
env = ToolEnvironment(
    dataset=dataset,
    system_prompt=SYSTEM_PROMPT,
    tools=[search_blocks],
    max_steps=2000,
    reward=PowerSystemReward(
        tools=[search_blocks],
        power_system_weights=reward_weights
    ),
    validation_config=validation_config,
)

print("Environment created successfully")

# Load and configure model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit",
    max_seq_length=8192,
    load_in_4bit=True,
    fast_inference=True,
    max_lora_rank=64,
    gpu_memory_utilization=0.8,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=128,
    use_gradient_checkpointing="unsloth",
    random_state=SEED,
)

# Training configuration
training_args = GRPOConfig(
    seed=SEED,
    output_dir=f"outputs/{wandb_project}/{run_name}",
    run_name=run_name,
    learning_rate=3e-5,
    lr_scheduler_type="cosine",
    warmup_steps=100,
    num_train_epochs=1,
    temperature=0.8,
    max_steps=1000,
    bf16=True,
    max_grad_norm=1.0,
    num_iterations=10,
    beta=0.1,
    max_prompt_length=4096,
    max_completion_length=4096,
    per_device_train_batch_size=1,
    num_generations=4,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    save_strategy="steps",
    save_steps=100,
    save_only_model=True,
    use_vllm=False,
    vllm_gpu_memory_utilization=0.8,
    logging_steps=10,
    log_on_each_node=False,
    log_completions=True,
    report_to=["wandb"],
    reward_weights=env.get_reward_weights(),
)

# Create callback to upload script to wandb
script_callback = create_script_upload_callback(
    script_path=__file__,
    additional_files=[]
)

# Initialize trainer
trainer = UnslothGRPOEnvTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=env.get_reward_funcs(),
    env=env,
    args=training_args,
    train_dataset=env.get_dataset(),
    eval_dataset=env.get_eval_dataset(),
    my_eval_steps=50,
    callbacks=[script_callback],
)

if __name__ == "__main__":
    # Resume training from checkpoint
    if resume_from_checkpoint and checkpoint_path:
        print(f"Resuming training from checkpoint: {checkpoint_path}")
        trainer.train(resume_from_checkpoint=checkpoint_path)
    else:
        print("Starting fresh training (no valid checkpoint path provided)")
        trainer.train() 