import os
import matplotlib.pyplot as plt
import argparse
import scipy.io as scio
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

# 导入新的模型字典和与 exp_ns.py 相同的损失函数
from model_dict_adaptive import get_model
from utils.testloss import TestLoss

def get_args():
    """
    为自适应模型设计的命令行参数解析器，适配 exp_ns.py 风格。
    """
    parser = argparse.ArgumentParser('Adaptive Transolver Training for NS2D (exp_ns style)')

    # --- 训练超参数 (与 exp_ns.py 保持一致) ---
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--epochs', type=int, default=500)
    parser.add_argument('--weight_decay', type=float, default=1e-5)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--max_grad_norm', type=float, default=None)
    
    # --- 核心路径与名称 ---
    parser.add_argument('--data_path', type=str, required=True, help='NavierStokes .mat 文件的完整路径')
    parser.add_argument('--save_name', type=str, default='ns_adaptive', help='检查点保存的文件名')
    parser.add_argument("--gpu", type=str, default='0', help="GPU index to use")

    # --- 模型选择 ---
    parser.add_argument('--model', type=str, default='StructuredAdaptiveTransolver', 
                        help="要训练的自适应模型的名称")

    # --- 模型架构参数 ---
    parser.add_argument('--n-hidden', type=int, default=64, help='隐藏层维度')
    parser.add_argument('--n-layers', type=int, default=3, help='递归层数 (深度)')
    parser.add_argument('--n-head', type=int, default=4, help="注意力头数")
    parser.add_argument('--mlp_ratio', type=int, default=1, help='MLP层的膨胀比例')
    parser.add_argument('--dropout', type=float, default=0.0, help='Dropout比率')
    parser.add_argument('--slice_num', type=int, default=32, help='Physics-Attention中的切片数')
    parser.add_argument('--unified_pos', type=int, default=0)
    parser.add_argument('--ref', type=int, default=8)
    
    # --- 自适应模型专用参数 ---
    parser.add_argument('--capacity_ratios', type=float, nargs='+', required=True,
                        help='[必需] 自适应模型的容量规划列表')

    return parser.parse_args()

def count_parameters(model):
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        total_params += parameter.numel()
    print(f"Total Trainable Params: {total_params}")
    return total_params

