import torch
import torch.nn as nn
import numpy as np
import argparse
from tqdm import tqdm
import time
import os
from functools import partial
from torch.optim.lr_scheduler import StepLR, OneCycleLR
from tensorboardX import SummaryWriter

from nn_module.encoder_module import SpatialTemporalEncoder2D
# 更新导入以使用自适应解码器
from nn_module.decoder_module import PointWiseDecoder2D_Adaptive

from dataset import get_new_data_loader
from loss_fn import rel_loss, rel_l2norm_loss
from utils import load_checkpoint, save_checkpoint, ensure_dir
import torchvision
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import datetime
import logging
import shutil
from typing import Union
from einops import rearrange, repeat
from torch.utils.data import Dataset, DataLoader, TensorDataset

# 设置 flags / seeds
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)
torch.multiprocessing.set_sharing_strategy('file_system')
torch.autograd.set_detect_anomaly(True)


def build_model(opt) -> (SpatialTemporalEncoder2D, PointWiseDecoder2D_Adaptive):

    encoder = SpatialTemporalEncoder2D(
        opt.in_channels,
        opt.encoder_emb_dim,
        opt.out_seq_emb_dim,
        opt.encoder_heads,
        opt.encoder_depth,
    )

    # 为自适应解码器准备kwargs，传递必要的编码器选项
    decoder_kwargs = {
        'out_seq_emb_dim': opt.out_seq_emb_dim,
        'encoder_heads': opt.encoder_heads,
        'capacity_ratios': opt.capacity_ratios,
        'final_keep_ratio': opt.final_keep_ratio,
    }

    # 实例化 PointWiseDecoder2D_Adaptive 模型
    decoder = PointWiseDecoder2D_Adaptive(
        latent_channels=opt.decoder_emb_dim,
        out_channels=opt.out_channels,
        out_steps=opt.out_step,
        propagator_depth=opt.propagator_depth,
        scale=opt.fourier_frequency,
        dropout=0.0,
        **decoder_kwargs,
    )

    total_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad) + \
                   sum(p.numel() for p in decoder.parameters() if p.requires_grad)
    print(f'Total trainable parameters: {total_params}')
    return encoder, decoder


# 改编自 Galerkin Transformer
def central_diff(x: torch.Tensor):
    # 假设周期性边界条件
    # x: (batch, seq_len, n), h 是步长, 假设 n = h*w
    x = rearrange(x, 'b t (h w) -> b t h w', h=64, w=64)
    h = 1./64.
    x = F.pad(x,
              (1, 1, 1, 1), mode='circular')  # [b t h+2 w+2]
    grad_x = (x[..., 1:-1, 2:] - x[..., 1:-1, :-2]) / (2*h)  # f(x+h) - f(x-h) / 2h
    grad_y = (x[..., 2:, 1:-1] - x[..., :-2, 1:-1]) / (2*h)  # f(x+h) - f(x-h) / 2h

    return grad_x, grad_y


