import torch
from torch import nn
import time  # NEW: 导入 time 模块

# 假设 'StaticRouter' 已经添加到了 layers.py
from models.ipot.layers import PreNorm, Attention, FeedForward, StaticRouter


class IPOTProcessorAdapt(nn.Module):
    def __init__(
            self,
            *,
            self_per_cross_attn=4,
            latent_channel=64,
            self_heads_num=8,
            self_heads_channel=None,
            attn_dropout=0.,
            ff_dropout=0.,
            ff_mult=4,
            use_query_residual=True,
            reduction_schedule: tuple = (1.0, 0.9, 0.8, 0.7),
            router_hidden_dim_ratio: float = 0.25
    ):
        super().__init__()
        assert len(reduction_schedule) == self_per_cross_attn, \
            f"Reduction schedule length ({len(reduction_schedule)}) must match the number of processor layers ({self_per_cross_attn})."

        self.use_query_residual = use_query_residual
        self.reduction_schedule = reduction_schedule
        
        # NEW: 初始化计时器相关的属性
        self.is_timing_enabled = False
        self.total_inference_time = 0.0
        self.inference_call_count = 0

        self.router = StaticRouter(
            latent_channel=latent_channel,
            hidden_dim_ratio=router_hidden_dim_ratio
        )
        if self_heads_channel is None:
            self_heads_channel = int(latent_channel // self_heads_num)
            
        self.processor_self_attn = PreNorm(
            latent_channel,
            Attention(
                latent_channel,
                heads_num=self_heads_num,
                heads_channel=self_heads_channel,
                dropout=attn_dropout
            )
        )
        self.processor_ff = PreNorm(
            latent_channel,
            FeedForward(
                latent_channel,
                mult=ff_mult,
                dropout=ff_dropout
            )
        )
        
        self.layers = nn.ModuleList([
            nn.ModuleList([self.processor_self_attn, self.processor_ff]) for _ in range(self_per_cross_attn)
        ])
        
    # NEW: 增加用于控制和报告的公共方法 (与IPOTProcessor完全相同)
    def enable_timing(self):
        """开启计时器"""
        print(f"[{self.__class__.__name__}] 内部计时器已开启。")
        self.is_timing_enabled = True

    def disable_timing(self):
        """关闭计时器"""
        self.is_timing_enabled = False

    def reset_timer(self):
        """重置计时统计"""
        self.total_inference_time = 0.0
        self.inference_call_count = 0

    def report_time(self):
        """报告平均推理耗时（毫秒）"""
        if self.inference_call_count == 0:
            return 0.0
        avg_time_ms = (self.total_inference_time / self.inference_call_count) * 1000
        return avg_time_ms

    def forward(self, z, mask=None):
        
        # NEW: 在 forward 方法中包裹计时逻辑
        if self.is_timing_enabled and not self.training:
            torch.cuda.synchronize()
            start_time = time.perf_counter()
            
        # --- 原始的核心逻辑保持不变 ---
        b, nz, d = z.shape
        scores = self.router(z)
        sorted_indices = torch.argsort(scores.squeeze(-1), dim=1, descending=True)

        for i, (self_attn, self_ff) in enumerate(self.layers):
            keep_ratio = self.reduction_schedule[i]
            k = int(nz * keep_ratio)
            k = max(k, 1)

            if k == nz:
                active_z = z
                active_indices = None
            else:
                active_indices = sorted_indices[:, :k]
                expanded_active_indices = active_indices.unsqueeze(-1).expand(-1, -1, d)
                active_z = torch.gather(z, 1, expanded_active_indices)

            if self.use_query_residual:
                active_z_updated = self_attn(active_z, context=active_z) + active_z
            else:
                active_z_updated = self_attn(active_z, context=active_z)
            
            active_z_updated = self_ff(active_z_updated) + active_z_updated

            if active_indices is not None:
                z = z.scatter(1, expanded_active_indices, active_z_updated)
            else:
                z = active_z_updated
        # -----------------------------

        if self.is_timing_enabled and not self.training:
            torch.cuda.synchronize()
            end_time = time.perf_counter()
            self.total_inference_time += (end_time - start_time)
            self.inference_call_count += 1
            
        return z