import sys
import os
import time
import pickle
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

# AMP (Automatic Mixed Precision) imports
from torch.cuda.amp import GradScaler, autocast

# Optimizer and Scheduler imports
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from args import get_args
# --- 核心修改: 从 data_utils_fix.py 导入新组件 ---
from data_utils_fix import get_segment_dataset, get_model, get_loss_func, collate_segment_batch
from utils import get_seed, get_num_params

def train(model, loss_func, metric_func,
              train_loader, valid_loader,
              optimizer, lr_scheduler,
              args,
              device="cuda",
              start_epoch: int = 0,
              model_filename='model.pt',
              result_filename='result.pt'):
    
    loss_train_history, loss_val_history = [], []
    best_val_metric = float('inf')
    best_val_epoch = -1
    
    scaler = GradScaler(enabled=args.amp)

    print("\n--- 开始长时序自回归训练 ---")
    if args.amp:
        print("混合精度训练 (AMP) 已启用。")
    print(f"训练步长 (Rollout Steps): {args.unroll_steps}")
    print(f"计划采样 (Scheduled Sampling): start={args.tf_start}, end={args.tf_end}, decay_epochs={args.tf_decay_epochs}")

    for epoch in range(start_epoch, start_epoch + args.epochs):
        
        avg_train_loss = train_epoch_rollout(
            model=model,
            train_loader=train_loader,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            loss_func=loss_func,
            device=device,
            epoch=epoch,
            args=args,
            scaler=scaler
        )
        loss_train_history.append(avg_train_loss)

        if (epoch + 1) % args.validation_freq == 0 or (epoch + 1) == args.epochs:
            avg_val_metric = validate_epoch_rollout(
                model=model,
                valid_loader=valid_loader,
                metric_func=metric_func,
                device=device,
                args=args
            )
            loss_val_history.append(avg_val_metric)
            
            best_tag = ""
            if avg_val_metric < best_val_metric:
                best_val_metric = avg_val_metric
                best_val_epoch = epoch + 1
                best_tag = "🚀 (Best)"
                checkpoint = {'args': args, 'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
                best_model_path = os.path.join(args.model_save_path, model_filename.replace('.pt', '_best.pt'))
                torch.save(checkpoint, best_model_path)

            log_str = f"Epoch [{epoch+1}/{args.epochs}] | Train Loss: {avg_train_loss:.5f} | Val Metric ({args.val_unroll_steps}-step): {avg_val_metric:.5f} | Best: {best_val_metric:.5f} at Ep {best_val_epoch} {best_tag}"
            print("-" * (len(log_str) + 2)); print(f" {log_str} "); print("-" * (len(log_str) + 2))
            
            if args.writer:
                args.writer.add_scalar('Metric/validation_rollout', avg_val_metric, epoch)
        else:
            print(f"Epoch [{epoch+1}/{args.epochs}] | Train Loss: {avg_train_loss:.5f} | (Validation Skipped)")
        
        if args.writer:
            args.writer.add_scalar('Loss/train_rollout', avg_train_loss, epoch)
            args.writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], epoch)
    
    print(f"\n--- 训练完成 ---")
    if best_val_epoch != -1:
        print(f"最佳验证指标 (基于 {args.val_unroll_steps}-step rollout): {best_val_metric:.6f} (在第 {best_val_epoch} 个epoch达到)")
    
    result = {
        'best_val_epoch': best_val_epoch,
        'best_val_metric': best_val_metric,
        'loss_train': np.asarray(loss_train_history),
        'loss_val': np.asarray(loss_val_history),
    }
    result_path = os.path.join(args.model_save_path, result_filename)
    with open(result_path, 'wb') as f:
        pickle.dump(result, f)
    
    return result