def make_image_grid(image: torch.Tensor, out_path, nrow=25):
    b, t, h, w = image.shape
    image = image.detach().cpu().numpy()
    image = image.reshape((b*t, h, w))
    fig = plt.figure(figsize=(8., 8.))
    grid = ImageGrid(fig, 111,  # 类似于 subplot(111)
                     nrows_ncols=(b*t//nrow, nrow),  # 创建轴的网格
                     )

    for ax, im_no in zip(grid, np.arange(b*t)):
        # 迭代网格返回 Axes.
        ax.imshow(image[im_no])
        ax.axis('off')
    plt.savefig(out_path, bbox_inches='tight')
    plt.close()


def index_points(points, idx):
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points


def get_arguments(parser):
    # 基本训练设置
    parser.add_argument(
        '--lr', type=float, default=1e-4, help='指定微调的学习率 (默认: 1e-6)'
    )
    parser.add_argument( # <--- 新增/修改：添加 --gpu 参数
        '--gpu', type=int, default=0, help='指定要使用的GPU ID (默认: 0)'
    )
    parser.add_argument(
        '--resume_training', action='store_true',
        help='从检查点恢复训练'
    )
    parser.add_argument(
        '--path_to_resume', type=str,
        default='none', help='用于恢复训练的检查点路径 (默认: "")'
    )
    parser.add_argument(
        '--eval_mode', action='store_true',
        help='仅加载预训练检查点并进行评估'
    )
    parser.add_argument(
        '--iters', type=int, default=5000, help='训练迭代次数 (默认: 100k)'
    )
    parser.add_argument(
        '--log_dir', type=str, default='./', help='日志和检查点保存路径'
    )
    parser.add_argument(
        '--ckpt_every', type=int, default=1000, help='每 x 次迭代保存模型检查点 (默认: 5k)'
    )

    # 通用选项
    parser.add_argument(
        '--in_seq_len', type=int, default=10, help='输入序列长度 (默认: 10)'
    )
    # 编码器模型选项
    parser.add_argument(
        '--in_channels', type=int, default=3, help='输入特征通道数 (默认: 3)'
    )
    parser.add_argument(
        '--encoder_emb_dim', type=int, default=128, help='编码器中Token嵌入的通道数 (默认: 128)'
    )
    parser.add_argument(
        '--out_seq_emb_dim', type=int, default=128, help='输出特征图的通道数 (默认: 128)'
    )
    parser.add_argument(
        '--encoder_depth', type=int, default=2, help='编码器中Transformer的深度 (默认: 2)'
    )
    parser.add_argument(
        '--encoder_heads', type=int, default=4, help='编码器中Transformer的头数 (默认: 4)'
    )

    # 解码器模型选项
    parser.add_argument(
        '--out_channels', type=int, default=1, help='输出通道数 (默认: 1)'
    )
    parser.add_argument(
        '--decoder_emb_dim', type=int, default=128, help='解码器中Token嵌入的通道数 (默认: 128)'
    )
    parser.add_argument(
        '--out_step', type=int, default=10, help='每次调用向前传播的步数 (默认: 10)'
    )
    parser.add_argument(
        '--out_seq_len', type=int, default=40, help='输出序列长度 (默认: 40)'
    )
    parser.add_argument(
        '--propagator_depth', type=int, default=2, help='传播器中MLP的深度 (默认: 2)'
    )
    parser.add_argument(
        '--decoding_depth', type=int, default=2, help='解码器中解码网络的深度 (默认: 2)'
    )
    parser.add_argument(
        '--fourier_frequency', type=int, default=8, help='傅里叶特征频率 (默认: 8)'
    )
    parser.add_argument(
        '--use_grad', action='store_true',
    )
    parser.add_argument(
        '--curriculum_steps', type=int, default=0, help='在初始阶段，不要展开太长'
    )
    parser.add_argument(
        '--curriculum_ratio', type=float, default=0.2, help='初始阶段的时长比例'
    )
    parser.add_argument(
        '--aug_ratio', type=float, default=0.0, help='随机裁剪的概率'
    )

    # 为自适应解码器添加的新参数
    parser.add_argument(
        '--capacity_ratios',
        type=lambda s: [float(item) for item in s.split(',')],
        default=None,
        help='为每个传播器层设置的容量比率，以逗号分隔。例如："1.0,0.8,0.5"。将覆盖 final_keep_ratio。'
    )
    parser.add_argument(
        '--final_keep_ratio',
        type=float,
        default=0.25,
        help='如果未提供 capacity_ratios，则用于容量线性衰减的最终保留比率。'
    )

    # ===================================
    # 数据集相关
    parser.add_argument(
        '--batch_size', type=int, default=16, help='每个批次的大小 (默认: 16)'
    )
    parser.add_argument(
        '--dataset_path', type=str, required=True, help='数据集路径'
    )
    parser.add_argument(
        '--train_seq_num', type=int, default=50, help='训练集中的序列数量'
    )
    parser.add_argument(
        '--test_seq_num', type=int, default=100, help='测试集中的序列数量'
    )

    return parser


# 主代码开始
if __name__ == '__main__':
    # argparse 用于实验的附加标志
    parser = argparse.ArgumentParser(
        description="训练一个自适应 PDE Transformer")
    parser = get_arguments(parser)
    opt = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu)
    print(f"程序将运行在 GPU: {opt.gpu}")
    print('使用以下选项')
    print(opt)

    print('准备数据中')

    # 实例化网络
    print('构建网络中')
    encoder, decoder = build_model(opt)

    # 如果在GPU上运行并希望使用cuda，则将模型移动到那里
    use_cuda = torch.cuda.is_available()
    if use_cuda:
        encoder = encoder.cuda()
        decoder = decoder.cuda()

    # 通常我们使用tensorboardX来跟踪实验
    writer = SummaryWriter(logdir=opt.log_dir)
    checkpoint_dir = os.path.join(opt.log_dir, 'model_ckpt')
    ensure_dir(checkpoint_dir)

    sample_dir = os.path.join(opt.log_dir, 'samples')
    ensure_dir(sample_dir)

    # 将选项信息保存到磁盘
    logger = logging.getLogger("LOG")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (opt.log_dir, 'logging_info'))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info('=======使用的选项=======')
    for arg in vars(opt):
        logger.info(f'{arg}: {getattr(opt, arg)}')

    # 加载检查点（如果需要）
    start_n_iter = 0

    # 创建优化器
    if opt.path_to_resume != 'none':
        enc_optim = torch.optim.AdamW(list(encoder.parameters()), lr=opt.lr, weight_decay=1e-4)
        dec_optim = torch.optim.AdamW(list(decoder.parameters()), lr=opt.lr, weight_decay=1e-4)

        if opt.resume_training:
            enc_scheduler = OneCycleLR(enc_optim, max_lr=opt.lr, total_steps=opt.iters,
                                       div_factor=1e4,
                                       final_div_factor=1e4,
                                       )
            dec_scheduler = OneCycleLR(dec_optim, max_lr=opt.lr, total_steps=opt.iters,
                                       div_factor=1e4,
                                       final_div_factor=1e4,
                                       )
        else:
            enc_scheduler = OneCycleLR(enc_optim, max_lr=opt.lr, total_steps=opt.iters,
                                       div_factor=20,
                                       pct_start=0.05,
                                       final_div_factor=1e3,
                                       )
            dec_scheduler = OneCycleLR(dec_optim, max_lr=opt.lr, total_steps=opt.iters,
                                       div_factor=20,
                                       pct_start=0.05,
                                       final_div_factor=1e3,
                                       )

        print(f'从以下检查点恢复: {opt.path_to_resume}')
        ckpt = load_checkpoint(opt.path_to_resume)
        encoder.load_state_dict(ckpt['encoder'])
        decoder.load_state_dict(ckpt['decoder'])

        if opt.resume_training:
            enc_optim.load_state_dict(ckpt['enc_optim'])
            dec_optim.load_state_dict(ckpt['dec_optim'])
            enc_scheduler.load_state_dict(ckpt['enc_sched'])
            dec_scheduler.load_state_dict(ckpt['dec_sched'])
            start_n_iter = ckpt['n_iter']
            print("已恢复预训练检查点，继续训练")
            logger.info("已恢复预训练检查点，继续训练")
        elif not opt.eval_mode:
            print("已恢复预训练检查点，使用微调模式")
            logger.info("已恢复预训练检查点，使用微调模式")
        else:

            print("已恢复预训练检查点，使用评估模式")
            logger.info("已恢复预训练检查点，使用评估模式")
    else:
        enc_optim = torch.optim.AdamW(list(encoder.parameters()), lr=opt.lr, weight_decay=1e-4)
        dec_optim = torch.optim.AdamW(list(decoder.parameters()), lr=opt.lr, weight_decay=1e-4)
        enc_scheduler = OneCycleLR(enc_optim, max_lr=opt.lr, total_steps=opt.iters,
                                   div_factor=1e4,
                                   final_div_factor=1e4,
                                   )
        dec_scheduler = OneCycleLR(dec_optim, max_lr=opt.lr, total_steps=opt.iters,
                                   div_factor=1e4,
                                   final_div_factor=1e4,
                                   )
        print("无预训练检查点，从头开始训练")
        logger.info("无预训练检查点，从头开始训练")

    # 现在我们开始主循环
    n_iter = start_n_iter
    data_path = opt.dataset_path
    ntrain = opt.train_seq_num
    ntest = opt.test_seq_num

    data = np.load(data_path)
    x_train = data[:opt.in_seq_len, ..., :ntrain]
    y_train = data[opt.in_seq_len:opt.in_seq_len+opt.out_seq_len, ..., :ntrain]
    x_test = data[:opt.in_seq_len, ..., -ntest:]
    y_test = data[opt.in_seq_len:opt.in_seq_len+opt.out_seq_len, ..., -ntest:]

    x_train = rearrange(torch.as_tensor(x_train, dtype=torch.float32), 't h w n -> n t (h w)')
    x_test = rearrange(torch.as_tensor(x_test, dtype=torch.float32), 't h w n -> n t (h w)')
    y_train = rearrange(torch.as_tensor(y_train, dtype=torch.float32), 't h w n -> n t (h w)')
    y_test = rearrange(torch.as_tensor(y_test, dtype=torch.float32), 't h w n -> n t (h w)')
    del data

    # 高斯归一化
    x_mean = torch.mean(x_train).unsqueeze(0)
    x_std = torch.std(x_train).unsqueeze(0)
    y_mean = torch.mean(y_train).unsqueeze(0)
    y_std = torch.std(y_train).unsqueeze(0)

    x_train = (x_train - x_mean) / x_std
    y_train = (y_train - y_mean) / y_std
    x_test = (x_test - x_mean) / x_std

    if use_cuda:
        x_mean, x_std, y_mean, y_std = x_mean.cuda(), x_std.cuda(), y_mean.cuda(), y_std.cuda()

    x0, y0 = np.meshgrid(np.linspace(0, 1, 64), np.linspace(0, 1, 64))
    xs = np.concatenate((x0[None, ...], y0[None, ...]), axis=0)
    grid = rearrange(torch.from_numpy(xs), 'c h w -> (h w) c').unsqueeze(0).float()

    train_dataloader = DataLoader(TensorDataset(x_train, y_train), batch_size=opt.batch_size, shuffle=True)
    test_dataloader = DataLoader(TensorDataset(x_test, y_test), batch_size=opt.batch_size, shuffle=False)

    # 遍历数据集的循环
    with tqdm(total=opt.iters) as pbar:
        pbar.update(n_iter)
        train_data_iter = iter(train_dataloader)

        while True and (not opt.eval_mode):
            encoder.train()
            decoder.train()

            try:
                data = next(train_data_iter)
            except StopIteration:
                del train_data_iter
                train_data_iter = iter(train_dataloader)
                data = next(train_data_iter)

            # 数据准备
            in_seq, gt = data
            input_pos = prop_pos = repeat(grid, '() n c -> b n c', b=in_seq.shape[0])

            if use_cuda:
                in_seq, gt = in_seq.cuda(), gt.cuda()
                input_pos, prop_pos = input_pos.cuda(), prop_pos.cuda()

            in_seq = rearrange(in_seq, 'b t n -> b n t')

            if np.random.uniform() > (1-opt.aug_ratio):
                sampling_ratio = np.random.uniform(0.45, 0.95)
                input_idx = torch.as_tensor(
                    np.concatenate(
                        [np.random.choice(input_pos.shape[1], int(sampling_ratio*input_pos.shape[1]), replace=False).reshape(1,-1)
                         for _ in range(in_seq.shape[0])], axis=0)
                ).view(in_seq.shape[0], -1).cuda()
                in_seq = index_points(in_seq, input_idx)
                input_pos = index_points(input_pos, input_idx)

            in_seq = torch.cat((in_seq, input_pos), dim=-1)
            z = encoder.forward(in_seq, input_pos)

            if opt.curriculum_steps > 0 and n_iter < int(opt.curriculum_ratio * opt.iters):
                progress = (n_iter*2) / (opt.iters*opt.curriculum_ratio)
                curriculum_steps = opt.curriculum_steps + int(max(0,  progress - 1.)*((opt.out_seq_len - opt.curriculum_steps)/2.)) * 2
                gt = gt[:, :curriculum_steps, :]
                # rollout 现在返回一个元组 (prediction, diagnostics)
                x_out, diagnostics = decoder.rollout(z, prop_pos, curriculum_steps, input_pos)
            else:
                # rollout 现在返回一个元组 (prediction, diagnostics)
                x_out, diagnostics = decoder.rollout(z, prop_pos, opt.out_seq_len, input_pos)

            # 定期记录自适应解码器的诊断信息
            if n_iter % 200 == 0 and diagnostics:
                logger.info(f"迭代 {n_iter} 的自适应诊断信息: {diagnostics}")

            pred_loss = rel_l2norm_loss(x_out, gt)
            loss = pred_loss
            if opt.use_grad:
                gt_grad_x, gt_grad_y = central_diff(gt)
                pred_grad_x, pred_grad_y = central_diff(x_out)
                grad_loss = rel_l2norm_loss(pred_grad_x, gt_grad_x) + rel_l2norm_loss(pred_grad_y, gt_grad_y)
                loss += 5e-2 * grad_loss
            else:
                grad_loss = torch.tensor([-1.])  # 占位符

            enc_optim.zero_grad()
            dec_optim.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(encoder.parameters(), 2.)
            torch.nn.utils.clip_grad_norm_(decoder.parameters(), 2.)
            enc_optim.step()
            dec_optim.step()
            enc_scheduler.step()
            dec_scheduler.step()

            writer.add_scalar('train_loss', loss, n_iter)
            writer.add_scalar('prediction_loss', pred_loss, n_iter)
            pbar.set_description(
                f'Total (1e-4): {loss.item()*1e4:.1f}||'
                f'pred (1e-4): {pred_loss.item()*1e4:.1f}||'
                f'grad (1e-4): {grad_loss.item()*1e4:.1f}||'
                f'lr (1e-3): {enc_scheduler.get_last_lr()[0]*1e3:.4f}||'
                f'Seq len: {gt.shape[1]}||'
            )
            pbar.update(1)
            n_iter += 1

            if opt.eval_mode or ((n_iter-1) % opt.ckpt_every == 0 or n_iter >= opt.iters):
                logger.info('测试中')
                print('测试中')
                encoder.eval()
                decoder.eval()

                with torch.no_grad():
                    all_avg_loss, all_acc_loss, all_last_loss = [], [], []
                    visualization_cache = {'in_seq': [], 'pred': [], 'gt': []}
                    picked = 0
                    for j, data in enumerate(tqdm(test_dataloader)):
                        in_seq, gt = data
                        input_pos = prop_pos = repeat(grid, '() n c -> b n c', b=in_seq.shape[0])
                        if use_cuda:
                            in_seq, gt = in_seq.cuda(), gt.cuda()
                            input_pos, prop_pos = input_pos.cuda(), prop_pos.cuda()

                        in_seq = rearrange(in_seq, 'b t n -> b n t')
                        in_seq = torch.cat((in_seq, input_pos), dim=-1)
                        z = encoder.forward(in_seq, input_pos)
                        # 在评估期间忽略诊断信息
                        x_out, _ = decoder.rollout(z, prop_pos, opt.out_seq_len, input_pos)
                        x_out = x_out * y_std + y_mean  # 反归一化

                        avg_loss = rel_loss(x_out, gt, p=2)
                        accumulated_mse = torch.nn.MSELoss(reduction='sum')(x_out, gt) / (gt.shape[-1] * gt.shape[0])
                        loss_at_last_step = rel_loss(x_out[:, -1:, ...], gt[:, -1:, ...], p=2)

                        all_avg_loss.append(avg_loss.item())
                        all_acc_loss.append(accumulated_mse.item())
                        all_last_loss.append(loss_at_last_step.item())

                        in_seq = in_seq[:, ..., :-2]
                        in_seq = rearrange(rearrange(in_seq, 'b n t -> b t n') * x_std + x_mean, 'b t (h w) -> b t h w', h=64, w=64)
                        x_out = rearrange(x_out, 'b t (h w) -> b t h w', h=64, w=64)
                        gt = rearrange(gt, 'b t (h w) -> b t h w', h=64, w=64)

                        if picked < 20:
                            idx = np.arange(0, min(20 - picked, in_seq.shape[0]))
                            visualization_cache['gt'].append(gt[idx, ::2])
                            visualization_cache['in_seq'].append(in_seq[idx, ::2])
                            visualization_cache['pred'].append(x_out[idx, ::2])
                            picked += in_seq.shape[0]

                    writer.add_scalar('testing avg loss', np.mean(all_avg_loss), global_step=n_iter)
                    print(f'测试平均损失 (1e-4): {np.mean(all_avg_loss)*1e4}')
                    print(f'测试累积MSE损失 (1e-4): {np.mean(all_acc_loss)*1e4}')
                    print(f'测试最后一步损失 (1e-4): {np.mean(all_last_loss)*1e4}')
                    logger.info(f'当前迭代: {n_iter}')
                    logger.info(f'测试平均损失 (1e-4): {np.mean(all_avg_loss)*1e4}')
                    logger.info(f'测试累积MSE损失 (1e-4): {np.mean(all_acc_loss)*1e4}')
                    logger.info(f'测试最后一步损失 (1e-4): {np.mean(all_last_loss)*1e4}')

                    if not opt.eval_mode:
                        ckpt = {
                            'encoder': encoder.state_dict(), 'decoder': decoder.state_dict(),
                            'n_iter': n_iter,
                            'enc_optim': enc_optim.state_dict(), 'dec_optim': dec_optim.state_dict(),
                            'enc_sched': enc_scheduler.state_dict(), 'dec_sched': dec_scheduler.state_dict(),
                        }
                        save_checkpoint(ckpt, os.path.join(checkpoint_dir, f'model_checkpoint{n_iter}.ckpt'))
                        del ckpt
                    if opt.eval_mode or (n_iter >= opt.iters):
                        print('运行结束...')
                        exit()