import torch
import torch.nn as nn
import numpy as np
import argparse
from tqdm import tqdm
import time
import os
from torch.optim.lr_scheduler import OneCycleLR
from tensorboardX import SummaryWriter

# --- 关键修改：导入我们新设计的模型 ---
from nn_module.encoder_module import SpatialEncoder2D
from nn_module.decoder_module import PointWiseDecoder2DSimple, AdaptivePointWiseDecoder2D_SteadyState

from loss_fn import pointwise_rel_l2norm_loss
from utils import load_checkpoint, save_checkpoint, ensure_dir
import matplotlib.pyplot as plt
import logging
import shutil
from torch.utils.data import Dataset, DataLoader, TensorDataset

# --- 设置随机种子以保证可复现性 ---
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)
torch.multiprocessing.set_sharing_strategy('file_system')

def build_model(args, res: int) -> (nn.Module, nn.Module):
    """
    根据命令行参数构建Encoder和Decoder。
    """
    # Encoder的输入是2D坐标，所以输入通道为2
    encoder = SpatialEncoder2D(
        2,   # <-- 修改点：输入通道为2 (x, y)
        96,
        256,
        4,
        6,
        res=res,
        use_ln=True
    )

    # --- 关键修改：根据 --adaptive 参数选择不同的解码器 ---
    if args.adaptive:
        print("Building ADAPTIVE Decoder...")
        decoder = AdaptivePointWiseDecoder2D_SteadyState(
            latent_channels=256,
            out_channels=1,
            propagator_depth=args.propagator_depth,
            capacity_ratios=args.capacity_ratios, # 使用默认的线性衰减
            final_keep_ratio=args.final_keep_ratio,
            scale=0.5
        )
    else:
        print("Building DEEP (Dense) Decoder...")
        decoder = PointWiseDecoder2DSimple(
            latent_channels=256,
            out_channels=1,
            refinement_depth=args.propagator_depth, # 复用参数
            scale=0.5,
        )

    total_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad) + \
                   sum(p.numel() for p in decoder.parameters() if p.requires_grad)
    print(f'Total trainable parameters: {total_params}')
    return encoder, decoder

def make_scatter_plot(pos: torch.Tensor, u_pred: torch.Tensor, u_gt: torch.Tensor, out_path, num_samples=3):
    """
    为点云数据创建散点图可视化。
    """
    pos = pos.detach().cpu().numpy()
    u_pred = u_pred.detach().cpu().squeeze(-1).numpy()
    u_gt = u_gt.detach().cpu().squeeze(-1).numpy()
    
    batch_size = pos.shape[0]
    num_to_plot = min(batch_size, num_samples)

    fig, axes = plt.subplots(num_to_plot, 2, figsize=(10, 5 * num_to_plot), squeeze=False)
    fig.suptitle('Prediction vs. Ground Truth', fontsize=16)

    for i in range(num_to_plot):
        # 绘制预测值
        ax1 = axes[i, 0]
        sc1 = ax1.scatter(pos[i, :, 0], pos[i, :, 1], c=u_pred[i], cmap='viridis', s=4)
        ax1.set_title(f'Sample {i} - Prediction')
        ax1.set_xlabel('x')
        ax1.set_ylabel('y')
        ax1.axis('equal')
        fig.colorbar(sc1, ax=ax1)

        # 绘制真实值
        ax2 = axes[i, 1]
        sc2 = ax2.scatter(pos[i, :, 0], pos[i, :, 1], c=u_gt[i], cmap='viridis', s=4)
        ax2.set_title(f'Sample {i} - Ground Truth')
        ax2.set_xlabel('x')
        ax2.set_ylabel('y')
        ax2.axis('equal')
        fig.colorbar(sc2, ax=ax2)

    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.savefig(out_path, bbox_inches='tight')
    plt.close()

