import torch
import torch.nn as nn
import numpy as np
import argparse
from tqdm import tqdm
import time
import os
from torch.optim.lr_scheduler import OneCycleLR
from tensorboardX import SummaryWriter
import matplotlib
matplotlib.use('Agg') # 使用非交互式后端，防止在服务器上出错

# --- 导入模型、损失函数和工具 (与pipe版本相同) ---
from nn_module.encoder_module import SpatialEncoder2D
from nn_module.decoder_module import PointWiseDecoder2DSimple, AdaptivePointWiseDecoder2D_SteadyState
from loss_fn import pointwise_rel_l2norm_loss
from utils import load_checkpoint, save_checkpoint, ensure_dir
import matplotlib.pyplot as plt
import logging
import shutil
from torch.utils.data import Dataset, DataLoader, TensorDataset

# --- 设置随机种子 (与pipe版本相同) ---
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)
torch.multiprocessing.set_sharing_strategy('file_system')

# --- 模型构建函数 (与pipe版本几乎相同) ---
def build_model(args, res: int) -> (nn.Module, nn.Module):
    encoder = SpatialEncoder2D(
        2, 96, 256, 4, 6, res=res, use_ln=True
    )

    if args.adaptive:
        print("Building ADAPTIVE Decoder...")
        capacity_ratios = args.capacity_ratios
        if capacity_ratios is None:
            capacity_ratios = np.linspace(1.0, args.final_keep_ratio, args.propagator_depth).tolist()
        
        decoder = AdaptivePointWiseDecoder2D_SteadyState(
            latent_channels=256, out_channels=1,
            propagator_depth=args.propagator_depth,
            capacity_ratios=capacity_ratios,
            scale=0.5, heads=4, dim_head=32)
    else:
        print("Building DEEP (Dense) Decoder...")
        decoder = PointWiseDecoder2DSimple(
            latent_channels=256, out_channels=1,
            refinement_depth=args.propagator_depth,
            res=res, scale=0.5, heads=4, dim_head=32)

    total_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad) + \
                   sum(p.numel() for p in decoder.parameters() if p.requires_grad)
    print(f'Total trainable parameters: {total_params}')
    return encoder, decoder

# --- 可视化函数 (与pipe版本相同) ---
def make_scatter_plot(pos: torch.Tensor, u_pred: torch.Tensor, u_gt: torch.Tensor, out_path, num_samples=3):
    # ... (此函数代码与train_pipe.py完全相同，此处省略以保持简洁)
    pass

# --- 命令行参数函数 ---
def get_arguments(parser):
    parser.add_argument('--lr', type=float, default=5e-4, help='Learning rate')
    parser.add_argument('--iters', type=int, default=200000, help='Number of training iterations')
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size') # NACA点数更多，可适当减小bs
    parser.add_argument('--log_dir', type=str, default='./naca_exp', help='Path for logging and checkpoints')
    parser.add_argument('--ckpt_every', type=int, default=10000, help='Save checkpoint frequency')
    parser.add_argument('--resume', action='store_true', help='Resume training from checkpoint')
    parser.add_argument('--path_to_resume', type=str, default='', help='Path to checkpoint for resuming')
    
    # --- 关键修改：适配NACA数据集参数 ---
    parser.add_argument('--data_path', type=str, required=True, help='Base path to NACA_Cylinder .npy files')
    parser.add_argument('--ntrain', type=int, default=1000, help='Number of training samples')
    parser.add_argument('--ntest', type=int, default=200, help='Number of testing samples')
    parser.add_argument('--downsample_x', type=int, default=10, help='Downsampling ratio for the x-axis')
    parser.add_argument('--downsample_y', type=int, default=2, help='Downsampling ratio for the y-axis')

    # 模型相关参数 (与pipe版本相同)
    parser.add_argument('--adaptive', action='store_true', help='Use the adaptive decoder')
    parser.add_argument('--propagator_depth', type=int, default=4, help='Depth of the refinement/propagator network')
    parser.add_argument('--capacity_ratios', type=float, nargs='+', default=None, help='List of capacity ratios for each adaptive layer')
    parser.add_argument('--final_keep_ratio', type=float, default=0.25, help='Final keep ratio for default capacity decay')
    
    return parser

