import torch
from torch import nn
import time  # NEW: 导入 time 模块

from models.ipot.layers import PreNorm, Attention, FeedForward

class IPOTProcessor(nn.Module):
    def __init__(
            self,
            *,
            self_per_cross_attn=4,
            latent_channel=64,
            self_heads_num=8,
            self_heads_channel=None,
            attn_dropout=0.,
            ff_dropout=0.,
            ff_mult=4,
            use_query_residual=True,
            weight_tie_layers=False # 保持原始参数
    ):
        super().__init__()
        self.use_query_residual = use_query_residual
        if self_heads_channel is None:
            self_heads_channel = int(latent_channel // self_heads_num)
        # NEW: 初始化计时器相关的属性
        self.is_timing_enabled = False
        self.total_inference_time = 0.0
        self.inference_call_count = 0


        self.processor_self_attn = PreNorm(
            latent_channel,
            Attention(
                latent_channel,
                heads_num=self_heads_num,
                heads_channel=self_heads_channel,
                dropout=attn_dropout
            )
        )
        self.processor_ff = PreNorm(
            latent_channel,
            FeedForward(
                latent_channel,
                mult=ff_mult,
                dropout=ff_dropout
            )
        )

        self.layers = nn.ModuleList([])

        for i in range(self_per_cross_attn):
            self.layers.append(nn.ModuleList([
                self.processor_self_attn,
                self.processor_ff
            ]))
    
    # NEW: 增加用于控制和报告的公共方法
    def enable_timing(self):
        """开启计时器"""
        print(f"[{self.__class__.__name__}] 内部计时器已开启。")
        self.is_timing_enabled = True

    def disable_timing(self):
        """关闭计时器"""
        self.is_timing_enabled = False

    def reset_timer(self):
        """重置计时统计"""
        self.total_inference_time = 0.0
        self.inference_call_count = 0

    def report_time(self):
        """报告平均推理耗时（毫秒）"""
        if self.inference_call_count == 0:
            return 0.0
        # 返回平均时间，并转换为毫秒
        avg_time_ms = (self.total_inference_time / self.inference_call_count) * 1000
        return avg_time_ms

    def forward(self, z, mask=None):
        
        # NEW: 在 forward 方法中包裹计时逻辑
        # 条件：必须开启计时器，并且模型必须处于评估模式 (model.eval())
        if self.is_timing_enabled and not self.training:
            torch.cuda.synchronize()
            start_time = time.perf_counter()

        # --- 原始的核心逻辑保持不变 ---
        for self_attn, self_ff in self.layers:
            if self.use_query_residual:
                z = self_attn(z, context=z) + z
            else:
                z = self_attn(z, context=z)

            z = self_ff(z) + z
        # -----------------------------

        if self.is_timing_enabled and not self.training:
            torch.cuda.synchronize()
            end_time = time.perf_counter()
            self.total_inference_time += (end_time - start_time)
            self.inference_call_count += 1

        return z