def get_arguments(parser):
    # --- 关键修改：更新命令行参数 ---
    parser.add_argument('--lr', type=float, default=5e-4, help='Learning rate')
    parser.add_argument('--iters', type=int, default=200000, help='Number of training iterations')
    parser.add_argument('--batch_size', type=int, default=16, help='Batch size')
    parser.add_argument('--log_dir', type=str, default='./pipe_exp', help='Path for logging and checkpoints')
    parser.add_argument('--ckpt_every', type=int, default=10000, help='Save checkpoint frequency')
    parser.add_argument('--resume', action='store_true', help='Resume training from checkpoint')
    parser.add_argument('--path_to_resume', type=str, default='', help='Path to checkpoint for resuming')
    
    # 数据集相关参数
    parser.add_argument('--data_path', type=str, required=True, help='Base path to Pipe .npy files')
    parser.add_argument('--ntrain', type=int, default=1000, help='Number of training samples')
    parser.add_argument('--ntest', type=int, default=200, help='Number of testing samples')
    parser.add_argument('--downsample_res', type=int, default=8, help='Downsampling ratio for the grid')

    # 模型相关参数
    parser.add_argument('--adaptive', action='store_true', help='Use the adaptive decoder instead of the deep dense one')
    parser.add_argument('--propagator_depth', type=int, default=4, help='Depth of the refinement/propagator network')
    parser.add_argument('--final_keep_ratio', type=float, default=0.25, help='Final keep ratio for adaptive capacity decay')
    parser.add_argument('--capacity_ratios', type=float, nargs='+', default=None,
                        help='A list of capacity ratios for each adaptive layer, separated by spaces. E.g., --capacity_ratios 1.0 0.8 0.6 0.4')
    return parser

