
import os
import argparse
import numpy as np
import torch
from tqdm import tqdm

# 导入新的模型字典和与 exp_pipe.py 相同的工具
from model_dict_adaptive import get_model
from utils.testloss import TestLoss
from utils.normalizer import UnitTransformer

def count_parameters(model):
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        total_params += parameter.numel()
    print(f"Total Trainable Params: {total_params}")
    return total_params

def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    # --- 1. 数据处理 (严格遵循 exp_pipe.py) ---
    print("--- 1. Loading and Processing Pipe Dataset ---")
    INPUT_X = os.path.join(args.data_path, 'Pipe_X.npy')
    INPUT_Y = os.path.join(args.data_path, 'Pipe_Y.npy')
    OUTPUT_Sigma = os.path.join(args.data_path, 'Pipe_Q.npy')

    ntrain = 1000
    ntest = 200
    N = 1200 # 总样本数

    # a. 读取和堆叠坐标
    inputX = np.load(INPUT_X)
    inputX = torch.tensor(inputX, dtype=torch.float)
    inputY = np.load(INPUT_Y)
    inputY = torch.tensor(inputY, dtype=torch.float)
    input_coords = torch.stack([inputX, inputY], dim=-1)

    # b. 读取物理量
    output_q = np.load(OUTPUT_Sigma)[:, 0] # 取第一个通道
    output_q = torch.tensor(output_q, dtype=torch.float)

    print(f"Initial shapes: Coords={input_coords.shape}, Output={output_q.shape}")
    s1, s2 = input_coords.shape[1], input_coords.shape[2] # 获取网格尺寸

    # c. 划分训练/测试集
    x_train = input_coords[:ntrain]
    y_train = output_q[:ntrain]
    x_test = input_coords[-ntest:]
    y_test = output_q[-ntest:]

    # d. 展平空间维度
    x_train = x_train.reshape(ntrain, -1, 2)
    x_test = x_test.reshape(ntest, -1, 2)
    y_train = y_train.reshape(ntrain, -1)
    y_test = y_test.reshape(ntest, -1)

    # e. 数据归一化
    x_normalizer = UnitTransformer(x_train)
    y_normalizer = UnitTransformer(y_train)
    x_train = x_normalizer.encode(x_train)
    x_test = x_normalizer.encode(x_test)
    y_train = y_normalizer.encode(y_train)

    x_normalizer.cuda()
    y_normalizer.cuda()

    # f. 创建 DataLoader (注意 TensorDataset 的输入)
    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, x_train, y_train),
                                               batch_size=args.batch_size, shuffle=True, num_workers=4)
    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, x_test, y_test),
                                              batch_size=args.batch_size, shuffle=False, num_workers=4)

    print("Data processing complete.")
    print("-" * 50)
    
    # --- [核心修改] 验证自适应模型必需的参数 ---
    if len(args.capacity_ratios) != args.n_layers:
        raise ValueError(
            f"参数错误: --capacity_ratios 的长度 ({len(args.capacity_ratios)}) "
            f"必须与 --n-layers ({args.n_layers}) 完全匹配。"
        )

    # --- 2. 模型创建 ---
    print(f"--- 2. Creating Model: {args.model} ---")
    model = get_model(args).Model(
        space_dim=2,
        fun_dim=0,
        out_dim=1,
        H=s1, W=s2,
        n_hidden=args.n_hidden,
        n_layers=args.n_layers,
        n_head=args.n_head,
        mlp_ratio=args.mlp_ratio,
        dropout=args.dropout,
        slice_num=args.slice_num,
        capacity_ratios=args.capacity_ratios
    ).cuda()

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, epochs=args.epochs,
                                                    steps_per_epoch=len(train_loader))
    myloss = TestLoss(size_average=False)

    print(args)
    count_parameters(model)
    print("-" * 50)

    # --- 3. 训练与评估循环 (严格遵循 exp_pipe.py) ---
    print(f"--- 3. Starting Training for {args.epochs} Epochs ---")
    best_rel_err = float('inf')

    for ep in range(args.epochs):
        model.train()
        train_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {ep+1}/{args.epochs}")

        # 注意这里的循环变量名 pos, fx, y
        for pos, fx, y in pbar:
            x, y = pos.cuda(), y.cuda()

            optimizer.zero_grad()
            out, _ = model(x, None) # 明确调用，fx=None

            # 在计算损失前进行反归一化
            out = y_normalizer.decode(out.squeeze(-1))
            _y = y_normalizer.decode(y) # 创建一个新变量以避免覆盖

            loss = myloss(out, _y)
            loss.backward()

            if args.max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            
            train_loss += loss.item()
            scheduler.step()
            pbar.set_postfix(step_loss=loss.item() / x.size(0))

        train_loss = train_loss / ntrain
        
        # 评估
        model.eval()
        rel_err = 0.0
        with torch.no_grad():
            for pos, fx, y in test_loader:
                x, y = pos.cuda(), y.cuda()
                out, _ = model(x, None)
                out = y_normalizer.decode(out.squeeze(-1))
                
                rel_err += myloss(out, y).item()

        rel_err /= ntest
        
        print(f"Epoch {ep+1} | Train Loss: {train_loss:.5f} | Test Rel Error: {rel_err:.5f}")

        if rel_err < best_rel_err:
            best_rel_err = rel_err
            print(f"🚀 New best model found! Saving to checkpoints/{args.save_name}.pt")
            if not os.path.exists('./checkpoints'):
                os.makedirs('./checkpoints')
            torch.save(model.state_dict(), os.path.join('./checkpoints', args.save_name + '.pt'))

    print("--- Training Complete ---")


if __name__ == "__main__":
    parser = argparse.ArgumentParser('Adaptive Transolver Training for Pipe Flow (exp_pipe style)')
    
    # 命令行参数 (严格基于 exp_pipe.py 并进行扩展)
    parser.add_argument('--data_path', type=str, required=True, help='Path to the directory containing pipe .npy files')
    parser.add_argument('--save_name', type=str, default='pipe_adaptive', help='File name for the saved model')
    parser.add_argument('--gpu', type=str, default='0', help="GPU index to use")

    parser.add_argument('--model', type=str, default='PipeAdaptiveTransolver', help="Model's name")
    
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--epochs', type=int, default=500)
    parser.add_argument('--weight_decay', type=float, default=1e-5)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--max_grad_norm', type=float, default=None)

    parser.add_argument('--n-hidden', type=int, default=64)
    parser.add_argument('--n-layers', type=int, default=8)
    parser.add_argument('--n-head', type=int, default=4)
    parser.add_argument('--mlp_ratio', type=int, default=1)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--slice_num', type=int, default=32)
    parser.add_argument('--unified_pos', type=int, default=0)
    parser.add_argument('--ref', type=int, default=8)
    parser.add_argument('--eval', type=int, default=0)
    parser.add_argument('--capacity_ratios', type=float, nargs='+', required=True,
                        help='[必需] 每个递归层的保留比例列表，长度必须等于 n-layers')
    
    args = parser.parse_args()
    main(args)