from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F

class ImprovedCausalFilter(nn.Module):
    def __init__(self,
                 input_dim: int,
                 hidden_dim: int | None = None,
                 lambda_init: float = 1.0,
                 lambda_min: float = -2.0,
                 decay_rate: float = 0.99,
                 temperature: float = 1.0,
                 residual_weight: float = 0.2,
                 normalize: bool = False,
                 dropout: float = 0.1,
                 heads: int = 4,
                 channel_reduction: int = 4,
                 temp_decay: float = 0.995,
                 temp_min: float = 0.2,
                 sparsity_target: float = 0.5,
                 sparsity_warmup: int = 5,
                 calibrate_output: bool = True,
                 **kwargs):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim if hidden_dim is not None else max(input_dim // 4, 8)
        self.lambda_init = lambda_init
        self.lambda_min = lambda_min
        self.decay_rate = decay_rate
        self.temperature0 = temperature
        self.temperature = temperature
        self.temp_decay = temp_decay
        self.temp_min = temp_min
        self.residual_weight0 = residual_weight
        self.normalize = normalize
        self.heads = heads
        self.channel_reduction = max(1, channel_reduction)
        self.sparsity_target = sparsity_target
        self.sparsity_warmup = sparsity_warmup
        self.calibrate_output = calibrate_output
        self.node_mlps = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, self.hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(self.hidden_dim, 1),
                nn.Tanh()
            ) for _ in range(heads)
        ])
        ch_hidden = max(4, input_dim // self.channel_reduction)
        self.channel_gate = nn.Sequential(
            nn.Linear(input_dim, ch_hidden),
            nn.ReLU(),
            nn.Linear(ch_hidden, input_dim),
        )
        self.lambda_param = nn.Parameter(torch.tensor(lambda_init, dtype=torch.float32))
        self.epoch_step = 0
        self.gate_stats = {
            'mean': 0.0, 'std': 0.0, 'min': 0.0, 'max': 0.0,
            'entropy': 0.0, 'sparsity': 0.0
        }
        self.extra_stats = {}
    def get_current_lambda(self):
        decay_factor = self.decay_rate ** self.epoch_step
        return max(self.lambda_param * decay_factor, self.lambda_min)
    def _adaptive_residual(self, gate_mean: torch.Tensor) -> float:
        return float(self.residual_weight0 * (1.0 - gate_mean.clamp(0, 1)))
    def _entropy(self, p: torch.Tensor) -> torch.Tensor:
        eps = 1e-8
        return -(p * (p + eps).log() + (1 - p) * (1 - p + eps).log())
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        head_scores = [mlp(x) for mlp in self.node_mlps]
        node_u = torch.stack(head_scores, dim=1).squeeze(-1)
        alpha = 0.7
        node_score = alpha * node_u.mean(dim=1, keepdim=True) + (1 - alpha) * node_u.min(dim=1, keepdim=True).values
        ch_context = x.mean(dim=0, keepdim=True)
        ch_logits = self.channel_gate(ch_context)
        ch_gate = torch.sigmoid(ch_logits)
        current_lambda = self.get_current_lambda()
        temp = self.temperature
        raw = current_lambda + node_score @ torch.ones(1, x.size(1), device=x.device)
        raw = raw + ch_gate
        gate = torch.sigmoid(raw / temp)
        if self.epoch_step >= self.sparsity_warmup:
            with torch.no_grad():
                mean_act = gate.mean()
                adjust = (mean_act - self.sparsity_target)
            gate = (gate - 0.1 * adjust).clamp(0.0, 1.0)
        gate_mean = gate.mean()
        residual_dynamic = self._adaptive_residual(gate_mean)
        filtered_x = gate * x + residual_dynamic * x
        if self.calibrate_output:
            in_mean = x.mean(dim=0, keepdim=True)
            in_std = x.std(dim=0, keepdim=True).clamp_min(1e-6)
            out_mean = filtered_x.mean(dim=0, keepdim=True)
            out_std = filtered_x.std(dim=0, keepdim=True).clamp_min(1e-6)
            filtered_x = (filtered_x - out_mean) / out_std * in_std + in_mean
        if self.normalize:
            filtered_x = F.layer_norm(filtered_x, filtered_x.shape[-1:])
        with torch.no_grad():
            entropy = self._entropy(gate).mean().item()
            sparsity = (gate < 0.5).float().mean().item()
            self.gate_stats.update(
                mean=gate_mean.item(),
                std=gate.std().item(),
                min=gate.min().item(),
                max=gate.max().item(),
                entropy=entropy,
                sparsity=sparsity,
            )
            self.extra_stats = {
                'node_score_mean': node_score.mean().item(),
                'channel_gate_mean': ch_gate.mean().item(),
                'temperature': temp,
                'residual_dynamic': residual_dynamic,
            }
        return filtered_x
    def step(self):
        self.epoch_step += 1
        self.temperature = max(self.temp_min, self.temperature * self.temp_decay)
    def reset_lambda(self):
        self.lambda_param.data.fill_(self.lambda_init)
        self.epoch_step = 0
        self.temperature = self.temperature0
    def get_stats(self):
        cur_lambda = self.get_current_lambda()
        return {
            'lambda': float(cur_lambda.detach().item() if isinstance(cur_lambda, torch.Tensor) else cur_lambda),
            'epoch_step': self.epoch_step,
            'gate_stats': self.gate_stats.copy(),
            'extra': self.extra_stats.copy(),
        }

class AdaptiveCausalFilter(nn.Module):
    """轻量级自适应滤除 (保持旧代码兼容)"""
    def __init__(self, input_dim, hidden_dim=None, warmup_epochs=10, max_epochs=50,
                 dropout=0.1, residual_weight=0.1, base_strength=0.1, max_strength=0.9, **kwargs):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim if hidden_dim is not None else max(input_dim // 4, 8)
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.residual_weight = residual_weight
        self.base_strength = base_strength
        self.max_strength = max_strength
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(self.hidden_dim, 1)
        )
        self.epoch_step = 0
    def forward(self, x):
        if self.epoch_step < self.warmup_epochs:
            progress = self.epoch_step / max(1, self.warmup_epochs)
            filter_strength = self.base_strength * progress
            phase = 'warmup'
        else:
            tail_progress = (self.epoch_step - self.warmup_epochs) / max(1, self.max_epochs - self.warmup_epochs)
            tail_progress = float(min(1.0, tail_progress))
            filter_strength = self.base_strength + (self.max_strength - self.base_strength) * tail_progress
            phase = 'training'
        importance = torch.sigmoid(self.mlp(x))
        gate = (1 - filter_strength) + filter_strength * importance
        out = gate * x + self.residual_weight * x
        self._last_stats = {
            'epoch_step': self.epoch_step,
            'filter_strength': float(filter_strength),
            'phase': phase
        }
        return out
    def step(self):
        self.epoch_step += 1
    def get_stats(self):
        return getattr(self, '_last_stats', {
            'epoch_step': self.epoch_step,
            'filter_strength': 0.0,
            'phase': 'warmup' if self.epoch_step < self.warmup_epochs else 'training'
        })

CausalFilter = ImprovedCausalFilter