# --- Main Code ---
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train a transformer for the Pipe dataset")
    parser = get_arguments(parser) 
    opt = parser.parse_args()
    print('Using following options:')
    print(opt)


    use_cuda = torch.cuda.is_available()
    ensure_dir(opt.log_dir)

    # --- 关键修改：加载Pipe数据集 ---
    print('Preparing the Pipe data...')
    INPUT_X_PATH = os.path.join(opt.data_path, 'Pipe_X.npy')
    INPUT_Y_PATH = os.path.join(opt.data_path, 'Pipe_Y.npy')
    OUTPUT_Q_PATH = os.path.join(opt.data_path, 'Pipe_Q.npy')

    ntrain, ntest = opt.ntrain, opt.ntest
    N = ntrain + ntest
    r = opt.downsample_res
    s = int(((129 - 1) / r) + 1)

    inputX = torch.from_numpy(np.load(INPUT_X_PATH)).float()
    inputY = torch.from_numpy(np.load(INPUT_Y_PATH)).float()
    coords = torch.stack([inputX, inputY], dim=-1)
    output = torch.from_numpy(np.load(OUTPUT_Q_PATH)[:, 0]).float().unsqueeze(-1) # Add channel dim

    x_train = coords[:ntrain, ::r, ::r][:, :s, :s]
    y_train = output[:ntrain, ::r, ::r][:, :s, :s]
    x_test = coords[-ntest:, ::r, ::r][:, :s, :s]
    y_test = output[-ntest:, ::r, ::r][:, :s, :s]

    # Reshape to point cloud format
    x_train = x_train.reshape(ntrain, -1, 2)
    y_train = y_train.reshape(ntrain, -1, 1)
    x_test = x_test.reshape(ntest, -1, 2)
    y_test = y_test.reshape(ntest, -1, 1)

    # Normalization (Crucial for training stability)
    x_mean, x_std = x_train.mean(dim=[0, 1]), x_train.std(dim=[0, 1])
    y_mean, y_std = y_train.mean(dim=[0, 1]), y_train.std(dim=[0, 1])

    x_train = (x_train - x_mean) / x_std
    x_test_unnorm = x_test.clone() # Keep for visualization
    x_test = (x_test - x_mean) / x_std
    
    if use_cuda:
        x_mean, x_std = x_mean.cuda(), x_std.cuda()
        y_mean, y_std = y_mean.cuda(), y_std.cuda()

    train_dataloader = DataLoader(TensorDataset(x_train, y_train), batch_size=opt.batch_size, shuffle=True)
    test_dataloader = DataLoader(TensorDataset(x_test, y_test, x_test_unnorm), batch_size=opt.batch_size, shuffle=False)
    
    # --- 模型和优化器 ---
    print('Building network...')
    encoder, decoder = build_model(opt, res=s)
    if use_cuda:
        encoder, decoder = encoder.cuda(), decoder.cuda()

    writer = SummaryWriter(log_dir=opt.log_dir)
    checkpoint_dir = os.path.join(opt.log_dir, 'model_ckpt')
    ensure_dir(checkpoint_dir)
    sample_dir = os.path.join(opt.log_dir, 'samples')
    ensure_dir(sample_dir)

    # ... (日志记录和脚本保存部分与Darcy版本类似) ...

    enc_optim = torch.optim.Adam(encoder.parameters(), lr=opt.lr, weight_decay=1e-5)
    dec_optim = torch.optim.Adam(decoder.parameters(), lr=opt.lr, weight_decay=1e-5)

    enc_scheduler = OneCycleLR(enc_optim, max_lr=opt.lr, total_steps=opt.iters)
    dec_scheduler = OneCycleLR(dec_optim, max_lr=opt.lr, total_steps=opt.iters)

    start_n_iter = 0
    # ... (检查点恢复部分与Darcy版本相同) ...

    # --- 主训练循环 ---
    n_iter = start_n_iter
    with tqdm(total=opt.iters, initial=n_iter) as pbar:
        train_data_iter = iter(train_dataloader)
        while n_iter < opt.iters:
            encoder.train()
            decoder.train()

            try:
                x, y = next(train_data_iter)
            except StopIteration:
                train_data_iter = iter(train_dataloader)
                x, y = next(train_data_iter)

            if use_cuda:
                x, y = x.cuda(), y.cuda()

            # --- 模型前向传播 ---
            # 对于Pipe数据，输入特征和位置都是坐标本身
            input_pos = prop_pos = x
            
            z = encoder.forward(input_pos, input_pos)
            x_out, diagnostics = decoder.forward(z, prop_pos, input_pos)

            # 反归一化
            x_out = x_out * y_std + y_mean
            
            # --- 关键修改：简化损失函数 ---
            loss = pointwise_rel_l2norm_loss(x_out, y)

            enc_optim.zero_grad()
            dec_optim.zero_grad()
            loss.backward()
            enc_optim.step()
            dec_optim.step()
            enc_scheduler.step()
            dec_scheduler.step()
            
            pbar.set_description(f'Iter: {n_iter}/{opt.iters} | Loss: {loss.item():.4e}')
            writer.add_scalar('train/loss', loss.item(), n_iter)
            if opt.adaptive and 'active_tokens_per_layer' in diagnostics:
                for i, k in enumerate(diagnostics['active_tokens_per_layer']):
                    writer.add_scalar(f'diagnostics/active_tokens_layer_{i}', k, n_iter)

            n_iter += 1
            pbar.update(1)

            # --- 评估和保存 ---
            if n_iter % opt.ckpt_every == 0 or n_iter >= opt.iters:
                print('Testing...')
                encoder.eval()
                decoder.eval()
                
                with torch.no_grad():
                    all_avg_loss = []
                    test_iter = iter(test_dataloader)
                    x_test_batch, y_test_batch, x_test_unnorm_batch = next(test_iter) # 只可视化第一个batch

                    if use_cuda:
                        x_test_batch, y_test_batch = x_test_batch.cuda(), y_test_batch.cuda()

                    input_pos = prop_pos = x_test_batch
                    z = encoder.forward(input_pos, input_pos)
                    x_out_test, _ = decoder.forward(z, prop_pos, input_pos)
                    x_out_test = x_out_test * y_std + y_mean
                    
                    test_loss = pointwise_rel_l2norm_loss(x_out_test, y_test_batch)
                    writer.add_scalar('test/loss', test_loss.item(), n_iter)
                    print(f'Testing avg loss: {test_loss.item():.4e}')

                    # 可视化
                    # make_scatter_plot(x_test_unnorm_batch, x_out_test, y_test_batch,
                    #                   os.path.join(sample_dir, f'result_iter_{n_iter}.png'))

                # 保存检查点
                ckpt = { 'encoder': encoder.state_dict(), 'decoder': decoder.state_dict(), 'n_iter': n_iter,
                         'enc_optim': enc_optim.state_dict(), 'dec_optim': dec_optim.state_dict(),
                         'enc_sched': enc_scheduler.state_dict(), 'dec_sched': dec_scheduler.state_dict() }
                save_checkpoint(ckpt, os.path.join(checkpoint_dir, f'model_checkpoint{n_iter}.ckpt'))
                del ckpt

    print("Training finished.")