import os
import argparse
import numpy as np
import scipy.io as scio
import torch
import torch.nn.functional as F
from tqdm import tqdm
from einops import rearrange

# 导入新的模型字典和与 exp_darcy.py 相同的工具
from model_dict_adaptive import get_model
from utils.testloss import TestLoss
from utils.normalizer import UnitTransformer

def count_parameters(model):
    """计算模型的可训练参数数量"""
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        total_params += parameter.numel()
    print(f"Total Trainable Params: {total_params}")
    return total_params

def central_diff(x: torch.Tensor, dx, resolution):
    """
    使用卷积计算中心差分来近似梯度，与 exp_darcy.py 保持一致。
    x: (batch, n_points, 1)
    """
    # 确保输入是4D张量 (B, C, H, W) 以便卷积
    x = rearrange(x, 'b (h w) c -> b c h w', h=resolution, w=resolution)
    
    # 定义x和y方向的中心差分卷积核
    # grad_x: [-0.5, 0, 0.5]
    kernel_x = torch.tensor([[[[-0.5, 0, 0.5]]]], dtype=x.dtype, device=x.device) / dx
    # grad_y: [[-0.5], [0], [0.5]]
    kernel_y = torch.tensor([[[[-0.5], [0], [0.5]]]], dtype=x.dtype, device=x.device) / dx
    
    # 'same' padding 确保输出尺寸不变
    grad_x = F.conv2d(x, kernel_x, padding='same')
    grad_y = F.conv2d(x, kernel_y, padding='same')
    
    # 将输出转换回 (batch, n_points, 1)
    grad_x = rearrange(grad_x, 'b c h w -> b (h w) c')
    grad_y = rearrange(grad_y, 'b c h w -> b (h w) c')
    
    return grad_x, grad_y