# --- Main Code ---
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train a transformer for the NACA_Cylinder dataset")
    parser = get_arguments(parser) 
    opt = parser.parse_args()
    print('Using following options:')
    print(opt)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    ensure_dir(opt.log_dir)

    # --- 关键修改：加载NACA_Cylinder数据集 ---
    print('Preparing the NACA_Cylinder data...')
    INPUT_X_PATH = os.path.join(opt.data_path, 'NACA_Cylinder_X.npy')
    INPUT_Y_PATH = os.path.join(opt.data_path, 'NACA_Cylinder_Y.npy')
    OUTPUT_Q_PATH = os.path.join(opt.data_path, 'NACA_Cylinder_Q.npy')

    ntrain, ntest = opt.ntrain, opt.ntest
    r1, r2 = opt.downsample_x, opt.downsample_y
    # NACA原始网格尺寸为 221x51
    s1 = int(((221 - 1) / r1) + 1)
    s2 = int(((51 - 1) / r2) + 1)

    inputX = torch.from_numpy(np.load(INPUT_X_PATH)).float()
    inputY = torch.from_numpy(np.load(INPUT_Y_PATH)).float()
    coords = torch.stack([inputX, inputY], dim=-1)
    # 根据transolver代码，选择第5个通道 (索引为4) 作为输出
    output = torch.from_numpy(np.load(OUTPUT_Q_PATH)[:, 4]).float().unsqueeze(-1)

    print(f"Original data shape: Coords-{coords.shape}, Output-{output.shape}")
    print(f"Downsampled grid size: {s1} x {s2} = {s1*s2} points")

    x_train = coords[:ntrain, ::r1, ::r2][:, :s1, :s2]
    y_train = output[:ntrain, ::r1, ::r2][:, :s1, :s2]
    # 使用 ntrain : ntrain+ntest 进行切片
    x_test = coords[ntrain:ntrain + ntest, ::r1, ::r2][:, :s1, :s2]
    y_test = output[ntrain:ntrain + ntest, ::r1, ::r2][:, :s1, :s2]

    x_train = x_train.reshape(ntrain, -1, 2)
    y_train = y_train.reshape(ntrain, -1, 1)
    x_test = x_test.reshape(ntest, -1, 2)
    y_test = y_test.reshape(ntest, -1, 1)

    # Normalization (逻辑与pipe版本相同)
    x_mean, x_std = x_train.mean(dim=[0, 1]), x_train.std(dim=[0, 1])
    y_mean, y_std = y_train.mean(dim=[0, 1]), y_train.std(dim=[0, 1])
    x_train_norm = (x_train - x_mean) / x_std
    x_test_norm = (x_test - x_mean) / x_std
    
    x_mean, x_std = x_mean.to(device), x_std.to(device)
    y_mean, y_std = y_mean.to(device), y_std.to(device)

    train_dataloader = DataLoader(TensorDataset(x_train_norm, y_train), batch_size=opt.batch_size, shuffle=True)
    # 将未归一化的x_test也放入loader，方便可视化
    test_dataloader = DataLoader(TensorDataset(x_test_norm, y_test, x_test), batch_size=opt.batch_size, shuffle=False)
    
    # --- 模型和优化器 ---
    print('Building network...')
    # --- 关键修改：为非方形网格选择一个合适的res值 ---
    # 通常选择较长的一边，以保证位置编码有足够的分辨率
    res_for_model = s1
    encoder, decoder = build_model(opt, res=res_for_model)
    encoder.to(device)
    decoder.to(device)

    writer = SummaryWriter(log_dir=opt.log_dir)
    checkpoint_dir = os.path.join(opt.log_dir, 'model_ckpt')
    ensure_dir(checkpoint_dir)
    sample_dir = os.path.join(opt.log_dir, 'samples')
    ensure_dir(sample_dir)
    # ... (后续的优化器、调度器、检查点、训练循环、评估逻辑与train_pipe.py完全相同，此处省略) ...
    # --- 主训练循环 (与train_pipe.py完全相同) ---
    enc_optim = torch.optim.Adam(encoder.parameters(), lr=opt.lr, weight_decay=1e-5)
    dec_optim = torch.optim.Adam(decoder.parameters(), lr=opt.lr, weight_decay=1e-5)
    enc_scheduler = OneCycleLR(enc_optim, max_lr=opt.lr, total_steps=opt.iters)
    dec_scheduler = OneCycleLR(dec_optim, max_lr=opt.lr, total_steps=opt.iters)
    start_n_iter = 0 # (此处应有完整的检查点恢复逻辑)

    n_iter = start_n_iter
    with tqdm(total=opt.iters, initial=n_iter) as pbar:
        train_data_iter = iter(train_dataloader)
        while n_iter < opt.iters:
            encoder.train()
            decoder.train()

            try:
                x, y = next(train_data_iter)
            except StopIteration:
                train_data_iter = iter(train_dataloader)
                x, y = next(train_data_iter)

            x, y = x.to(device), y.to(device)
            
            input_pos = prop_pos = x
            z = encoder.forward(input_pos, input_pos)
            x_out, diagnostics = decoder.forward(z, prop_pos, input_pos)

            x_out = x_out * y_std + y_mean
            loss = pointwise_rel_l2norm_loss(x_out, y)

            enc_optim.zero_grad()
            dec_optim.zero_grad()
            loss.backward()
            enc_optim.step()
            dec_optim.step()
            enc_scheduler.step()
            dec_scheduler.step()
            
            pbar.set_description(f'Iter: {n_iter}/{opt.iters} | Loss: {loss.item():.4e}')
            writer.add_scalar('train/loss', loss.item(), n_iter)
            if opt.adaptive and diagnostics and 'active_tokens_per_layer' in diagnostics:
                for i, k in enumerate(diagnostics['active_tokens_per_layer']):
                    writer.add_scalar(f'diagnostics/active_tokens_layer_{i}', k, n_iter)

            n_iter += 1
            pbar.update(1)
            # (此处应有完整的评估和保存逻辑)
            if n_iter % opt.ckpt_every == 0 or n_iter >= opt.iters:
                print('Testing...')
                encoder.eval()
                decoder.eval()
                
                with torch.no_grad():
                    all_avg_loss = []
                    test_iter = iter(test_dataloader)
                    x_test_batch, y_test_batch, x_test_unnorm_batch = next(test_iter) # 只可视化第一个batch

                    if use_cuda:
                        x_test_batch, y_test_batch = x_test_batch.cuda(), y_test_batch.cuda()

                    input_pos = prop_pos = x_test_batch
                    z = encoder.forward(input_pos, input_pos)
                    x_out_test, _ = decoder.forward(z, prop_pos, input_pos)
                    x_out_test = x_out_test * y_std + y_mean
                    
                    test_loss = pointwise_rel_l2norm_loss(x_out_test, y_test_batch)
                    writer.add_scalar('test/loss', test_loss.item(), n_iter)
                    print(f'Testing avg loss: {test_loss.item():.4e}')

                    # 可视化
                    # make_scatter_plot(x_test_unnorm_batch, x_out_test, y_test_batch,
                    #                   os.path.join(sample_dir, f'result_iter_{n_iter}.png'))

                # 保存检查点
                ckpt = { 'encoder': encoder.state_dict(), 'decoder': decoder.state_dict(), 'n_iter': n_iter,
                         'enc_optim': enc_optim.state_dict(), 'dec_optim': dec_optim.state_dict(),
                         'enc_sched': enc_scheduler.state_dict(), 'dec_sched': dec_scheduler.state_dict() }
                save_checkpoint(ckpt, os.path.join(checkpoint_dir, f'model_checkpoint{n_iter}.ckpt'))
                del ckpt

    print("Training finished.")