import os
import sys
import numpy as np
import argparse
import configparser
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.nn.utils import clip_grad_norm_
import time
import logging
from tqdm import tqdm
import warnings
from datetime import datetime
import pickle
from collections import defaultdict
import torch.nn.functional as F
from pathlib import Path

def verify_phase_differences(history, config, logger):
    phase1_epochs = int(config['Training'].get('epochs_phase1', 40))
    phase2_epochs = int(config['Training'].get('epochs_phase2', 60))
    phase3_epochs = int(config['Training'].get('epochs_phase3', 100))
    
    if len(history['train_loss']) < phase1_epochs + phase2_epochs:
        return
    
    phase1_loss = np.mean(history['train_loss'][:phase1_epochs])
    phase2_loss = np.mean(history['train_loss'][phase1_epochs:phase1_epochs+phase2_epochs])
    
    if len(history['train_loss']) >= phase1_epochs + phase2_epochs + phase3_epochs:
        phase3_loss = np.mean(history['train_loss'][phase1_epochs+phase2_epochs:])
    else:
        phase3_loss = np.mean(history['train_loss'][phase1_epochs+phase2_epochs:])
    
    phase1_causal = np.mean(history['causal_loss'][:phase1_epochs])
    phase2_causal = np.mean(history['causal_loss'][phase1_epochs:phase1_epochs+phase2_epochs])

def set_random_seed(seed=42):
    import random
    import numpy as np
    import torch
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    os.environ['PYTHONHASHSEED'] = str(seed)

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
warnings.filterwarnings('ignore')

try:
    from Orion_model import OrionModel, create_orion_model
except ImportError:
    print("错误: 无法导入Orion_model.py，请确保文件在当前目录")
    sys.exit(1)

def setup_logging(config):
    log_dir = Path(config['Training']['log_path'])
    log_dir.mkdir(parents=True, exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = log_dir / f"orion_training_{timestamp}.log"
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        handlers=[
            logging.FileHandler(log_file, encoding='utf-8'),
            logging.StreamHandler(sys.stdout)
        ]
    )
    
    logger = logging.getLogger(__name__)
    logger.info("Orion Model Training Started")
    logger.info(f"Log file: {log_file}")
    
    return logger

def load_data(config, logger):
    logger.info("Loading dataset...")
    
    graph_signal_matrix_filename = config['Data']['graph_signal_matrix_filename']
    base_dir = os.path.dirname(graph_signal_matrix_filename)
    base_name = os.path.basename(graph_signal_matrix_filename).split('.')[0]
    
    num_hours = config['Data']['num_of_hours']
    num_days = config['Data']['num_of_days'] 
    num_weeks = config['Data']['num_of_weeks']
    data_file = os.path.join(base_dir, f'{base_name}_r{num_hours}_d{num_days}_w{num_weeks}_Orion.npz')
    
    if not os.path.exists(data_file):
        logger.error(f"Data file not found: {data_file}")
        logger.error("Please run prepare_data.py first")
        raise FileNotFoundError(f"Data file not found: {data_file}")
    
    logger.info(f"Loading data from: {data_file}")
    data = np.load(data_file)
    
    train_data = {}
    val_data = {}
    
    train_data['x_h'] = torch.from_numpy(data['train_x_h']).float() if 'train_x_h' in data else None
    train_data['x_w'] = torch.from_numpy(data['train_x_w']).float() if 'train_x_w' in data else None
    train_data['x_d'] = torch.from_numpy(data['train_x_d']).float() if 'train_x_d' in data else None
    train_data['target'] = torch.from_numpy(data['train_target']).float() if 'train_target' in data else None
    train_data['time_indices'] = torch.from_numpy(data['train_time_indices']).long() if 'train_time_indices' in data else None
    
    val_data['x_h'] = torch.from_numpy(data['val_x_h']).float() if 'val_x_h' in data else None
    val_data['x_w'] = torch.from_numpy(data['val_x_w']).float() if 'val_x_w' in data else None
    val_data['x_d'] = torch.from_numpy(data['val_x_d']).float() if 'val_x_d' in data else None
    val_data['target'] = torch.from_numpy(data['val_target']).float() if 'val_target' in data else None
    val_data['time_indices'] = torch.from_numpy(data['val_time_indices']).long() if 'val_time_indices' in data else None
    
    stats = {}
    
    stats['min_flow'] = float(data['train_min_flow']) if 'train_min_flow' in data else 0.0
    stats['max_flow'] = float(data['train_max_flow']) if 'train_max_flow' in data else 1.0
    stats['mean_flow'] = float(data['train_mean_flow']) if 'train_mean_flow' in data else 0.0
    stats['std_flow'] = float(data['train_std_flow']) if 'train_std_flow' in data else 1.0
    
    stats['feature_stats'] = {}
    
    feature_keys = [key for key in data.keys() if key.startswith('train_feature_') and key.endswith('_min')]
    
    if feature_keys:
        for key in feature_keys:
            parts = key.split('_')
            if len(parts) >= 3:
                feature_idx = parts[2]
                try:
                    feature_num = int(feature_idx)
                    feature_stat = {
                        'min': float(data[f'train_feature_{feature_idx}_min']),
                        'max': float(data[f'train_feature_{feature_idx}_max']),
                        'mean': float(data[f'train_feature_{feature_idx}_mean']),
                        'std': float(data[f'train_feature_{feature_idx}_std'])
                    }
                    
                    if f'train_feature_{feature_idx}_median' in data:
                        feature_stat['median'] = float(data[f'train_feature_{feature_idx}_median'])
                    if f'train_feature_{feature_idx}_iqr' in data:
                        feature_stat['iqr'] = float(data[f'train_feature_{feature_idx}_iqr'])
                    
                    stats['feature_stats'][feature_num] = feature_stat
                except ValueError:
                    continue
    
    adj_matrix = data['adj_matrix'] if 'adj_matrix' in data else None
    
    logger.info("Data loaded successfully")
    if train_data['x_h'] is not None:
        logger.info(f"Train x_h shape: {train_data['x_h'].shape}")
    if train_data['x_w'] is not None:
        logger.info(f"Train x_w shape: {train_data['x_w'].shape}")
    if train_data['x_d'] is not None:
        logger.info(f"Train x_d shape: {train_data['x_d'].shape}")
    
    return train_data, val_data, stats, adj_matrix

