import torch
import torch.nn as nn
import torch.nn.functional as F
from model.gnns import Graph_Conv, prune_edge_index_by_attn
from model.br_sa3 import BDGTrans


def calculate_spectral_loss(B: torch.Tensor, g: int, eps: float = 1e-6):
    Ac = B @ B.t()
    Sc = 0.5 * (Ac + Ac.transpose(-1, -2)).clamp_min(0.0)
    n = Sc.size(0)
    jitter = torch.eye(n, device=Sc.device, dtype=torch.float64) * 1e-6
    Sc = Sc + jitter
    d = Sc.sum(-1).clamp_min(eps)
    D_inv_sqrt = torch.diag(d.rsqrt())
    I = torch.eye(Sc.size(0), device=Sc.device, dtype=Sc.dtype)
    L = I - D_inv_sqrt @ Sc @ D_inv_sqrt
    evals, _ = torch.linalg.eigh(L)
    g = int(min(g, evals.numel()))
    return evals[:g].sum()

class BRFormer(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        activation,
        num_gnns: int,
        num_trans: int,
        num_heads: int,
        dropout: float,
        graphconv: str,
        k_blocks: int = 32, 
        g_blocks: int = 10, 
        hard: bool = True,
        w_attn: float = 0.1, 
        keep_ratio: float = 0.8, 
    ):
        super().__init__()
        self.dropout = float(dropout)
        self.activation = activation()
        self.graphconv = graphconv
        self.keep_ratio = keep_ratio
        
        self.num_anchors = k_blocks
        self.out_channels = output_dim
        self.g_blocks = g_blocks
        
        self.B0 = nn.Parameter(torch.Tensor(k_blocks, hidden_dim))
        nn.init.xavier_uniform_(self.B0)
        
        self.fc_in = nn.Linear(input_dim, hidden_dim)
        
        self.gnn_norms = nn.ModuleList([nn.BatchNorm1d(hidden_dim) for _ in range(num_gnns)])
        
        if self.graphconv == "sgc":
            self.vanilla_gnns = nn.ModuleList([Graph_Conv() for _ in range(num_gnns)])
        else:
            self.vanilla_gnns = nn.ModuleList([Graph_Conv() for _ in range(num_gnns)])

        self.attn_stack = BDGTrans(
            in_channels=hidden_dim,
            hidden_channels=hidden_dim,
            activation=self.activation,
            num_layers=num_trans,
            num_heads=num_heads,
            k_blocks=k_blocks,
            dropout=self.dropout, 
            g_blocks=g_blocks,
            hard=hard
        )

        self.post_mlp_1 = nn.Linear(hidden_dim, hidden_dim)
        self.post_mlp_2 = nn.Linear(hidden_dim, hidden_dim)
        self.post_norm = nn.LayerNorm(hidden_dim) 

        self.fc_out = nn.Linear(hidden_dim, output_dim)
        
        self.w_attn = w_attn

    def forward(self, x0, edge_index, epoch=0):
        x = self.fc_in(x0)
        x_a, _ = self.attn_stack(x, self.B0)
        loss_spec = calculate_spectral_loss(self.B0, g=self.g_blocks)

        attn_bN = torch.stack(self.attn_stack.all_attn_bN, dim=0).mean(dim=0)
        attn_Nb = torch.stack(self.attn_stack.all_attn_Nb, dim=0).mean(dim=0) 
        if attn_bN.dim() == 3: attn_bN = attn_bN.mean(dim=0)
        if attn_Nb.dim() == 3: attn_Nb = attn_Nb.mean(dim=0)

        if epoch < 30: 
            edge_index_used = edge_index 
        else:
            edge_index_used = prune_edge_index_by_attn(
                attn_bN, attn_Nb, edge_index,
                keep_ratio=self.keep_ratio,
            )
            
        z = x
        
        if self.graphconv == "sgc":
            for conv in self.vanilla_gnns:
                z = conv(z, edge_index_used)
            z = self.activation(z) 
            z = F.dropout(z, p=self.dropout, training=self.training)
        else:
            for i, conv in enumerate(self.vanilla_gnns):
                z_in = z 
                z = conv(z, edge_index_used)
                z = self.gnn_norms[i](z)
                z = self.activation(z)
                z = F.dropout(z, p=self.dropout, training=self.training)
                z = z + z_in

        z_norm = F.normalize(z, dim=-1)
        x_a_norm = F.normalize(x_a, dim=-1)
        
        x_agg = self.w_attn * x_a_norm + (1 - self.w_attn) * z_norm
    
        residual = x_agg
        out = self.post_norm(x_agg)
        out = self.post_mlp_1(out)
        out = self.activation(out)
        out = F.dropout(out, p=self.dropout, training=self.training)
        out = self.post_mlp_2(out) 
        out = F.dropout(out, p=self.dropout, training=self.training)

        x_agg = residual + out
        out_final = self.fc_out(x_agg)
        return F.log_softmax(out_final, dim=-1), loss_spec


    def init_anchors_from_labels(self, x, y, train_idx):

        num_classes = self.out_channels

        if hasattr(self.fc_in, 'weight'):
            nn.init.orthogonal_(self.fc_in.weight)
            if self.fc_in.bias is not None:
                nn.init.zeros_(self.fc_in.bias)

        base_count = self.num_anchors // num_classes
        remainder = self.num_anchors % num_classes
        class_counts = [base_count + (1 if i < remainder else 0) for i in range(num_classes)]
        
        new_anchors = []
        
        with torch.no_grad():
            x_train = x[train_idx]
            y_train = y[train_idx]
            
            for c in range(num_classes):
                count = class_counts[c]
                if count == 0: continue 

                mask_c = (y_train == c)
            
                if mask_c.sum() > 0:
                    raw_centroid = x_train[mask_c].mean(dim=0, keepdim=True) 
                else:
                    raw_centroid = torch.randn(1, x.shape[1], device=x.device)

                projected_centroid = self.fc_in(raw_centroid) 

                noise = torch.randn(count, self.B0.shape[1], device=x.device) * 0.1
                class_group = projected_centroid + noise
                new_anchors.append(class_group)
            
            final_anchors = torch.cat(new_anchors, dim=0)
            self.B0.data.copy_(final_anchors.to(self.B0.dtype))
            
        print("[Model] Anchor init done")