def train_epoch_rollout(model, train_loader, optimizer, lr_scheduler, loss_func, device, epoch, args, scaler):
    model.train()
    total_loss = 0.0
    
    tf_decay_epochs = args.tf_decay_epochs
    if tf_decay_epochs <= 0:
        tf_rate = args.tf_end
    else:
        progress = min(epoch / tf_decay_epochs, 1.0)
        tf_rate = args.tf_start * (1 - progress) + args.tf_end * progress

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs} [Training Rollout, TF Rate: {tf_rate:.2f}]", leave=False)
    
    for g_batch, coords_batch, u_p_batch, inputs_f_batch, features_batch, targets_batch in pbar:
        # 1. 将数据移动到设备
        g_batch = g_batch.to(device)
        coords_batch = coords_batch.to(device, non_blocking=True)
        u_p_batch = u_p_batch.to(device, non_blocking=True)
        # inputs_f_batch (if used)
        features_batch = features_batch.to(device, non_blocking=True)
        targets_batch = targets_batch.to(device, non_blocking=True)
        
        B, T, N, C_state = features_batch.shape
        
        # 2. 获取初始输入
        current_state = features_batch[:, 0, :, :]
        
        optimizer.zero_grad(set_to_none=True)
        rollout_loss = torch.tensor(0.0, device=device)

        with autocast(enabled=args.amp):
            # 3. 自回归循环
            for k in range(args.unroll_steps):
                # a. 准备模型输入: 将当前状态设置到DGL图中
                g_batch.ndata['x'] = current_state.reshape(-1, C_state)

                # b. 模型前向传播 (传递分离的坐标和状态)
                pred_out, _ = model(g_batch, coords_batch, u_p_batch, inputs_f_batch)
                pred_state = pred_out.view(B, N, C_state)
                
                # c. 获取当前步的目标真值
                target_state = targets_batch[:, k, :, :]
                
                # d. 计算并累加损失
                loss_k, _, _ = loss_func(g_batch, pred_out, target_state.reshape(-1, C_state))
                rollout_loss += loss_k

                # e. 计划采样：决定下一步的输入
                if k < args.unroll_steps - 1:
                    use_teacher = torch.rand(1).item() < tf_rate
                    if use_teacher:
                        current_state = features_batch[:, k + 1, :, :].detach()
                    else:
                        current_state = pred_state.detach()

        # 4. 对多步的损失求平均并反向传播
        avg_rollout_loss = rollout_loss / args.unroll_steps
        
        scaler.scale(avg_rollout_loss).backward()
        if args.grad_clip > 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        scaler.step(optimizer)
        scaler.update()

        if lr_scheduler and not isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            lr_scheduler.step()

        total_loss += avg_rollout_loss.item()
        pbar.set_postfix(loss=avg_rollout_loss.item())

    return total_loss / len(train_loader)

def validate_epoch_rollout(model, valid_loader, metric_func, device, args):
    model.eval()
    
    # --- 修正点 1: 初始化为标量 ---
    total_metric_sum_of_batches = 0.0  
    
    pbar = tqdm(valid_loader, desc=f"Validating (Rollout {args.val_unroll_steps} steps)", leave=False)
    
    with torch.no_grad():
        for g_batch, coords_batch, u_p_batch, inputs_f_batch, features_batch, targets_batch in pbar:
            # 数据移动到设备 (保持不变)
            g_batch = g_batch.to(device)
            coords_batch = coords_batch.to(device, non_blocking=True)
            u_p_batch = u_p_batch.to(device, non_blocking=True)
            features_batch = features_batch.to(device, non_blocking=True)
            targets_batch = targets_batch.to(device, non_blocking=True)

            B, _, N, C_state = features_batch.shape

            current_state = features_batch[:, 0, :, :]
            # --- 修正点 2: 确保累加器初始化为标量，并进行累加 ---
            rollout_metric_sum_for_batch = 0.0

            with autocast(enabled=args.amp):
                for k in range(args.val_unroll_steps):
                    g_batch.ndata['x'] = current_state.reshape(-1, C_state)
                    
                    pred_out, _ = model(g_batch, coords_batch, u_p_batch, inputs_f_batch)
                    pred_state = pred_out.view(B, N, C_state)
                    
                    target_state = targets_batch[:, k, :, :]
                    
                    _, _, metric_k = metric_func(g_batch, pred_out, target_state.reshape(-1, C_state))
                    
                    # --- 修正点 3: 核心修正！使用 np.mean() 确保累加的是标量 ---
                    # metric_k 是一个 NumPy 数组，使用 np.mean() 提取其平均值 (标量)
                    rollout_metric_sum_for_batch += np.mean(metric_k)
                    
                    # 纯自回归：始终使用模型预测作为下一步输入
                    if k < args.val_unroll_steps - 1:
                        current_state = pred_state
            
            # 修正点 4: 累加到总和中
            avg_rollout_metric_for_batch = rollout_metric_sum_for_batch / args.val_unroll_steps
            total_metric_sum_of_batches += avg_rollout_metric_for_batch # 累加标量
            # total_metric 命名现在改为 total_metric_sum_of_batches 避免混淆

    # 修正点 5: 返回总平均值 (现在必然是标量)
    return total_metric_sum_of_batches / len(valid_loader)