class OrionDataset(torch.utils.data.Dataset):
    def __init__(self, x_h, x_w, x_d, target, time_indices):
        self.x_h = x_h
        self.x_w = x_w  
        self.x_d = x_d
        self.target = target
        self.time_indices = time_indices
        
        self.length = x_h.shape[0] if x_h is not None else 0
        
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        sample = {}
        
        if self.x_h is not None:
            sample['x_h'] = self.x_h[idx]
        else:
            sample['x_h'] = torch.empty(0)
            
        if self.x_w is not None:
            sample['x_w'] = self.x_w[idx]
        else:
            sample['x_w'] = torch.empty(0)
            
        if self.x_d is not None:
            sample['x_d'] = self.x_d[idx]
        else:
            sample['x_d'] = torch.empty(0)
            
        if self.target is not None:
            sample['target'] = self.target[idx]
        else:
            sample['target'] = torch.empty(0)
            
        if self.time_indices is not None:
            sample['time_indices'] = self.time_indices[idx]
        else:
            sample['time_indices'] = torch.empty(0, dtype=torch.long)
        
        return (sample['x_h'], sample['x_w'], sample['x_d'], 
                sample['target'], sample['time_indices'])

def create_data_loaders(train_data, val_data, config, logger):
    logger.info("Creating data loaders...")
    
    batch_size = int(config['Training']['batch_size'])
    
    train_dataset = OrionDataset(
        train_data['x_h'],
        train_data['x_w'],
        train_data['x_d'],
        train_data['target'],
        train_data['time_indices']
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=True
    )
    
    val_dataset = OrionDataset(
        val_data['x_h'],
        val_data['x_w'],
        val_data['x_d'],
        val_data['target'],
        val_data['time_indices']
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        drop_last=False
    )
    
    logger.info(f"Train batches: {len(train_loader)}")
    logger.info(f"Val batches: {len(val_loader)}")
    logger.info(f"Batch size: {batch_size}")
    
    return train_loader, val_loader

