# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import math
import sys
from typing import Iterable

import torch

import util.misc as misc
import util.lr_sched as lr_sched


def train_one_epoch_de(model: torch.nn.Module,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler,
                    log_writer=None,
                    args=None):
    model.train(True)
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 20

    accum_iter = args.accum_iter

    optimizer.zero_grad()

    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))

    for data_iter_step, (samples_noisy, samples_clean) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):

        # we use a per iteration (instead of per epoch) lr scheduler
        if data_iter_step % accum_iter == 0:
            lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)

        samples_noisy = samples_noisy.to(device, non_blocking=True)
        samples_clean = samples_clean.to(device, non_blocking=True)

        with torch.cuda.amp.autocast():
            loss, _,_ = model(samples_noisy, samples_clean)

        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        loss /= accum_iter
        loss_scaler(loss, optimizer, parameters=model.parameters(),
                    update_grad=(data_iter_step + 1) % accum_iter == 0)
        if (data_iter_step + 1) % accum_iter == 0:
            optimizer.zero_grad()

        torch.cuda.synchronize()

        metric_logger.update(loss=loss_value)

        lr = optimizer.param_groups[0]["lr"]
        metric_logger.update(lr=lr)

        loss_value_reduce = misc.all_reduce_mean(loss_value)
        if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
            """ We use epoch_1000x as the x-axis in tensorboard.
            This calibrates different curves when batch size changes.
            """
            epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
            log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
            log_writer.add_scalar('lr', lr, epoch_1000x)


    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


def get_random_mask_half(batch_size, seq_len, patch_size, device=None):
    """
    在长度为 seq_len 的序列里(每 patch_size 个点合成1个patch)，
    共有 n_patches = seq_len // patch_size 个 patch。
    本函数从这 n_patches 个 patch 中随机选出一半进行掩码(mask=1)，另一半保留(mask=0)。

    返回 shape=[batch_size, n_patches] 的 0/1 mask 张量。
    """
    import torch

    n_patches = seq_len // patch_size
    if device is None:
        device = torch.device("cpu")

    # 先构造全0张量
    mask = torch.zeros(batch_size, n_patches, dtype=torch.float32, device=device)

    # 要选出多少个 patch 来掩码 (一半)
    n_mask = n_patches // 2

    for b in range(batch_size):
        # 先随机打散 n_patches 的所有索引
        indices = torch.randperm(n_patches, device=device)
        # 取前 n_mask 个作为要掩码的 patch
        mask_indices = indices[:n_mask]
        # 在 mask 张量上把这些位置置为1
        mask[b, mask_indices] = 1.0

    return mask
def train_one_epoch_r(
    model: torch.nn.Module,
    data_loader: Iterable,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    epoch: int,
    loss_scaler,
    log_writer=None,
    args=None
):


    model.train(True)
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 20

    accum_iter = args.accum_iter

    optimizer.zero_grad()

    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))

    for data_iter_step, (samples_noisy, samples_clean) in enumerate(
        metric_logger.log_every(data_loader, print_freq, header)
    ):
        # --- 每 iteration 调整一次学习率 (若你的逻辑是这样) ---
        if data_iter_step % accum_iter == 0:
            lr_sched.adjust_learning_rate(
                optimizer,
                data_iter_step / len(data_loader) + epoch,
                args
            )

        samples_noisy = samples_noisy.to(device, non_blocking=True)
        samples_clean = samples_clean.to(device, non_blocking=True)

        # ============ 在此处生成随机 mask (每 iteration 一次) ============
        B = samples_noisy.size(0)
        seq_len = 256     # 你的序列长度
        patch_size = 2    # 你的 patch_size
        # 假设你已在其他地方定义好 get_random_pairwise_mask 函数
        mask = get_random_mask_half(B, seq_len, patch_size, device=device)

        # ============ 前向传播，传入 mask ============
        with torch.cuda.amp.autocast():
            loss, _, _ = model(samples_noisy, samples_clean, mask=mask)
            # 注意：要确保你的 model(...) 的 forward 支持 mask=xxx

        loss_value = loss.item()
        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        # 处理累积梯度
        loss = loss / accum_iter
        loss_scaler(
            loss,
            optimizer,
            parameters=model.parameters(),
            update_grad=(data_iter_step + 1) % accum_iter == 0
        )
        if (data_iter_step + 1) % accum_iter == 0:
            optimizer.zero_grad()

        torch.cuda.synchronize()

        # 日志记录
        metric_logger.update(loss=loss_value)
        lr = optimizer.param_groups[0]["lr"]
        metric_logger.update(lr=lr)

        loss_value_reduce = misc.all_reduce_mean(loss_value)
        if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
            # 可以把 iteration/epoch 作为 x 轴写到 TensorBoard
            epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
            log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
            log_writer.add_scalar('lr', lr, epoch_1000x)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}