def main():
    args = get_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # --- 数据加载参数 (与 exp_ns.py 保持一致) ---
    ntrain = 1000
    ntest = 200
    T_in = 10
    T_pred = 10 # 在 exp_ns.py 中变量名为 T
    step = 1
    r = 1 # exp_ns.py 中 downsample 默认为 1
    h = int(((64 - 1) / r) + 1)

    # --- 1. 数据加载 (严格遵循 exp_ns.py 逻辑) ---
    print("--- 1. Loading and Processing NS Dataset ---")
    data = scio.loadmat(args.data_path)
    
    # a. 划分输入历史 (a) 和未来真值 (u)
    train_a = torch.from_numpy(data['u'][:ntrain, ::r, ::r, :T_in][:, :h, :h, :]).float()
    train_u = torch.from_numpy(data['u'][:ntrain, ::r, ::r, T_in:T_in+T_pred][:, :h, :h, :]).float()
    test_a = torch.from_numpy(data['u'][-ntest:, ::r, ::r, :T_in][:, :h, :h, :]).float()
    test_u = torch.from_numpy(data['u'][-ntest:, ::r, ::r, T_in:T_in+T_pred][:, :h, :h, :]).float()

    # b. 将数据从 [B, H, W, T] -> [B, N, T]
    train_a = train_a.reshape(ntrain, -1, T_in)
    train_u = train_u.reshape(ntrain, -1, T_pred)
    test_a = test_a.reshape(ntest, -1, T_in)
    test_u = test_u.reshape(ntest, -1, T_pred)

    # c. 创建位置坐标
    gridx = torch.tensor(np.linspace(0, 1, h), dtype=torch.float).reshape(1, h, 1, 1).repeat(h, 1, 1, 1)
    gridy = torch.tensor(np.linspace(0, 1, h), dtype=torch.float).reshape(h, 1, 1, 1).repeat(1, h, 1, 1)
    pos = torch.cat((gridx.permute(1,0,2,3), gridy.permute(1,0,2,3)), dim=-1).reshape(1, -1, 2)
    pos_train = pos.repeat(ntrain, 1, 1)
    pos_test = pos.repeat(ntest, 1, 1)

    # d. 创建 DataLoader
    train_loader = DataLoader(TensorDataset(pos_train, train_a, train_u), batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(TensorDataset(pos_test, test_a, test_u), batch_size=args.batch_size, shuffle=False)

    print("Dataloading is over.")
    print("-" * 50)

    # --- 2. 模型创建 (核心修改：out_dim=1) ---
    print(f"--- 2. Creating Model: {args.model} ---")
    model_class = get_model(args).Model
    
    model = model_class(
        **vars(args),
        space_dim=2,
        fun_dim=T_in,
        out_dim=step, # 模型被设计为单步预测器
        H=h, W=h,
        
    ).to(device)

    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, epochs=args.epochs,
                                              steps_per_epoch=len(train_loader))
    myloss = TestLoss(size_average=False)

    print(args)
    count_parameters(model)
    print("-" * 50)

    # --- 3. 训练与评估循环 (完全替换为 exp_ns.py 的教师强制逻辑) ---
    print(f"--- 3. Starting Training (Teacher Forcing) for {args.epochs} Epochs ---")
    
    for ep in range(args.epochs):
        model.train()
        train_l2_step = 0
        train_l2_full = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {ep+1}/{args.epochs} [Training]")
        for x, fx, yy in pbar:
            loss = 0
            x, fx, yy = x.to(device), fx.to(device), yy.to(device)
            bsz = x.shape[0]

            # a. 教师强制循环
            for t in range(0, T_pred, step):
                y = yy[..., t:t + step]
                im, _ = model(x, fx=fx) # 模型进行单步预测
                loss += myloss(im.reshape(bsz, -1), y.reshape(bsz, -1))

                if t == 0:
                    pred = im
                else:
                    pred = torch.cat((pred, im), -1)
                
                fx = torch.cat((fx[..., step:], y), dim=-1) # 使用真实标签 y 更新下一个输入
            
            train_l2_step += loss.item()
            train_l2_full += myloss(pred.reshape(bsz, -1), yy.reshape(bsz, -1)).item()
            
            optimizer.zero_grad()
            loss.backward()
            if args.max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            scheduler.step()
            
            pbar.set_postfix(step_loss=loss.item() / bsz / (T_pred/step))

        # 评估 (使用纯自回归)
        test_l2_step = 0
        test_l2_full = 0
        model.eval()
        with torch.no_grad():
            for x, fx, yy in test_loader:
                loss = 0
                x, fx, yy = x.to(device), fx.to(device), yy.to(device)
                bsz = x.shape[0]

                # b. 纯自回归循环
                for t in range(0, T_pred, step):
                    y = yy[..., t:t + step]
                    im, _ = model(x, fx=fx)
                    loss += myloss(im.reshape(bsz, -1), y.reshape(bsz, -1))

                    if t == 0:
                        pred = im
                    else:
                        pred = torch.cat((pred, im), -1)
                    
                    fx = torch.cat((fx[..., step:], im), dim=-1) # 使用模型预测 im 更新下一个输入

                test_l2_step += loss.item()
                test_l2_full += myloss(pred.reshape(bsz, -1), yy.reshape(bsz, -1)).item()

        # --- 4. 打印与 exp_ns.py 一致的损失信息 ---
        print(
            "Epoch {} , train_step_loss:{:.5f} , train_full_loss:{:.5f} , test_step_loss:{:.5f} , test_full_loss:{:.5f}".format(
                ep+1, train_l2_step / ntrain / (T_pred / step), train_l2_full / ntrain, 
                test_l2_step / ntest / (T_pred / step), test_l2_full / ntest
            )
        )

        if ep  % 100 == 0:
            if not os.path.exists('./checkpoints'):
                os.makedirs('./checkpoints')
            print('save model')
            torch.save(model.state_dict(), os.path.join('./checkpoints', args.save_name + f'_ep{ep+1}.pt'))
            
    print("--- Training Complete ---")
    if not os.path.exists('./checkpoints'):
        os.makedirs('./checkpoints')
    print('save final model')
    torch.save(model.state_dict(), os.path.join('./checkpoints', args.save_name + '_final.pt'))

if __name__ == "__main__":
    main()