def setup_model_and_optimizer(config, logger, device):
    logger.info("Initializing model and optimizer...")
    
    model_config = {
        'num_of_vertices': int(config['Data']['num_of_vertices']),
        'in_channels': int(config['Data']['in_channels']),
        'target_len': int(config['Data']['target_len']),
        'source_len': int(config['Data']['source_len']),
        'd_model': int(config['Model']['d_model']),
        'd_ff_emb': int(config['Model']['d_ff_emb']),
        'd_ff_belt': int(config['Model']['d_ff_belt']),
        'd_ff_fusion': int(config['Model']['d_ff_fusion']),
        'd_ff_reverse': int(config['Model']['d_ff_reverse']),
        'n_belt_block': int(config['Model']['n_belt_block']),
        'head_s': int(config['Model']['head_s']),
        'head_t': int(config['Model']['head_t']),
        'head_f': int(config['Model']['head_f']),
        'num_time_segments': int(config['Model']['num_time_segments']),
        'dropout': float(config['Model']['dropout']),
        'causal_threshold': float(config['Model']['causal_threshold'])
    }
    
    model = create_orion_model(model_config)
    
    use_multi_gpu = config['Training'].getboolean('use_multi_gpu')
    if use_multi_gpu and torch.cuda.device_count() > 1:
        gpu_ids = [int(x.strip()) for x in config['Training']['gpu_ids'].split(',')]
        model = nn.DataParallel(model, device_ids=gpu_ids)
        logger.info(f"Using multi-GPU: {gpu_ids}")
    
    model = model.to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    logger.info(f"Total parameters: {total_params:,}")
    logger.info(f"Trainable parameters: {trainable_params:,}")
    
    learning_rate = float(config['Training']['learning_rate'])
    weight_decay = float(config['Training']['weight_decay'])
    
    optimizer = optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay,
        betas=(0.9, 0.999),
        eps=1e-8
    )
    
    lr_decay_factor = float(config['Training']['lr_decay_factor'])
    lr_decay_step_size = int(config['Training']['lr_decay_step_size'])
    
    scheduler = optim.lr_scheduler.StepLR(
        optimizer,
        step_size=lr_decay_step_size,
        gamma=lr_decay_factor
    )
    
    logger.info(f"Initial learning rate: {learning_rate}")
    logger.info(f"Weight decay: {weight_decay}")
    
    return model, optimizer, scheduler

def compute_metrics(predictions, targets, stats):
    use_robust = 'median' in stats.get('feature_stats', {}).get(0, {})
    
    if use_robust:
        if 'median' in stats.get('feature_stats', {}).get(0, {}):
            median_flow = stats['feature_stats'][0]['median']
            iqr_flow = stats['feature_stats'][0]['iqr']
            
            if iqr_flow > 1e-8:
                pred_denorm = predictions * iqr_flow + median_flow
                target_denorm = targets * iqr_flow + median_flow
            else:
                mean_flow = stats['mean_flow']
                std_flow = stats['std_flow']
                if std_flow > 1e-8:
                    pred_denorm = predictions * std_flow + mean_flow
                    target_denorm = targets * std_flow + mean_flow
                else:
                    pred_denorm = predictions
                    target_denorm = targets
        else:
            mean_flow = stats['mean_flow']
            std_flow = stats['std_flow']
            if std_flow > 1e-8:
                pred_denorm = predictions * std_flow + mean_flow
                target_denorm = targets * std_flow + mean_flow
            else:
                pred_denorm = predictions
                target_denorm = targets
    else:
        min_flow = stats['min_flow']
        max_flow = stats['max_flow']
        
        if max_flow - min_flow > 1e-8:
            pred_denorm = predictions * (max_flow - min_flow) + min_flow
            target_denorm = targets * (max_flow - min_flow) + min_flow
        else:
            pred_denorm = predictions
            target_denorm = targets
    
    mae = F.l1_loss(pred_denorm, target_denorm).item()
    
    mse = F.mse_loss(pred_denorm, target_denorm)
    rmse = torch.sqrt(mse).item()
    
    threshold = 5
    mask = torch.abs(target_denorm) > threshold
    
    if mask.sum() > 0:
        valid_targets = target_denorm[mask]
        valid_preds = pred_denorm[mask]
        mape = torch.mean(torch.abs((valid_targets - valid_preds) / torch.abs(valid_targets))) * 100
        mape = mape.item()
    else:
        mape = 0.0
    
    if mape > 1000:
        epsilon = 1.0
        mape_safe = torch.mean(torch.abs((target_denorm - pred_denorm) / (torch.abs(target_denorm) + epsilon))) * 100
        mape = min(mape, mape_safe.item())
    
    return {'mae': mae, 'rmse': rmse, 'mape': mape}

