from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F

class ImprovedCausalFilter(nn.Module):
    def __init__(self,
                 input_dim: int,
                 hidden_dim: int | None = None,
                 lambda_init: float = 1.0,
                 lambda_min: float = -2.0,
                 decay_rate: float = 0.99,
                 temperature: float = 1.0,
                 residual_weight: float = 0.2,
                 normalize: bool = False,
                 dropout: float = 0.1,
                 heads: int = 4,
                 channel_reduction: int = 4,
                 temp_decay: float = 0.995,
                 temp_min: float = 0.2,
                 sparsity_target: float = 0.5,
                 sparsity_warmup: int = 5,
                 calibrate_output: bool = True,
                 
                 adversarial_strength: float = 0.01,
                 feature_decorr: bool = True,
                 gradient_penalty: float = 0.1,
                 moment_matching: bool = True,
                 spectral_norm: bool = True,
                 **kwargs):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim if hidden_dim is not None else max(input_dim // 4, 8)
        self.lambda_init = lambda_init
        self.lambda_min = lambda_min
        self.decay_rate = decay_rate
        self.temperature0 = temperature
        self.temperature = temperature
        self.temp_decay = temp_decay
        self.temp_min = temp_min
        self.residual_weight0 = residual_weight
        self.normalize = normalize
        self.heads = heads
        self.channel_reduction = max(1, channel_reduction)
        self.sparsity_target = sparsity_target
        self.sparsity_warmup = sparsity_warmup
        self.calibrate_output = calibrate_output
        
        
        self.adversarial_strength = adversarial_strength
        self.feature_decorr = feature_decorr
        self.gradient_penalty = gradient_penalty
        self.moment_matching = moment_matching
        
        
        self.node_mlps = nn.ModuleList()
        for _ in range(heads):
            layers = [
                nn.Linear(input_dim, self.hidden_dim),
                nn.BatchNorm1d(self.hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(self.hidden_dim, self.hidden_dim // 2),
                nn.ReLU(),
                nn.Linear(self.hidden_dim // 2, 1),
                nn.Tanh()
            ]
            if spectral_norm:
                
                layers[0] = nn.utils.spectral_norm(layers[0])
                layers[4] = nn.utils.spectral_norm(layers[4])
                layers[6] = nn.utils.spectral_norm(layers[6])
            self.node_mlps.append(nn.Sequential(*layers))
        
        
        ch_hidden = max(8, input_dim // self.channel_reduction)
        ch_layers = [
            nn.Linear(input_dim, ch_hidden),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(ch_hidden, ch_hidden),
            nn.ReLU(), 
            nn.Linear(ch_hidden, input_dim)
        ]
        if spectral_norm:
            ch_layers[0] = nn.utils.spectral_norm(ch_layers[0])
            ch_layers[3] = nn.utils.spectral_norm(ch_layers[3])
            ch_layers[5] = nn.utils.spectral_norm(ch_layers[5])
        self.channel_gate = nn.Sequential(*ch_layers)
        
        
        self.confound_detector = nn.Sequential(
            nn.Linear(input_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, 1),
            nn.Sigmoid()
        )
        
        
        if feature_decorr:
            self.causal_proj = nn.Linear(input_dim, input_dim)
            self.confound_proj = nn.Linear(input_dim, input_dim)
        
        self.lambda_param = nn.Parameter(torch.tensor(lambda_init, dtype=torch.float32))
        self.epoch_step = 0
        self.gate_stats = {
            'mean': 0.0, 'std': 0.0, 'min': 0.0, 'max': 0.0,
            'entropy': 0.0, 'sparsity': 0.0
        }
        self.extra_stats = {}
        
        
        if moment_matching:
            self.register_buffer('running_mean', torch.zeros(input_dim))
            self.register_buffer('running_var', torch.ones(input_dim))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
    def get_current_lambda(self):
        decay_factor = self.decay_rate ** self.epoch_step
        return max(self.lambda_param * decay_factor, self.lambda_min)
    def _adaptive_residual(self, gate_mean: torch.Tensor) -> float:
        return float(self.residual_weight0 * (1.0 - gate_mean.clamp(0, 1)))
    
    def _entropy(self, p: torch.Tensor) -> torch.Tensor:
        eps = 1e-8
        return -(p * (p + eps).log() + (1 - p) * (1 - p + eps).log())
    
    def _adversarial_noise(self, x: torch.Tensor) -> torch.Tensor:
        
        if not self.training or self.adversarial_strength <= 0:
            return torch.zeros_like(x)
        
        x_adv = x.clone().detach().requires_grad_(True)
        fake_loss = self.confound_detector(x_adv).mean()
        grad = torch.autograd.grad(fake_loss, x_adv, create_graph=False)[0]
        noise = self.adversarial_strength * grad.sign()
        return noise.detach()
    
    def _feature_decorrelation_loss(self, x: torch.Tensor) -> torch.Tensor:
        
        if not hasattr(self, 'causal_proj') or not self.training:
            return torch.tensor(0.0, device=x.device)
        
        causal_feat = self.causal_proj(x)
        confound_feat = self.confound_proj(x)
        
        
        causal_norm = F.normalize(causal_feat, p=2, dim=0)
        confound_norm = F.normalize(confound_feat, p=2, dim=0)
        cross_corr = torch.mm(causal_norm.T, confound_norm)
        
        
        mask = torch.eye(cross_corr.size(0), device=x.device)
        decorr_loss = (cross_corr * (1 - mask)).abs().mean()
        return decorr_loss
    
    def _moment_matching_regularization(self, x_filtered: torch.Tensor, x_orig: torch.Tensor) -> torch.Tensor:
        
        if not self.moment_matching or not self.training:
            return torch.tensor(0.0, device=x_filtered.device)
        
        
        if self.training:
            with torch.no_grad():
                batch_mean = x_orig.mean(dim=0)
                batch_var = x_orig.var(dim=0, unbiased=False)
                momentum = 0.1
                self.running_mean = (1 - momentum) * self.running_mean + momentum * batch_mean
                self.running_var = (1 - momentum) * self.running_var + momentum * batch_var
        
        
        filtered_mean = x_filtered.mean(dim=0)
        mean_loss = F.mse_loss(filtered_mean, self.running_mean.detach())
        
        
        filtered_var = x_filtered.var(dim=0, unbiased=False)
        var_loss = F.mse_loss(filtered_var, self.running_var.detach())
        
        
        centered_orig = x_orig - self.running_mean.detach()
        centered_filt = x_filtered - filtered_mean.detach()
        skew_orig = (centered_orig ** 3).mean(dim=0)
        skew_filt = (centered_filt ** 3).mean(dim=0)
        skew_loss = F.mse_loss(skew_filt, skew_orig.detach())
        
        return mean_loss + 0.5 * var_loss + 0.1 * skew_loss
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.size(0)
        
        
        adv_noise = self._adversarial_noise(x)
        x_robust = x + adv_noise
        
        
        if hasattr(self, 'causal_proj') and self.training:
            causal_component = self.causal_proj(x)
            confound_component = self.confound_proj(x)
            
            x_work = 0.8 * causal_component + 0.2 * x  
        else:
            x_work = x_robust
        
        
        head_scores = []
        for mlp in self.node_mlps:
            
            if batch_size < 4:
                score = mlp[:-1](x_work)  
                if len(score.shape) > 1 and score.size(-1) > 1:
                    score = F.layer_norm(score, score.shape[-1:])
                score = torch.tanh(score)
            else:
                score = mlp(x_work)
            head_scores.append(score)
        
        node_u = torch.stack(head_scores, dim=1).squeeze(-1)
        
        
        head_consistency = 1.0 - node_u.var(dim=1, keepdim=True).clamp(min=1e-6)
        head_weights = F.softmax(head_consistency / 0.1, dim=1)  
        node_score = (node_u * head_weights).sum(dim=1, keepdim=True)
        
        
        ch_context = x.mean(dim=0, keepdim=True)  
        ch_context_enhanced = torch.cat([
            ch_context, 
            x.max(dim=0, keepdim=True)[0],  
            x.std(dim=0, keepdim=True)      
        ], dim=-1)
        
        
        if ch_context_enhanced.size(-1) != self.input_dim:
            ch_context_enhanced = ch_context  
        
        ch_logits = self.channel_gate(ch_context_enhanced)
        ch_gate = torch.sigmoid(ch_logits)
        
        
        current_lambda = self.get_current_lambda()
        
        
        data_complexity = x.std().item()  
        adaptive_lambda = current_lambda * (1.0 + 0.1 * data_complexity)
        
        temp = self.temperature
        
        
        raw = adaptive_lambda + node_score @ torch.ones(1, x.size(1), device=x.device)
        raw = raw + ch_gate
        
        
        interaction = 0.1 * torch.tanh(node_score) * ch_gate
        raw = raw + interaction
        
        gate = torch.sigmoid(raw / temp)
        
        
        if self.epoch_step >= self.sparsity_warmup:
            with torch.no_grad():
                mean_act = gate.mean()
                
                gate_entropy = self._entropy(gate).mean()
                sparsity_strength = 0.1 * (1.0 + gate_entropy)  
                adjust = sparsity_strength * (mean_act - self.sparsity_target)
            gate = (gate - adjust).clamp(0.0, 1.0)
        
        
        gate_mean = gate.mean()
        gate_confidence = 1.0 - gate.var()  
        residual_dynamic = self._adaptive_residual(gate_mean) * gate_confidence.clamp(0.5, 1.0)
        
        
        if self.training and self.gradient_penalty > 0:
            
            gate_grad = torch.autograd.grad(
                gate.mean(), x, create_graph=True, retain_graph=True
            )[0] if x.requires_grad else None
            if gate_grad is not None:
                grad_penalty = (gate_grad.norm(dim=-1) ** 2).mean()
                
                self.extra_stats['grad_penalty'] = grad_penalty.item()
        
        filtered_x = gate * x + residual_dynamic * x
        
        
        if self.calibrate_output:
            
            in_mean = x.mean(dim=0, keepdim=True)
            in_std = x.std(dim=0, keepdim=True).clamp_min(1e-6)
            out_mean = filtered_x.mean(dim=0, keepdim=True)
            out_std = filtered_x.std(dim=0, keepdim=True).clamp_min(1e-6)
            filtered_x = (filtered_x - out_mean) / out_std * in_std + in_mean
            
            
            moment_reg = self._moment_matching_regularization(filtered_x, x)
            self.extra_stats['moment_reg'] = moment_reg.item()
        
        if self.normalize:
            filtered_x = F.layer_norm(filtered_x, filtered_x.shape[-1:])
        
        
        if self.training:
            decorr_loss = self._feature_decorrelation_loss(x)
            self.extra_stats['decorr_loss'] = decorr_loss.item()
        
        
        with torch.no_grad():
            entropy = self._entropy(gate).mean().item()
            sparsity = (gate < 0.5).float().mean().item()
            
            
            gate_consistency = 1.0 - gate.var(dim=0).mean().item()  
            gate_selectivity = (gate.max(dim=0)[0] - gate.min(dim=0)[0]).mean().item()  
            
            self.gate_stats.update(
                mean=gate_mean.item(),
                std=gate.std().item(),
                min=gate.min().item(),
                max=gate.max().item(),
                entropy=entropy,
                sparsity=sparsity,
            )
            self.extra_stats.update({
                'node_score_mean': node_score.mean().item(),
                'channel_gate_mean': ch_gate.mean().item(),
                'temperature': temp,
                'residual_dynamic': float(residual_dynamic),
                'gate_consistency': gate_consistency,
                'gate_selectivity': gate_selectivity,
                'adaptive_lambda': adaptive_lambda,
            })
        
        return filtered_x
    def step(self):
        self.epoch_step += 1
        self.temperature = max(self.temp_min, self.temperature * self.temp_decay)
    def reset_lambda(self):
        self.lambda_param.data.fill_(self.lambda_init)
        self.epoch_step = 0
        self.temperature = self.temperature0
    def get_stats(self):
        cur_lambda = self.get_current_lambda()
        return {
            'lambda': float(cur_lambda.detach().item() if isinstance(cur_lambda, torch.Tensor) else cur_lambda),
            'epoch_step': self.epoch_step,
            'gate_stats': self.gate_stats.copy(),
            'extra': self.extra_stats.copy(),
        }
    
    def get_auxiliary_losses(self) -> dict:
        
        aux_losses = {}
        if hasattr(self, 'extra_stats'):
            for key in ['decorr_loss', 'moment_reg', 'grad_penalty']:
                if key in self.extra_stats:
                    aux_losses[key] = self.extra_stats[key]
        return aux_losses

class AdaptiveCausalFilter(nn.Module):
    
    def __init__(self, input_dim, hidden_dim=None, warmup_epochs=10, max_epochs=50,
                 dropout=0.1, residual_weight=0.1, base_strength=0.1, max_strength=0.9, **kwargs):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim if hidden_dim is not None else max(input_dim // 4, 8)
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.residual_weight = residual_weight
        self.base_strength = base_strength
        self.max_strength = max_strength
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(self.hidden_dim, 1)
        )
        self.epoch_step = 0
    def forward(self, x):
        if self.epoch_step < self.warmup_epochs:
            progress = self.epoch_step / max(1, self.warmup_epochs)
            filter_strength = self.base_strength * progress
            phase = 'warmup'
        else:
            tail_progress = (self.epoch_step - self.warmup_epochs) / max(1, self.max_epochs - self.warmup_epochs)
            tail_progress = float(min(1.0, tail_progress))
            filter_strength = self.base_strength + (self.max_strength - self.base_strength) * tail_progress
            phase = 'training'
        importance = torch.sigmoid(self.mlp(x))  
        gate = (1 - filter_strength) + filter_strength * importance
        out = gate * x + self.residual_weight * x
        self._last_stats = {
            'epoch_step': self.epoch_step,
            'filter_strength': float(filter_strength),
            'phase': phase
        }
        return out
    def step(self):
        self.epoch_step += 1
    def get_stats(self):
        return getattr(self, '_last_stats', {
            'epoch_step': self.epoch_step,
            'filter_strength': 0.0,
            'phase': 'warmup' if self.epoch_step < self.warmup_epochs else 'training'
        })

CausalFilter = ImprovedCausalFilter
