import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class CausalFilter(nn.Module):
    
    
    def __init__(self, input_dim, hidden_dim=None, lambda_init=5.0, lambda_min=-2.0, 
                 decay_rate=0.99, normalize=False, dropout=0.1):
        super(CausalFilter, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim if hidden_dim is not None else max(input_dim // 2, 16)
        self.lambda_init = lambda_init
        self.lambda_min = lambda_min
        self.decay_rate = decay_rate
        self.normalize = normalize
        
        
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(self.hidden_dim, self.hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(self.hidden_dim // 2, 1)  
        )
        
        
        self.lambda_param = nn.Parameter(torch.tensor(lambda_init, dtype=torch.float32))
        self.epoch_step = 0  
        
    def forward(self, x):
        
        
        u = self.mlp(x)  
        
        
        current_lambda = self.get_current_lambda()
        
        
        
        gate = torch.sigmoid(current_lambda + u)  
        
        
        filtered_x = gate * x  
        
        
        alpha = 0.8  
        filtered_x = alpha * filtered_x + (1 - alpha) * x
        
        
        if self.normalize:
            filtered_x = F.layer_norm(filtered_x, filtered_x.shape[-1:])  
            
        return filtered_x
    
    def get_current_lambda(self):
        
        if self.training:
            
            decay_factor = self.decay_rate ** self.epoch_step
            current_lambda = max(self.lambda_param * decay_factor, self.lambda_min)
        else:
            
            decay_factor = self.decay_rate ** self.epoch_step
            current_lambda = max(self.lambda_param * decay_factor, self.lambda_min)
        return current_lambda
    
    def step(self):
        
        self.epoch_step += 1
    
    def reset_lambda(self):
        
        self.lambda_param.data.fill_(self.lambda_init)
        self.epoch_step = 0


class CausalFilterWrapper(nn.Module):
    
    
    def __init__(self, model, layer_dims, filter_positions=None, **filter_kwargs):
        
        super(CausalFilterWrapper, self).__init__()
        
        self.model = model
        self.layer_dims = layer_dims
        
        
        if filter_positions is None:
            filter_positions = list(range(len(layer_dims)))
        self.filter_positions = filter_positions
        
        
        self.filters = nn.ModuleDict()
        for pos in filter_positions:
            if pos < len(layer_dims):
                self.filters[str(pos)] = CausalFilter(layer_dims[pos], **filter_kwargs)
    
    def apply_filter(self, x, layer_idx):
        
        if str(layer_idx) in self.filters:
            x = self.filters[str(layer_idx)](x)
        return x
    
    def step_all_filters(self):
        
        for filter_module in self.filters.values():
            filter_module.step()
    
    def reset_all_filters(self):
        
        for filter_module in self.filters.values():
            filter_module.reset_lambda()



def add_causal_filter_to_model(model_class):
    
    class FilteredModel(model_class):
        def __init__(self, *args, use_causal_filter=False, filter_config=None, **kwargs):
            super().__init__(*args, **kwargs)
            
            self.use_causal_filter = use_causal_filter
            self.filters = nn.ModuleDict()
            
            if use_causal_filter and filter_config:
                for layer_name, config in filter_config.items():
                    self.filters[layer_name] = CausalFilter(**config)
        
        def apply_causal_filter(self, x, layer_name):
            if self.use_causal_filter and layer_name in self.filters:
                return self.filters[layer_name](x)
            return x
        
        def step_filters(self):
            if self.use_causal_filter:
                for filter_module in self.filters.values():
                    filter_module.step()
    
    return FilteredModel



def create_filter_config(layer_dims, lambda_init=100.0, lambda_min=0.1, 
                        decay_rate=0.95, normalize=True):
    
    config = {}
    for layer_name, dim in layer_dims.items():
        config[layer_name] = {
            'input_dim': dim,
            'lambda_init': lambda_init,
            'lambda_min': lambda_min,
            'decay_rate': decay_rate,
            'normalize': normalize
        }
    return config


def insert_causal_filters(model, layer_outputs, filter_positions=None, **filter_kwargs):
    
    filters = nn.ModuleDict()
    
    if filter_positions is None:
        filter_positions = list(range(len(layer_outputs)))
    
    for pos in filter_positions:
        if pos < len(layer_outputs):
            filters[f'filter_{pos}'] = CausalFilter(layer_outputs[pos], **filter_kwargs)
    
    return filters