def debug_causal_learning(model, logger, epoch):
    with torch.no_grad():
        actual_model = model.module if hasattr(model, 'module') else model
        
        gate_values = []
        for block_type in ['hour_belt_blocks', 'week_belt_blocks', 'day_belt_blocks']:
            blocks = getattr(actual_model, block_type, None)
            if blocks:
                for i, block in enumerate(blocks):
                    gate = block.te_causgat.causal_gate.item()
                    gate_values.append(gate)

def train_epoch(model, train_loader, optimizer, device, logger, lambda_causal=0.1, 
                epoch=0, phase="训练", config=None):
    model.train()
    total_loss = 0.0
    total_pred_loss = 0.0
    total_causal_loss = 0.0
    total_samples = 0
    
    all_predictions = []
    all_targets = []
    
    intervention_count = 0
    total_batches = len(train_loader)
    
    enable_intervention = True
    intervention_start_epoch = 5
    intervention_frequency = 10
    num_interventions = 2
    intervention_weight = 0.5
    
    if config and config.has_section('Model'):
        if config.has_option('Model', 'enable_intervention'):
            enable_intervention = config.getboolean('Model', 'enable_intervention')
        if config.has_option('Model', 'intervention_start_epoch'):
            intervention_start_epoch = config.getint('Model', 'intervention_start_epoch')
        if config.has_option('Model', 'intervention_frequency'):
            intervention_frequency = config.getint('Model', 'intervention_frequency')
        if config.has_option('Model', 'num_interventions_per_batch'):
            num_interventions = config.getint('Model', 'num_interventions_per_batch')
        if config.has_option('Model', 'intervention_validation_weight'):
            intervention_weight = config.getfloat('Model', 'intervention_validation_weight')
    
    total_samples_in_epoch = len(train_loader.dataset)
    processed_samples = 0
    
    pbar = tqdm(total=total_samples_in_epoch, 
                desc=f"Epoch {epoch+1}", 
                unit="samples",
                leave=False)
    
    for batch_idx, (x_h, x_w, x_d, targets, time_indices) in enumerate(train_loader):
        try:
            x_h = x_h.to(device, non_blocking=True)
            x_w = x_w.to(device, non_blocking=True) 
            x_d = x_d.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            time_indices = time_indices.to(device, non_blocking=True)
            
            target_len = targets.shape[-1]
            if time_indices.shape[1] != target_len:
                if time_indices.shape[1] > target_len:
                    time_indices = time_indices[:, :target_len]
                else:
                    padding_length = target_len - time_indices.shape[1]
                    padding = time_indices[:, -1:].expand(-1, padding_length)
                    time_indices = torch.cat([time_indices, padding], dim=1)
            
            optimizer.zero_grad()
            
            perform_intervention = (
                enable_intervention and 
                epoch >= intervention_start_epoch and 
                batch_idx % intervention_frequency == 0
            )
            
            if perform_intervention:
                intervention_count += 1
            
            if hasattr(model, 'module'):
                outputs = model.module(x_h, x_w, x_d, time_indices, 
                                     perform_intervention=perform_intervention,
                                     num_interventions=num_interventions)
            else:
                outputs = model(x_h, x_w, x_d, time_indices, 
                              perform_intervention=perform_intervention,
                              num_interventions=num_interventions)
            
            predictions = outputs['predictions']
            causal_matrices = outputs.get('causal_matrices', [])
            validity_scores = outputs.get('validity_scores', None)
            
            if hasattr(model, 'module'):
                loss_dict = model.module.compute_loss(
                    predictions, targets, causal_matrices, validity_scores,
                    lambda_causal=lambda_causal
                    )
            else:
                loss_dict = model.compute_loss(
                    predictions, targets, causal_matrices, validity_scores,
                    lambda_causal=lambda_causal
                )
            
            pred_loss = loss_dict['pred_loss']
            causal_loss = loss_dict['causal_loss']
            causal_contribution = loss_dict.get('causal_contribution', 0.0)
            validity_bonus = loss_dict.get('validity_bonus', 0.0)
            total_loss_batch = loss_dict['total_loss']
            
            if torch.isnan(total_loss_batch) or torch.isinf(total_loss_batch):
                logger.warning(f"Invalid loss at batch {batch_idx}, skipping")
                continue
            
            total_loss_batch.backward()
            
            clip_grad_norm_(model.parameters(), max_norm=5.0)
            
            optimizer.step()
            
            batch_size = targets.size(0)
            total_loss += total_loss_batch.item() * batch_size
            total_pred_loss += pred_loss.item() * batch_size
            total_causal_loss += causal_loss.item() * batch_size
            total_samples += batch_size
            
            all_predictions.append(predictions.detach().cpu())
            all_targets.append(targets.detach().cpu())
            
            processed_samples += batch_size
            pbar.update(batch_size)
            
            pbar.set_postfix({
                'Loss': f'{total_loss_batch.item():.4f}',
                'Batch': f'{batch_idx+1}/{len(train_loader)}'
            })
            
        except Exception as e:
            logger.error(f"Error in training batch {batch_idx}: {str(e)}")
            import traceback
            logger.error(traceback.format_exc())
            continue
    
    pbar.close()
    
    avg_total_loss = total_loss / total_samples if total_samples > 0 else float('inf')
    avg_pred_loss = total_pred_loss / total_samples if total_samples > 0 else float('inf')
    avg_causal_loss = total_causal_loss / total_samples if total_samples > 0 else float('inf')
    
    if epoch % 5 == 0:
        debug_causal_learning(model, logger, epoch)

    return {
        'total_loss': avg_total_loss,
        'pred_loss': avg_pred_loss,
        'causal_loss': avg_causal_loss,
        'intervention_count': intervention_count,
        'predictions': torch.cat(all_predictions, dim=0) if all_predictions else None,
        'targets': torch.cat(all_targets, dim=0) if all_targets else None
    }

