import os
import argparse
import numpy as np
import torch
from tqdm import tqdm

# 导入新的模型字典和与 exp_airfoil.py 相同的工具
from model_dict_adaptive import get_model
from utils.testloss import TestLoss

def count_parameters(model):
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        total_params += parameter.numel()
    print(f"Total Trainable Params: {total_params}")
    return total_params

def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    # --- 1. 数据处理 (严格遵循 exp_airfoil.py) ---
    print("--- 1. Loading and Processing Airfoil Dataset ---")
    INPUT_X = os.path.join(args.data_path, 'NACA_Cylinder_X.npy')
    INPUT_Y = os.path.join(args.data_path, 'NACA_Cylinder_Y.npy')
    OUTPUT_Sigma = os.path.join(args.data_path, 'NACA_Cylinder_Q.npy')

    ntrain = 1000
    ntest = 200

    # a. 获取网格尺寸
    r1, r2 = 1, 1 # airfoil 脚本中未使用下采样
    s1 = int(((221 - 1) / r1) + 1)
    s2 = int(((51 - 1) / r2) + 1)
    
    # b. 读取和堆叠坐标
    inputX = np.load(INPUT_X)
    inputX = torch.tensor(inputX, dtype=torch.float)
    inputY = np.load(INPUT_Y)
    inputY = torch.tensor(inputY, dtype=torch.float)
    input_coords = torch.stack([inputX, inputY], dim=-1)

    # c. 读取物理量 (注意通道索引为4)
    output_q = np.load(OUTPUT_Sigma)[:, 4]
    output_q = torch.tensor(output_q, dtype=torch.float)
    
    print(f"Initial shapes: Coords={input_coords.shape}, Output={output_q.shape}")

    # d. 划分训练/测试集
    x_train = input_coords[:ntrain, ::r1, ::r2][:, :s1, :s2]
    y_train = output_q[:ntrain, ::r1, ::r2][:, :s1, :s2]
    x_test = input_coords[ntrain:ntrain + ntest, ::r1, ::r2][:, :s1, :s2]
    y_test = output_q[ntrain:ntrain + ntest, ::r1, ::r2][:, :s1, :s2]
    
    # e. 展平空间维度
    x_train = x_train.reshape(ntrain, -1, 2)
    x_test = x_test.reshape(ntest, -1, 2)
    y_train = y_train.reshape(ntrain, -1)
    y_test = y_test.reshape(ntest, -1)
    
    # f. 创建 DataLoader
    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, x_train, y_train),
                                               batch_size=args.batch_size, shuffle=True, num_workers=4)
    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, x_test, y_test),
                                              batch_size=args.batch_size, shuffle=False, num_workers=4)

    print("Data processing complete.")
    print("-" * 50)
    
    # --- [核心修改] 验证自适应模型必需的参数 ---
    if len(args.capacity_ratios) != args.n_layers:
        raise ValueError(
            f"参数错误: --capacity_ratios 的长度 ({len(args.capacity_ratios)}) "
            f"必须与 --n-layers ({args.n_layers}) 完全匹配。"
        )

    # --- 2. 模型创建 ---
    print(f"--- 2. Creating Model: {args.model} ---")
    model = get_model(args).Model(
        space_dim=2,
        out_dim=1,
        H=s1, W=s2,
        n_hidden=args.n_hidden,
        n_layers=args.n_layers,
        n_head=args.n_head,
        mlp_ratio=args.mlp_ratio,
        dropout=args.dropout,
        slice_num=args.slice_num,
        capacity_ratios=args.capacity_ratios,
        # 确保传递 exp_airfoil.py 特有的参数
        ref=args.ref,
        unified_pos=args.unified_pos,
    ).cuda()

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, epochs=args.epochs,
                                                    steps_per_epoch=len(train_loader))
    myloss = TestLoss(size_average=False)
    
    print(args)
    count_parameters(model)
    print("-" * 50)
    
    # --- 3. 训练与评估循环 (严格遵循 exp_airfoil.py) ---
    print(f"--- 3. Starting Training for {args.epochs} Epochs ---")
    best_rel_err = float('inf')
    
    for ep in range(args.epochs):
        model.train()
        train_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {ep+1}/{args.epochs}")

        for pos, fx, y in pbar:
            x, y = pos.cuda(), y.cuda()

            optimizer.zero_grad()
            out, _ = model(x, None)
            loss = myloss(out.squeeze(-1), y)
            loss.backward()

            if args.max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            
            train_loss += loss.item()
            scheduler.step()
            pbar.set_postfix(step_loss=loss.item() / x.size(0))

        train_loss = train_loss / ntrain
        
        # 评估
        model.eval()
        rel_err = 0.0
        with torch.no_grad():
            for pos, fx, y in test_loader:
                x, y = pos.cuda(), y.cuda()
                out, _ = model(x, None)
                rel_err += myloss(out.squeeze(-1), y).item()

        rel_err /= ntest
        
        print(f"Epoch {ep+1} | Train loss: {train_loss:.5f} | Test Rel Error: {rel_err:.5f}")

        if rel_err < best_rel_err:
            best_rel_err = rel_err
            print(f"🚀 New best model found! Saving to checkpoints/{args.save_name}.pt")
            if not os.path.exists('./checkpoints'):
                os.makedirs('./checkpoints')
            torch.save(model.state_dict(), os.path.join('./checkpoints', args.save_name + '.pt'))

    print("--- Training Complete ---")


if __name__ == "__main__":
    parser = argparse.ArgumentParser('Adaptive Transolver Training for Airfoil Flow')
    
    # 命令行参数 (严格基于 exp_airfoil.py 并进行扩展)
    # a. 与 exp_airfoil.py 完全相同的参数
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--epochs', type=int, default=500)
    parser.add_argument('--weight_decay', type=float, default=1e-5)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument("--gpu", type=str, default='0', help="GPU index to use")
    parser.add_argument('--max_grad_norm', type=float, default=None)
    parser.add_argument('--save_name', type=str, default='airfoil_adaptive')
    parser.add_argument('--data_path', type=str, required=True)

    # b. 模型相关参数
    parser.add_argument('--model', type=str, default='PipeAdaptiveTransolver', help="Model's name")
    parser.add_argument('--n-hidden', type=int, default=64)
    parser.add_argument('--n-layers', type=int, default=3)
    parser.add_argument('--n-head', type=int, default=4)
    parser.add_argument('--mlp_ratio', type=int, default=1)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--slice_num', type=int, default=32)
    parser.add_argument('--unified_pos', type=int, default=0)
    parser.add_argument('--ref', type=int, default=8)

    # c. 自适应模型专用参数
    parser.add_argument('--capacity_ratios', type=float, nargs='+', required=True,
                        help='[必需] 每个递归层的保留比例列表，长度必须等于 n-layers')
    
    args = parser.parse_args()
    main(args)