# Set global random seed before imports
import unsloth
import os, random, numpy as np, torch
import logging
import sys
from pathlib import Path

# Add project root to Python path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

from prompts.few_shots import CONNECT_SYSTEM_FEW_SHOT

# Configuration for log suppression
suppress_logs = True
if suppress_logs:
    # Suppress all logging messages from system_parser and related modules
    logging.getLogger('system_parser').setLevel(logging.CRITICAL)
    logging.getLogger('system_parser.system_graph').setLevel(logging.CRITICAL)
    logging.getLogger('system_parser.pandapower').setLevel(logging.CRITICAL)
    logging.getLogger('system_parser.pandapower.pandapower_converter').setLevel(logging.CRITICAL)
    
    # Suppress pandapower library logs
    logging.getLogger('pandapower').setLevel(logging.CRITICAL)
    logging.getLogger('pandapower.diagnostic_reports').setLevel(logging.CRITICAL)
    
    # Also suppress other potential noisy loggers
    logging.getLogger('transformers').setLevel(logging.ERROR)
    logging.getLogger('torch').setLevel(logging.ERROR)
    logging.getLogger('unsloth').setLevel(logging.ERROR)
    
    # Set root logger to only show critical messages
    logging.getLogger().setLevel(logging.CRITICAL)
else:
    # Suppress system_parser module logs when not fully suppressing
    logging.getLogger('system_parser.system_graph').setLevel(logging.CRITICAL)

# Import the refactored ToolEnvironment class
from envs.environments import ToolEnvironment
from envs.validation_config import ValidationConfig
from trainers.grpo_env_trainer_unsloth import UnslothGRPOEnvTrainer
from rewards.power_system_reward import PowerSystemReward

# Environment variables and settings
SEED = 1000

# Set random seeds for reproducibility
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import dependencies
from trl import GRPOConfig
from unsloth import FastLanguageModel
from utils.data_utils import preprocess_dataset
from utils.wandb_callbacks import create_script_upload_callback
from tools.search_blocks import search_blocks
from tools.electrical_check import electrical_check

# Project configuration
dataset_name = "SimuAgent/sps_complete"
wandb_project = "sps_complete"
os.environ["WANDB_PROJECT"] = wandb_project

# Generate run name
run_name = "execution_run_001"

print(f"Run name: {run_name}")

# Tool prompt for reasoning
SYSTEM_PROMPT = """
Think step-by-step inside <think>...</think> tags. Provide your final answer inside <answer>...</answer> tags.

You have access to tools to help solve problems:
{tool_descriptions}

Call tools using a JSON command within <tool> tags, including:

"name": tool name
"args": tool arguments
Tool output will appear in <result> tags. Multiple tool calls are allowed if needed.
<answer>...</answer> tags must contain only the final answer.</answer>

Some blocks, such as the Three-Phase PI Section Line, Three-Phase Transformer (Two Windings), and Three-Phase V-I Measurement, have ports on both sides—(a1,b1,c1) on one side and (a2,b2,c2) on the other.
"""

# Setup environment
dataset = preprocess_dataset(dataset_name, "train")

# Create validation configuration
validation_config = ValidationConfig()

# Reward weights (matching training_config.yaml exactly)
reward_weights = {
    'connectivity': 0.0,
    'validation': 0.0,
    'parameter': 0.0,
    'conversion': 0.0,
    'diagnostic': 0.0,
    'load_satisfaction': 0.2,
    'structure': 0.0,
    'tool_execution': 0.10,
    'format': 0.05,
    'xml': 0.05,
    'connection_addition': 0.25,
    'block_addition': 0.25,
    'frequency_coherence': 0.0,
    'voltage_coherence': 0.0,
    'port_connectivity': 0.0,
    'block_effectiveness': 0.0,
    'dictionary_match': 0.0,
    'answer_comparison': 1.0
}

# Create the refactored ToolEnvironment
env = ToolEnvironment(
    dataset=dataset,
    system_prompt=SYSTEM_PROMPT,
    tools=[electrical_check],
    max_steps=3,
    reward=PowerSystemReward(
        tools=[electrical_check],
        power_system_weights=reward_weights
    ),
    validation_config=validation_config,
    few_shot=CONNECT_SYSTEM_FEW_SHOT[0]
)

print(env.system_prompt)

# Load and configure model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Qwen2.5-7B-Instruct-bnb-4bit",
    max_seq_length=4096,
    load_in_4bit=True,
    fast_inference=True,
    max_lora_rank=32,
    gpu_memory_utilization=0.6,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=32,
    use_gradient_checkpointing="unsloth",
    random_state=SEED,
)

# Training configuration
training_args = GRPOConfig(
    seed=SEED,
    output_dir=f"outputs/{wandb_project}/{run_name}",
    run_name=run_name,
    learning_rate=5e-6,
    lr_scheduler_type="constant_with_warmup",
    warmup_steps=10,
    num_train_epochs=1,
    temperature=1.0,
    max_steps=2000,
    bf16=True,
    max_grad_norm=0.1,
    num_iterations=2,
    beta=0.002,
    max_prompt_length=512,
    max_completion_length=3584,  # 4096 - 512
    per_device_train_batch_size=4,
    num_generations=4,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    save_strategy="steps",
    save_steps=100,
    save_only_model=True,
    use_vllm=True,
    vllm_gpu_memory_utilization=0.9,
    logging_steps=1,
    log_on_each_node=False,
    log_completions=True,
    report_to=["wandb"],
    reward_weights=env.get_reward_weights(),
)

# Create callback to upload script to wandb
script_callback = create_script_upload_callback(
    script_path=__file__,
    additional_files=[]
)

# Initialize trainer
trainer = UnslothGRPOEnvTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=env.get_reward_funcs(),
    env=env,
    args=training_args,
    train_dataset=env.get_dataset(),
    eval_dataset=env.get_eval_dataset(),
    my_eval_steps=100,
    callbacks=[script_callback],
)

if __name__ == "__main__":
    trainer.train() 