def validate_epoch(model, val_loader, device, logger, stats):
    model.eval()
    total_loss = 0.0
    total_samples = 0
    
    all_predictions = []
    all_targets = []
    
    total_samples_in_epoch = len(val_loader.dataset)
    processed_samples = 0
    
    with torch.no_grad():
        pbar = tqdm(total=total_samples_in_epoch, 
                   desc="Validation", 
                   unit="samples",
                   leave=False)
        
        for batch_idx, (x_h, x_w, x_d, targets, time_indices) in enumerate(val_loader):
            try:
                x_h = x_h.to(device, non_blocking=True)
                x_w = x_w.to(device, non_blocking=True)
                x_d = x_d.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)
                time_indices = time_indices.to(device, non_blocking=True)
                
                target_len = targets.shape[-1]
                if time_indices.shape[1] != target_len:
                    if time_indices.shape[1] > target_len:
                        time_indices = time_indices[:, :target_len]
                    else:
                        padding_length = target_len - time_indices.shape[1]
                        padding = time_indices[:, -1:].expand(-1, padding_length)
                        time_indices = torch.cat([time_indices, padding], dim=1)
                
                outputs = model(x_h, x_w, x_d, time_indices)
                predictions = outputs['predictions']
                
                pred_loss = F.l1_loss(predictions, targets)
                
                if torch.isnan(pred_loss) or torch.isinf(pred_loss):
                    logger.warning(f"Invalid loss at validation batch {batch_idx}, skipping")
                    continue
                
                batch_size = targets.size(0)
                total_loss += pred_loss.item() * batch_size
                total_samples += batch_size
                
                all_predictions.append(predictions.cpu())
                all_targets.append(targets.cpu())
                
                processed_samples += batch_size
                pbar.update(batch_size)
                
                pbar.set_postfix({
                    'Loss': f'{pred_loss.item():.4f}',
                    'Batch': f'{batch_idx+1}/{len(val_loader)}'
                })
                
            except Exception as e:
                logger.error(f"Error in validation batch {batch_idx}: {str(e)}")
                import traceback
                logger.error(traceback.format_exc())
                continue
        
        pbar.close()
    
    avg_loss = total_loss / total_samples if total_samples > 0 else float('inf')
    
    if all_predictions and all_targets:
        all_pred = torch.cat(all_predictions, dim=0)
        all_tgt = torch.cat(all_targets, dim=0)
        metrics = compute_metrics(all_pred, all_tgt, stats)
    else:
        metrics = {'mae': float('inf'), 'rmse': float('inf'), 'mape': float('inf')}
    
    return {
        'loss': avg_loss,
        'mae': metrics['mae'],
        'rmse': metrics['rmse'], 
        'mape': metrics['mape'],
        'predictions': all_pred if all_predictions else None,
        'targets': all_tgt if all_targets else None
    }

