import torch
import torch.nn as nn
import torch.nn.functional as F
from model.layers import ChebGraphConv
from improved_filter import ImprovedCausalFilter

class ChebyNetWithFilter(nn.Module):
    def __init__(self, n_feat, n_hid, n_class, enable_bias, Ko, Kl, droprate, use_causal_filter=False, filter_config=None, task='node'):
        super(ChebyNetWithFilter, self).__init__()
        self.use_causal_filter = use_causal_filter
        self.task = task
        self.Kl = Kl
        self.cheb_convs = nn.ModuleList()
        self.cheb_convs.append(ChebGraphConv(Ko, n_feat, n_hid, enable_bias))
        for _ in range(1, Kl):
            self.cheb_convs.append(ChebGraphConv(Ko, n_hid, n_hid, enable_bias))
        self.linear = nn.Linear(n_hid, n_class)
        self.droprate = droprate

        if self.use_causal_filter and filter_config:
            self.filters = nn.ModuleList()
            self.filters.append(ImprovedCausalFilter(n_feat, **filter_config.get('input', {})))
            for _ in range(1, Kl):
                self.filters.append(ImprovedCausalFilter(n_hid, **filter_config.get('hidden', {})))
        else:
            self.filters = []

    def forward(self, feature, gso, batch=None):
        x = feature
        for i in range(self.Kl):
            if self.use_causal_filter:
                x = self.filters[i](x)
            x = self.cheb_convs[i](x, gso)
            x = F.relu(x)
            x = F.dropout(x, self.droprate, training=self.training)
        
        x = self.linear(x)

        if self.task == 'graph':
            from torch_geometric.nn import global_mean_pool
            x = global_mean_pool(x, batch)

        return F.log_softmax(x, dim=1)

    def step_epoch(self):
        if self.use_causal_filter:
            for f in self.filters:
                f.step()

    def get_filter_info(self):
        if not self.use_causal_filter:
            return "Causal filter not in use."
        
        info_str = "Filter Stats:\n"
        for i, f in enumerate(self.filters):
            stats = f.get_stats()
            info_str += f"  Layer {i} Filter: lambda={stats['lambda']:.4f}, gate_mean={stats['gate_stats']['mean']:.4f}\n"
        return info_str.strip()