def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # --- 1. 数据处理 (严格遵循 exp_darcy.py) ---
    print("--- 1. Loading and Processing Darcy Dataset ---")
    train_path = os.path.join(args.data_path, 'piececonst_r421_N1024_smooth1.mat')
    test_path = os.path.join(args.data_path, 'piececonst_r421_N1024_smooth2.mat')
    ntrain = args.ntrain
    ntest = 200

    # a. 下采样与网格尺寸
    r = args.downsample
    s = h = int(((421 - 1) / r) + 1)
    dx = 1.0 / s

    # b. 加载训练数据
    train_data = scio.loadmat(train_path)
    x_train_coeff = torch.from_numpy(train_data['coeff'][:ntrain, ::r, ::r][:, :s, :s].reshape(ntrain, -1)).float()
    y_train_sol = torch.from_numpy(train_data['sol'][:ntrain, ::r, ::r][:, :s, :s].reshape(ntrain, -1)).float()

    # c. 加载测试数据
    test_data = scio.loadmat(test_path)
    x_test_coeff = torch.from_numpy(test_data['coeff'][:ntest, ::r, ::r][:, :s, :s].reshape(ntest, -1)).float()
    y_test_sol = torch.from_numpy(test_data['sol'][:ntest, ::r, ::r][:, :s, :s].reshape(ntest, -1)).float()

    # d. 数据归一化
    x_normalizer = UnitTransformer(x_train_coeff)
    y_normalizer = UnitTransformer(y_train_sol)
    x_train_coeff = x_normalizer.encode(x_train_coeff)
    x_test_coeff = x_normalizer.encode(x_test_coeff)
    y_train_sol = y_normalizer.encode(y_train_sol)
    x_normalizer.to(device)
    y_normalizer.to(device)

    # e. 创建位置坐标
    gridx = torch.tensor(np.linspace(0, 1, s), dtype=torch.float)
    gridy = torch.tensor(np.linspace(0, 1, s), dtype=torch.float)
    gridx, gridy = torch.meshgrid(gridx, gridy, indexing='ij')
    pos = torch.stack([gridx.ravel(), gridy.ravel()], dim=-1).unsqueeze(0)
    pos_train = pos.repeat(ntrain, 1, 1)
    pos_test = pos.repeat(ntest, 1, 1)
    
    # f. 创建 DataLoader
    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(pos_train, x_train_coeff, y_train_sol),
                                               batch_size=args.batch_size, shuffle=True, num_workers=4)
    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(pos_test, x_test_coeff, y_test_sol),
                                              batch_size=args.batch_size, shuffle=False, num_workers=4)

    print(f"Data processing complete. Grid size: {s}x{s}")
    print("-" * 50)
    
    # --- [核心修改] 验证自适应模型必需的参数 ---
    if len(args.capacity_ratios) != args.n_layers:
        raise ValueError(
            f"参数错误: --capacity_ratios 的长度 ({len(args.capacity_ratios)}) "
            f"必须与 --n-layers ({args.n_layers}) 完全匹配。"
        )
        
    # --- 2. 模型创建 (使用通用自适应模型) ---
    print(f"--- 2. Creating Model: {args.model} ---")
    model = get_model(args).Model(
        space_dim=2,
        fun_dim=1, # Darcy 有1个函数输入 (coeff)
        out_dim=1,
        H=s, W=s,
        n_hidden=args.n_hidden,
        n_layers=args.n_layers,
        n_head=args.n_head,
        mlp_ratio=args.mlp_ratio,
        dropout=args.dropout,
        slice_num=args.slice_num,
        capacity_ratios=args.capacity_ratios,
        ref=args.ref,
        unified_pos=args.unified_pos
    ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, epochs=args.epochs,
                                                    steps_per_epoch=len(train_loader))
    myloss = TestLoss(size_average=False)
    
    print(args)
    count_parameters(model)
    print("-" * 50)

    # --- 3. 训练与评估循环 (严格遵循 exp_darcy.py 的复合损失逻辑) ---
    print(f"--- 3. Starting Training for {args.epochs} Epochs ---")
    best_rel_err = float('inf')

    for ep in range(args.epochs):
        model.train()
        train_l2_loss, train_reg_loss = 0, 0
        pbar = tqdm(train_loader, desc=f"Epoch {ep+1}/{args.epochs}")

        for x, fx, y in pbar:
            x, fx, y = x.to(device), fx.to(device), y.to(device)
            optimizer.zero_grad()
            
            # 模型调用，fx 是渗透率 coeff
            out, _ = model(x, fx=fx.unsqueeze(-1))
            
            # 反归一化
            out_decoded = y_normalizer.decode(out.squeeze(-1))
            y_decoded = y_normalizer.decode(y)

            # a. 计算 L2 损失
            l2loss = myloss(out_decoded, y_decoded)
            
            # b. 计算梯度损失
            gt_grad_x, gt_grad_y = central_diff(y_decoded.unsqueeze(-1), dx, s)
            pred_grad_x, pred_grad_y = central_diff(out_decoded.unsqueeze(-1), dx, s)
            deriv_loss = myloss(pred_grad_x.squeeze(-1), gt_grad_x.squeeze(-1)) + \
                         myloss(pred_grad_y.squeeze(-1), gt_grad_y.squeeze(-1))
            
            # c. 组合损失
            loss = l2loss + 0.1 * deriv_loss
            loss.backward()

            if args.max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            scheduler.step()
            
            train_l2_loss += l2loss.item()
            train_reg_loss += (0.1 * deriv_loss).item()
            pbar.set_postfix(l2=l2loss.item()/x.size(0), reg=(0.1*deriv_loss).item()/x.size(0))

        train_l2_loss /= ntrain
        train_reg_loss /= ntrain

        # 评估
        model.eval()
        rel_err = 0.0
        with torch.no_grad():
            for x, fx, y in test_loader:
                x, fx, y = x.to(device), fx.to(device), y.to(device)
                out, _ = model(x, fx=fx.unsqueeze(-1))
                out_decoded = y_normalizer.decode(out.squeeze(-1))
                rel_err += myloss(out_decoded, y).item()

        rel_err /= ntest
        
        print(f"Epoch {ep+1} | Train L2: {train_l2_loss:.5f} | Reg Loss: {train_reg_loss:.5f} | Test Rel Error: {rel_err:.5f}")

        if rel_err < best_rel_err:
            best_rel_err = rel_err
            print(f"🚀 New best model found! Saving to checkpoints/{args.save_name}.pt")
            if not os.path.exists('./checkpoints'):
                os.makedirs('./checkpoints')
            torch.save(model.state_dict(), os.path.join('./checkpoints', args.save_name + '.pt'))
            
    print("--- Training Complete ---")

if __name__ == "__main__":
    parser = argparse.ArgumentParser('Adaptive Transolver Training for Darcy Flow (exp_darcy style)')
    
    # 命令行参数 (严格基于 exp_darcy.py 并进行扩展)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--epochs', type=int, default=500)
    parser.add_argument('--weight_decay', type=float, default=1e-5)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument("--gpu", type=str, default='1', help="GPU index to use")
    parser.add_argument('--max_grad_norm', type=float, default=None)
    parser.add_argument('--downsample', type=int, default=5)
    parser.add_argument('--ntrain', type=int, default=1000)
    parser.add_argument('--save_name', type=str, default='darcy_adaptive')
    parser.add_argument('--data_path', type=str, required=True, help='Path to the directory containing darcy .mat files')

    parser.add_argument('--model', type=str, default='PipeAdaptiveTransolver', help="Model's name")
    parser.add_argument('--n-hidden', type=int, default=64)
    parser.add_argument('--n-layers', type=int, default=3)
    parser.add_argument('--n-head', type=int, default=4)
    parser.add_argument('--mlp_ratio', type=int, default=1)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--slice_num', type=int, default=32)
    parser.add_argument('--unified_pos', type=int, default=0)
    parser.add_argument('--ref', type=int, default=8)
    
    # 自适应模型专用参数
    parser.add_argument('--capacity_ratios', type=float, nargs='+', required=True,
                        help='[必需] 每个递归层的保留比例列表，长度必须等于 n-layers')
    
    args = parser.parse_args()
    main(args)