def save_checkpoint(model, optimizer, scheduler, epoch, metrics, config, logger, 
                    is_best_mae=False, is_best_rmse=False, is_best_mape=False):
    save_dir = Path(config['Training']['save_path'])
    save_dir.mkdir(parents=True, exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'metrics': metrics,
        'config': dict(config._sections)
    }
    
    latest_path = save_dir / 'latest_checkpoint.pth'
    torch.save(checkpoint, latest_path)
    
    if is_best_mae:
        best_mae_path = save_dir / 'best_model_mae.pth'
        torch.save(checkpoint, best_mae_path)
        logger.info(f"Saved best MAE model: {best_mae_path}")
    
    if is_best_rmse:
        best_rmse_path = save_dir / 'best_model_rmse.pth'
        torch.save(checkpoint, best_rmse_path)
        logger.info(f"Saved best RMSE model: {best_rmse_path}")
    
    if is_best_mape:
        best_mape_path = save_dir / 'best_model_mape.pth'
        torch.save(checkpoint, best_mape_path)
        logger.info(f"Saved best MAPE model: {best_mape_path}")
    
    if is_best_mae:
        best_path = save_dir / 'best_model.pth'
        torch.save(checkpoint, best_path)
        logger.info(f"Saved best model: {best_path}")
    
    if (epoch + 1) % 20 == 0:
        epoch_path = save_dir / f'checkpoint_epoch_{epoch+1}_{timestamp}.pth'
        torch.save(checkpoint, epoch_path)
        logger.info(f"Saved checkpoint: {epoch_path}")

