import torch
import torch.nn as nn
import numpy as np
import argparse
from tqdm import tqdm
import time
import os
import shutil
import logging
from einops import rearrange
from torch.optim.lr_scheduler import OneCycleLR
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader, TensorDataset
from scipy.io import loadmat
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

# --- 1. 导入所有必要的模块 ---
from nn_module.encoder_module import Encoder1D
# ！！！关键：导入我们新的自适应Decoder！！！
from nn_module.decoder_module import AdaptivePointWiseDecoder1D
from loss_fn import rel_loss
from utils import load_checkpoint, save_checkpoint, ensure_dir

# --- 设置随机种子 (与原脚本一致) ---
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

# --- 辅助函数 (与原脚本一致) ---
def central_diff(x: torch.Tensor, h):
    pad_0, pad_1 = x[:, -2:-1], x[:, 1:2]
    x = torch.cat([pad_0, x, pad_1], dim=1)
    x_diff = (x[:, 2:] - x[:, :-2]) / 2
    return x_diff / h

def make_image_grid(init: torch.Tensor, sequence: torch.Tensor, gt: torch.Tensor, out_path, nrow=8):
    b, n, c = sequence.shape
    init = init.detach().cpu().squeeze(-1).numpy()
    sequence = sequence.detach().cpu().squeeze(-1).numpy()
    gt = gt.detach().cpu().squeeze(-1).numpy()
    fig = plt.figure(figsize=(16., 16.))
    grid = ImageGrid(fig, 111, nrows_ncols=(b // nrow, nrow))
    x_coords = np.linspace(0, 1, n)
    for ax, im_no in zip(grid, np.arange(b)):
        ax.plot(x_coords, sequence[im_no], c='r', label='Prediction')
        ax.plot(x_coords, gt[im_no], '--', c='g', alpha=0.8, label='Ground Truth')
        ax.axis('equal')
        ax.axis('off')
    plt.savefig(out_path, bbox_inches='tight')
    plt.close()

# ==============================================================================
# 2. 命令行参数定义 (已修改)
# ==============================================================================
def get_arguments(parser):
    # --- [新增] 自适应解码器专属参数 ---
    parser.add_argument(
        '--capacity_ratios', nargs='+', type=float, required=True,
        help='[自适应] Decoder每层递归的保留比例列表 (必须提供!)'
    )
    parser.add_argument(
        '--propagator_depth', type=int, default=3,
        help='自适应传播器的深度'
    )

    # --- GPU与设备 ---
    parser.add_argument('--gpu', type=int, default=0, help='指定使用的GPU设备ID')
    
    # --- 训练控制 (与原脚本一致) ---
    parser.add_argument('--lr', type=float, default=3e-4, help='学习率')
    parser.add_argument('--resume', action='store_true', help='从检查点恢复训练')
    parser.add_argument('--path_to_resume', type=str, default='', help='用于恢复的检查点路径')
    parser.add_argument('--iters', type=int, default=100000, help='训练迭代总次数')
    parser.add_argument('--log_dir', type=str, default='./burgers_adapt_exp', help='日志和模型保存路径')
    parser.add_argument('--ckpt_every', type=int, default=5000, help='保存检查点的频率')
    
    # --- 数据集参数 (与原脚本一致) ---
    parser.add_argument('--batch_size', type=int, default=16, help='批处理大小')
    parser.add_argument('--dataset_path', type=str, required=True, help='数据集 .mat 文件路径')
    parser.add_argument('--train_seq_num', type=int, default=1000, help='训练集中的序列数量')
    parser.add_argument('--test_seq_num', type=int, default=100, help='测试集中的序列数量')
    parser.add_argument('--resolution', type=int, default=2048, help='空间分辨率')
    
    return parser

# ==============================================================================
# 3. 模型构建函数 (已修改)
# ==============================================================================
def build_model(res, opt):
    """构建用于Burgers任务的自适应模型"""
    encoder = Encoder1D(
        2, 512, 512, 4, res=res
    )

    print("\n--- 正在构建模型: Encoder1D + 自适应Decoder1D ---\n")
    decoder = AdaptivePointWiseDecoder1D(
        latent_channels=512,
        out_channels=1,
        propagator_depth=opt.propagator_depth,
        capacity_ratios=opt.capacity_ratios,
        res=res,
        scale=2
    )

    total_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad) + \
                   sum(p.numel() for p in decoder.parameters() if p.requires_grad)
    print(f'总可训练参数: {total_params:,}')
    return encoder, decoder

# ==============================================================================
# 4. 主执行逻辑
# ==============================================================================
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="[自适应版] 训练 OFormer (Burgers)")
    opt = get_arguments(parser).parse_args()
    print('当前使用的配置:')
    print(opt)
    
    # --- 设备设置 ---
    use_cuda = torch.cuda.is_available()
    device = torch.device(f'cuda:{opt.gpu}' if use_cuda else 'cpu')
    print(f"训练将在 {device} 上进行。")
    if use_cuda:
        torch.cuda.set_device(device)

    # --- 日志与路径准备 (与原脚本一致) ---
    ensure_dir(opt.log_dir)
    checkpoint_dir = os.path.join(opt.log_dir, 'model_ckpt')
    ensure_dir(checkpoint_dir)
    sample_dir = os.path.join(opt.log_dir, 'samples')
    ensure_dir(sample_dir)
    writer = SummaryWriter(log_dir=opt.log_dir)
    
    logger = logging.getLogger("LOG")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(os.path.join(opt.log_dir, 'logging_info.txt'))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info('=======Option used=======')
    for arg in vars(opt):
        logger.info(f'{arg}: {getattr(opt, arg)}')

    # --- 代码备份 (与原脚本一致) ---
    print("正在备份脚本代码...")
    script_dir = os.path.join(opt.log_dir, 'script_cache')
    ensure_dir(script_dir)
    try:
        shutil.copy(__file__, script_dir)
        shutil.copy('./nn_module/encoder_module.py', script_dir)
        shutil.copy('./nn_module/decoder_module.py', script_dir)
    except Exception as e:
        print(f"代码备份时出错: {e}")

    # --- 数据加载 (与原脚本一致) ---
    print('正在准备数据...')
    data_path = opt.dataset_path
    ntrain = opt.train_seq_num
    ntest = opt.test_seq_num
    res = opt.resolution
    sub = 2 ** 13 // res
    dx = 1. / res

    data = loadmat(data_path)
    x_data = data['a'][:, ::sub]
    y_data = data['u'][:, ::sub]

    x_train = torch.from_numpy(x_data[:ntrain, :].reshape(ntrain, res, 1)).float()
    y_train = torch.from_numpy(y_data[:ntrain, :].reshape(ntrain, res, 1)).float()
    x_test = torch.from_numpy(x_data[-ntest:, :].reshape(ntest, res, 1)).float()
    y_test = torch.from_numpy(y_data[-ntest:, :].reshape(ntest, res, 1)).float()
    
    gridx = torch.tensor(np.linspace(0, 1, res), dtype=torch.float32).reshape(1, res, 1)

    train_dataloader = DataLoader(TensorDataset(x_train, y_train), batch_size=opt.batch_size, shuffle=True)
    test_dataloader = DataLoader(TensorDataset(x_test, y_test), batch_size=opt.batch_size, shuffle=False)

    # --- 构建模型、优化器、调度器 ---
    encoder, decoder = build_model(res, opt)
    encoder.to(device)
    decoder.to(device)

    optim_params = list(encoder.parameters()) + list(decoder.parameters())
    optimizer = torch.optim.AdamW(optim_params, lr=opt.lr, weight_decay=1e-4)
    scheduler = OneCycleLR(optimizer, max_lr=opt.lr, total_steps=opt.iters, div_factor=1e4, final_div_factor=1e4)

    start_n_iter = 0
    # (检查点加载逻辑与原脚本类似)
    if opt.resume and os.path.exists(opt.path_to_resume):
        # ...
        pass
    
    # --- 5. 主训练循环 ---
    n_iter = start_n_iter
    with tqdm(total=opt.iters, initial=start_n_iter) as pbar:
        train_data_iter = iter(train_dataloader)
        while n_iter < opt.iters:
            encoder.train()
            decoder.train()
            
            try:
                x, y = next(train_data_iter)
            except StopIteration:
                train_data_iter = iter(train_dataloader)
                x, y = next(train_data_iter)
            
            x, y = x.to(device), y.to(device)
            input_pos = prop_pos = gridx.repeat(x.shape[0], 1, 1).to(device)
            x_with_pos = torch.cat((x, input_pos), dim=-1)
            
            optimizer.zero_grad()
            
            z = encoder(x_with_pos, input_pos)
            
            # --- [核心修改] 调用自适应Decoder，并处理返回值 ---
            x_out, diagnostics = decoder(z, prop_pos, input_pos)
            
            if diagnostics:
                writer.add_scalar(f'decoder_diagnostics/avg_time_per_layer_ms', np.mean(diagnostics['time_per_layer_ms']), n_iter)

            pred_loss = rel_loss(x_out, y, p=2)
            
            gt_deriv = central_diff(y, dx)
            pred_deriv = central_diff(x_out, dx)
            deriv_loss = rel_loss(pred_deriv, gt_deriv, p=2)
            
            loss = pred_loss + 1e-3 * deriv_loss

            loss.backward()
            optimizer.step()
            scheduler.step()

            writer.add_scalar('train_loss/total', loss.item(), n_iter)
            pbar.set_description(f'Total (1e-4): {loss.item()*1e4:.1f} | Pred (1e-4): {pred_loss.item()*1e4:.1f} | Deriv (1e-4): {deriv_loss.item()*1e4:.1f}')
            
            n_iter += 1
            pbar.update(1)

            # --- 6. 验证与保存 ---
            if (n_iter % opt.ckpt_every == 0) or (n_iter >= opt.iters):
                encoder.eval()
                decoder.eval()
                all_avg_loss = []
                visualization_cache = {'in_seq': [], 'pred': [], 'gt': []}
                picked = 0

                with torch.no_grad():
                    for x_test_batch, y_test_batch in tqdm(test_dataloader, desc="Testing"):
                        x_test_batch, y_test_batch = x_test_batch.to(device), y_test_batch.to(device)
                        
                        input_pos_test = gridx.repeat(x_test_batch.shape[0], 1, 1).to(device)
                        x_test_with_pos = torch.cat((x_test_batch, input_pos_test), dim=-1)
                        
                        z_test = encoder(x_test_with_pos, input_pos_test)
                        x_out_test, _ = decoder(z_test, input_pos_test, input_pos_test)

                        avg_loss = rel_loss(x_out_test, y_test_batch, p=2)
                        all_avg_loss.append(avg_loss.item())

                        if picked < 64:
                            # (可视化缓存逻辑)
                            pass

                avg_test_loss = np.mean(all_avg_loss)
                writer.add_scalar('test_loss/avg_rel_l2', avg_test_loss, n_iter)
                logger.info(f'Iter: {n_iter}, Test Loss: {avg_test_loss:.6f}')
                print(f'\nIter: {n_iter}, Test Loss: {avg_test_loss:.6f}')
                
                # (可视化与检查点保存逻辑)
                ckpt = {'encoder': encoder.state_dict(), 'decoder': decoder.state_dict(), 'n_iter': n_iter}
                save_checkpoint(ckpt, os.path.join(checkpoint_dir, f'model_checkpoint_{n_iter}.ckpt'))

    writer.close()
    print('训练完成。')