if __name__ == "__main__":
    args = get_args()
    
    # --- 为多步训练添加必要的参数 (如果args.py中没有) ---
    defaults = {
        'unroll_steps': 5, 'val_unroll_steps': 10, 'segment_length': 15,
        'tf_start': 1.0, 'tf_end': 0.0, 'tf_decay_epochs': 50,
        'train_num': 50000, 'test_num': 5000, 'num_train_sims': 1000,
        'validation_freq': 1, 'grad_clip': 1.0
    }
    for key, value in defaults.items():
        if not hasattr(args, key):
            setattr(args, key, value)
    
    # 参数有效性检查
    if args.unroll_steps >= args.segment_length:
        raise ValueError(f"unroll_steps ({args.unroll_steps}) must be less than segment_length ({args.segment_length})")
    if args.val_unroll_steps >= args.segment_length:
        raise ValueError(f"val_unroll_steps ({args.val_unroll_steps}) must be less than segment_length ({args.segment_length})")

    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() and not getattr(args, 'no_cuda', False) else 'cpu')
    get_seed(args.seed, printout=True)
    
    # --- 1. 使用新的 data_utils_fix.py 加载数据 ---
    print("正在加载和创建分段数据集...")
    train_dataset, test_dataset = get_segment_dataset(args)

    # --- 2. 创建 DataLoader ---
    dl_args = {'num_workers': args.num_workers, 'pin_memory': True if device.type == 'cuda' else False}
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_segment_batch, **dl_args)
    valid_loader = DataLoader(test_dataset, batch_size=args.val_batch_size, shuffle=False, drop_last=False, collate_fn=collate_segment_batch, **dl_args)
    print("数据加载器准备完毕。")

    # --- 3. 模型、损失函数、优化器等设置 ---
    args.normalizer = args.y_normalizer.to(device) if hasattr(args, 'y_normalizer') and args.y_normalizer is not None else None

    # get_model 现在会使用 args.dataset_config
    model = get_model(args).to(device)
    print(f"\nModel: {model.__name__}\t Number of params: {get_num_params(model)}")

    loss_func = get_loss_func(name=args.loss_name, args=args, normalizer=args.normalizer)
    metric_func = get_loss_func(name='rel2', args=args, normalizer=args.normalizer)

    path_prefix = f"{args.dataset}_{args.model_name}_{args.comment}_rollout{args.unroll_steps}_{time.strftime('%m%d_%H%M')}"
    model_filename = path_prefix + '.pt'
    result_filename = path_prefix + '.pkl'
    args.model_save_path = './data/checkpoints/'
    if not os.path.exists(args.model_save_path): os.makedirs(args.model_save_path)

    if args.use_tb:
        writer_path = './data/logs/' + path_prefix
        args.writer = SummaryWriter(log_dir=writer_path)
    else:
        args.writer = None

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = OneCycleLR(optimizer, max_lr=args.lr, epochs=args.epochs, steps_per_epoch=len(train_loader))
    
    time_start = time.time()

    # --- 4. 启动训练 ---
    train(model, loss_func, metric_func,
          train_loader, valid_loader,
          optimizer, scheduler,
          args=args, device=device,
          model_filename=model_filename,
          result_filename=result_filename)

    print('总训练时间: {:.2f} 秒'.format(time.time() - time_start))

    final_model_path = os.path.join(args.model_save_path, model_filename)
    checkpoint = {'args': args, 'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
    torch.save(checkpoint, final_model_path)
    print(f"最终模型已保存到: {final_model_path}")

    print("\n--- 在独立的测试集上进行最终评估 (自回归) ---")
    final_test_metric = validate_epoch_rollout(model, valid_loader, metric_func, device, args)
    print(f"最终模型在测试集上的 {args.val_unroll_steps}-step rollout 指标: {final_test_metric:.6f}")