def three_phase_training(model, train_loader, val_loader, optimizer, scheduler, config, stats, logger, device):
    logger.info("Starting three-phase training strategy")
    
    phase_config = {
        'phase1': {
            'name': 'Feature Learning',
            'epochs': int(config['Training'].get('epochs_phase1', 40)),
            'lambda_causal': float(config['Training'].get('lambda_causal_phase1', 0.1)),
            'freeze_modules': config['Training'].get('freeze_modules_phase1', '').split(',') if config['Training'].get('freeze_modules_phase1') else [],
            'description': 'Learning spatiotemporal features'
        },
        'phase2': {
            'name': 'Causal Structure Learning',
            'epochs': int(config['Training'].get('epochs_phase2', 60)),
            'lambda_causal': float(config['Training'].get('lambda_causal_phase2', 0.5)),
            'freeze_modules': config['Training'].get('freeze_modules_phase2', 'embedding_process').split(',') if config['Training'].get('freeze_modules_phase2') else ['embedding_process'],
            'description': 'Learning causal relationships'
        },
        'phase3': {
            'name': 'End-to-End Fine-tuning',
            'epochs': int(config['Training'].get('epochs_phase3', 100)),
            'lambda_causal': float(config['Training'].get('lambda_causal_phase3', 0.3)),
            'freeze_modules': config['Training'].get('freeze_modules_phase3', '').split(',') if config['Training'].get('freeze_modules_phase3') else [],
            'description': 'Final optimization'
        }
    }
    
    for phase in phase_config.values():
        phase['freeze_modules'] = [m.strip() for m in phase['freeze_modules'] if m.strip()]
    
    history = {
        'train_loss': [],
        'val_loss': [],
        'causal_loss': [],
        'val_mae': [],
        'val_rmse': [],
        'val_mape': [],
        'learning_rate': [],
        'phase': []
    }
    
    best_val_loss = float('inf')
    best_val_mae = float('inf')
    best_val_rmse = float('inf')
    best_val_mape = float('inf')
    
    patience = int(config['Training']['patience'])
    total_epochs = 0
    
    current_lr = float(config['Training']['learning_rate'])
    
    total_scheduler_steps = 0
    
    for phase_idx, (phase_name, phase_cfg) in enumerate(phase_config.items()):
        logger.info(f"Starting Phase {phase_idx+1}/3: {phase_cfg['name']}")
        logger.info(f"Epochs: {phase_cfg['epochs']}, Lambda_causal: {phase_cfg['lambda_causal']}")
        
        phase_patience_counter = 0
        phase_best_loss = float('inf')
        
        def set_module_requires_grad(model, freeze_modules, freeze=True):
            frozen_params = 0
            total_params = 0
            frozen_modules_list = []
            
            actual_model = model.module if hasattr(model, 'module') else model
            
            for name, module in actual_model.named_modules():
                if name == '':
                    continue
                    
                module_params = sum(p.numel() for p in module.parameters(recurse=False))
                total_params += module_params
                
                should_freeze = False
                for freeze_name in freeze_modules:
                    if freeze_name and freeze_name in name:
                        should_freeze = True
                        break
                
                if should_freeze and freeze:
                    for param in module.parameters(recurse=False):
                        param.requires_grad = False
                    frozen_params += module_params
                    if name not in frozen_modules_list:
                        frozen_modules_list.append(name)
                else:
                    for param in module.parameters(recurse=False):
                        param.requires_grad = True
            
            return frozen_params, total_params, frozen_modules_list
        
        if phase_cfg['freeze_modules']:
            for param in model.parameters():
                param.requires_grad = True
                
            frozen_params, total_params, frozen_modules = set_module_requires_grad(
                model, phase_cfg['freeze_modules'], freeze=True
            )
            
            logger.info(f"Frozen parameters: {frozen_params:,}/{total_params:,}")
            
            trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
            logger.info(f"Trainable parameters: {trainable_params:,}")
            
            trainable_params_list = [p for p in model.parameters() if p.requires_grad]
            if trainable_params_list:
                optimizer = optim.AdamW(
                    trainable_params_list,
                    lr=current_lr,
                    weight_decay=float(config['Training']['weight_decay']),
                    betas=(0.9, 0.999),
                    eps=1e-8
                )
                
                lr_decay_step_size = int(config['Training']['lr_decay_step_size'])
                lr_decay_factor = float(config['Training']['lr_decay_factor'])
                
                scheduler = optim.lr_scheduler.StepLR(
                    optimizer,
                    step_size=lr_decay_step_size,
                    gamma=lr_decay_factor
                )
                
                for _ in range(total_scheduler_steps // lr_decay_step_size):
                    scheduler.step()
        else:
            for param in model.parameters():
                param.requires_grad = True
            
            trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
            logger.info(f"All parameters trainable: {trainable_params:,}")
            
            if phase_idx == 0 or (phase_idx > 0 and phase_config[f'phase{phase_idx}']['freeze_modules']):
                optimizer = optim.AdamW(
                    model.parameters(),
                    lr=current_lr,
                    weight_decay=float(config['Training']['weight_decay']),
                    betas=(0.9, 0.999),
                    eps=1e-8
                )
                
                lr_decay_step_size = int(config['Training']['lr_decay_step_size'])
                lr_decay_factor = float(config['Training']['lr_decay_factor'])
                
                scheduler = optim.lr_scheduler.StepLR(
                    optimizer,
                    step_size=lr_decay_step_size,
                    gamma=lr_decay_factor
                )
                
                for _ in range(total_scheduler_steps // lr_decay_step_size):
                    scheduler.step()
        
        phase_epochs_completed = 0
        for epoch in range(phase_cfg['epochs']):
            epoch_start_time = time.time()
        
            train_metrics = train_epoch(
                model, train_loader, optimizer, device, logger,
                lambda_causal=phase_cfg['lambda_causal'],
                epoch=total_epochs,
                phase=phase_cfg['name'],
                config=config
            )
            
            val_metrics = validate_epoch(model, val_loader, device, logger, stats)
            
            scheduler.step()
            total_scheduler_steps += 1
            current_lr = optimizer.param_groups[0]['lr']
            
            history['train_loss'].append(train_metrics['total_loss'])
            history['val_loss'].append(val_metrics['loss'])
            history['causal_loss'].append(train_metrics['causal_loss'])
            history['val_mae'].append(val_metrics['mae'])
            history['val_rmse'].append(val_metrics['rmse'])
            history['val_mape'].append(val_metrics['mape'])
            history['learning_rate'].append(current_lr)
            history['phase'].append(phase_name)
            
            epoch_time = time.time() - epoch_start_time
            
            logger.info(
                f"Phase {phase_idx+1} Epoch {epoch+1}/{phase_cfg['epochs']} | "
                f"Train Loss: {train_metrics['total_loss']:.4f} | "
                f"Val Loss: {val_metrics['loss']:.4f} | "
                f"MAE: {val_metrics['mae']:.4f} | "
                f"RMSE: {val_metrics['rmse']:.4f} | "
                f"MAPE: {val_metrics['mape']:.2f}% | "
                f"LR: {current_lr:.2e} | "
                f"Time: {epoch_time:.1f}s"
            )
            
            is_best_mae = val_metrics['mae'] < best_val_mae
            is_best_rmse = val_metrics['rmse'] < best_val_rmse
            is_best_mape = val_metrics['mape'] < best_val_mape
            is_best_loss = val_metrics['loss'] < best_val_loss
            
            if is_best_mae:
                best_val_mae = val_metrics['mae']
                logger.info(f"New best MAE: {best_val_mae:.4f}")
            
            if is_best_rmse:
                best_val_rmse = val_metrics['rmse']
                logger.info(f"New best RMSE: {best_val_rmse:.4f}")
            
            if is_best_mape:
                best_val_mape = val_metrics['mape']
                logger.info(f"New best MAPE: {best_val_mape:.2f}%")
            
            if is_best_loss:
                best_val_loss = val_metrics['loss']
            
            if val_metrics['loss'] < phase_best_loss:
                phase_best_loss = val_metrics['loss']
                phase_patience_counter = 0
            else:
                phase_patience_counter += 1
            
            if (total_epochs + 1) % 10 == 0 or is_best_mae or is_best_rmse or is_best_mape:
                save_checkpoint(
                    model, optimizer, scheduler, total_epochs,
                    {**train_metrics, **val_metrics},
                    config, logger, 
                    is_best_mae=is_best_mae,
                    is_best_rmse=is_best_rmse,
                    is_best_mape=is_best_mape
                )
            
            phase_patience = patience * 2
            if phase_patience_counter >= phase_patience:
                logger.warning(f"Phase {phase_idx+1} early stopping triggered")
                break
            
            total_epochs += 1
            phase_epochs_completed += 1
        
        logger.info(f"Phase {phase_idx+1} completed. Epochs trained: {phase_epochs_completed}")
    
    logger.info("Three-phase training completed")
    logger.info(f"Best MAE: {best_val_mae:.4f}")
    logger.info(f"Best RMSE: {best_val_rmse:.4f}")
    logger.info(f"Best MAPE: {best_val_mape:.2f}%")
    logger.info(f"Total epochs: {total_epochs}")
    
    return history

def main():
    parser = argparse.ArgumentParser(description='Orion Model Training')
    parser.add_argument('--config', default='configurations/Orion_PEMS08_config.conf', 
                       help='Configuration file path')
    parser.add_argument('--resume', default=None, 
                       help='Resume training from checkpoint')
    parser.add_argument('--debug', action='store_true', 
                       help='Debug mode')
    parser.add_argument('--seed', type=int, default=42,
                       help='Random seed')
    args = parser.parse_args()
    
    set_random_seed(args.seed)
    
    config = configparser.ConfigParser()
    config.read(args.config)
    
    device = torch.device(config['Training']['device'] if torch.cuda.is_available() else 'cpu')
    
    logger = setup_logging(config)
    
    try:
        logger.info(f"Config: {args.config}")
        logger.info(f"Device: {device}")
        logger.info(f"Random seed: {args.seed}")
        
        train_data, val_data, stats, adj_matrix = load_data(config, logger)
        
        train_loader, val_loader = create_data_loaders(train_data, val_data, config, logger)
        
        model, optimizer, scheduler = setup_model_and_optimizer(config, logger, device)
        
        start_epoch = 0
        if args.resume and os.path.exists(args.resume):
            logger.info(f"Resuming from checkpoint: {args.resume}")
            checkpoint = torch.load(args.resume, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            start_epoch = checkpoint['epoch'] + 1
            logger.info(f"Resuming from epoch {start_epoch}")
        
        training_start_time = time.time()
        logger.info("Starting training...")
        
        history = three_phase_training(
            model, train_loader, val_loader, optimizer, scheduler,
            config, stats, logger, device
        )
        
        verify_phase_differences(history, config, logger)
        
        training_time = time.time() - training_start_time
        logger.info("Training completed")
        logger.info(f"Total training time: {training_time/3600:.2f} hours")
        logger.info(f"Best validation loss: {min(history['val_loss']):.4f}")
        logger.info(f"Best validation MAE: {min(history['val_mae']):.4f}")
        logger.info(f"Best validation RMSE: {min(history['val_rmse']):.4f}")
        logger.info(f"Best validation MAPE: {min(history['val_mape']):.2f}%")
        
    except KeyboardInterrupt:
        logger.warning("Training interrupted by user")
    except Exception as e:
        logger.error(f"Error during training: {str(e)}")
        import traceback
        logger.error(traceback.format_exc())
        raise
    finally:
        logger.info("Training script finished")

if __name__ == '__main__':
    main()