import torch
from torch import nn
from torch_scatter import scatter
import torch.nn.functional as F
from torch_geometric.utils import softmax
from utils import RBFExpansion
from transformer import ComformerConv


def equality_adjustment(equality, batch):
    b, l1, l2 = batch.size()
    L = l1 * l2
    x = batch.reshape(b, L)

    eq_cpu = equality.detach().to('cpu')

    for i in range(b):
        parent = list(range(L))
        rank = [0] * L
        parity = [0] * L

        def find(u: int):
            if parent[u] == u:
                return u, 0
            r, p = find(parent[u])
            parity[u] ^= p
            parent[u] = r
            return parent[u], parity[u]

        def union(a: int, b2: int, w: int):
            ra, pa = find(a)
            rb, pb = find(b2)
            if ra == rb:
                return (pa ^ pb) == w
            if rank[ra] < rank[rb]:
                ra, rb = rb, ra
                pa, pb = pb, pa
            parent[rb] = ra
            parity[rb] = pa ^ pb ^ w
            if rank[ra] == rank[rb]:
                rank[ra] += 1
            return True

        eq = eq_cpu[i, 0]
        anti = eq_cpu[i, 1]

        for a in range(L):
            for b2 in range(a + 1, L):
                if bool(eq[a, b2]):
                    union(a, b2, 0)
                if bool(anti[a, b2]):
                    union(a, b2, 1)

        comp = {}
        sign_to_root = [1] * L
        for u in range(L):
            r, pu = find(u)
            sign_to_root[u] = -1 if (pu & 1) else 1
            comp.setdefault(r, []).append(u)

        for r, nodes in comp.items():
            if len(nodes) <= 1:
                continue
            idx = torch.tensor(nodes, device=x.device, dtype=torch.long)
            sgn = torch.tensor([sign_to_root[u] for u in nodes], device=x.device, dtype=x.dtype)
            base = (sgn * x[i, idx]).mean()
            x[i, idx] = sgn * base

    return x.reshape(b, l1, l2)

class VoigtBlock(nn.Module):
    def __init__(self):
        super().__init__()
        idx_a = torch.tensor([0, 1, 2, 0, 1, 2], dtype=torch.long)
        idx_b = torch.tensor([0, 1, 2, 1, 2, 0], dtype=torch.long)
        
        self.register_buffer("row_u", idx_a.view(6, 1))
        self.register_buffer("row_v", idx_b.view(6, 1))
        self.register_buffer("col_k", idx_a.view(1, 6))
        self.register_buffer("col_l", idx_b.view(1, 6))

    def forward(self, C_rank4):
        return C_rank4[..., self.row_u, self.row_v, self.col_k, self.col_l]

class WeightedElasticPooling(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.att_net = nn.Sequential(
            nn.Linear(in_channels, in_channels // 2),
            nn.Tanh(),
            nn.Linear(in_channels // 2, 1)
        )
    
    def forward(self, atom_tensors, node_feats, batch):
        scores = self.att_net(node_feats) 
        weights = softmax(scores, batch) 
        weighted_tensors = atom_tensors * weights.view(-1, 1, 1, 1, 1)
        crystal_tensors = scatter(weighted_tensors, batch, dim=0, reduce='sum')
        return crystal_tensors

class MatrixInteractionHead(nn.Module):
    def __init__(self, in_channels, edge_dim, num_matrix_channels=128, mixing_channels: int = 16):
        super().__init__()
        self.in_channels = in_channels
        self.num_channels = num_matrix_channels  # M
        self.mix_channels = mixing_channels      # K

        self.edge_mlp = nn.Sequential(
            nn.Linear(in_channels * 2 + edge_dim, in_channels),
            nn.SiLU(),
            nn.Linear(in_channels, num_matrix_channels * 2) 
        )
        
        self.edge_mlp_lin = nn.Sequential(
            nn.Linear(in_channels * 2 + edge_dim, in_channels),
            nn.SiLU(),
            nn.Linear(in_channels, 2)
        )

        self.compress = nn.Linear(num_matrix_channels, mixing_channels, bias=False)

        self.lin_weight_u = nn.Parameter(torch.ones(mixing_channels) / mixing_channels)
        self.lin_weight_v = nn.Parameter(torch.ones(mixing_channels) / mixing_channels)

        self.mix_weight = nn.Parameter(torch.randn(mixing_channels, mixing_channels) / mixing_channels)
        self.gamma_raw = nn.Parameter(torch.tensor(0.0))
        
        self.iso_mlp = nn.Sequential(
            nn.Linear(in_channels, 64),
            nn.SiLU(),
            nn.Linear(64, 2)
        )
        
        self.p_raw = nn.Parameter(torch.tensor(0.0))

        self.reset_parameters()

    def reset_parameters(self):
        # nn.init.xavier_uniform_(self.mix_weight)
        # nn.init.normal_(self.lin_weight_u, mean=0.0, std=0.1)
        # nn.init.normal_(self.lin_weight_v, mean=0.0, std=0.1)
        nn.init.normal_(self.edge_mlp_lin[-1].weight, std=0.01)
        nn.init.constant_(self.edge_mlp_lin[-1].bias, 0.0)
        nn.init.normal_(self.edge_mlp[-1].weight, std=0.01)
        nn.init.constant_(self.edge_mlp[-1].bias, 0.0)

    def forward(self, x, edge_index, edge_vec, edge_feat, batch):
        src, dst = edge_index
        
        N = x.size(0)
        M = self.num_channels
        device = x.device

        dist = torch.norm(edge_vec, dim=-1, keepdim=True).clamp_min(1e-8)
        n = edge_vec / dist
        
        I = torch.eye(3, device=device).unsqueeze(0)
        nnT = torch.einsum("bi,bj->bij", n, n)
        Q = nnT - (1.0 / 3.0) * I
        
        basis_qq = torch.einsum("bij,bkl->bijkl", Q, Q)
        basis_iq = 0.5 * (
            torch.einsum("bij,bkl->bijkl", I.expand_as(Q), Q)
            + torch.einsum("bij,bkl->bijkl", Q, I.expand_as(Q))
        )

        edge_input = torch.cat([x[src], x[dst], edge_feat], dim=-1)

        weights = self.edge_mlp(edge_input)
        w_dev, w_iso = weights.chunk(2, dim=-1)
        
        S_edges = (
            w_dev.view(-1, M, 1, 1) * Q.view(-1, 1, 3, 3)
            + w_iso.view(-1, M, 1, 1) * I.view(-1, 1, 3, 3)
        )
        
        ones = torch.ones(dst.size(0), device=device)
        deg = scatter(ones, dst, dim=0, dim_size=N, reduce='sum').clamp_min(1.0)
        
        p = torch.sigmoid(self.p_raw) * 0.98 + 0.01
        
        norm_coef = deg.pow(-p).view(N, 1, 1, 1)
        
        S_sum = scatter(S_edges, dst, dim=0, reduce='sum', dim_size=N)
        S_node = S_sum * norm_coef
        # [N, K, 3, 3]
        S_core = self.compress(S_node.transpose(1, -1)).transpose(1, -1)

        U = torch.einsum("k,bkij->bij", self.lin_weight_u, S_core)
        V = torch.einsum("k,bkij->bij", self.lin_weight_v, S_core)
        T_lin = torch.einsum("bij,bkl->bijkl", U, V)

        S_weighted = torch.einsum("mn,bmij->bnij", self.mix_weight, S_core)
        T_quad = torch.einsum("bnij,bnkl->bijkl", S_weighted, S_core)

        T_lin = 0.5 * (T_lin + T_lin.transpose(1, 2))
        T_lin = 0.5 * (T_lin + T_lin.transpose(3, 4))
        T_lin = 0.5 * (T_lin + T_lin.transpose(1, 3).transpose(2, 4))

        T_quad = 0.5 * (T_quad + T_quad.transpose(1, 2))
        T_quad = 0.5 * (T_quad + T_quad.transpose(3, 4))
        T_quad = 0.5 * (T_quad + T_quad.transpose(1, 3).transpose(2, 4))

        gamma = torch.sigmoid(self.gamma_raw)
        T_interaction = T_lin + gamma * T_quad

        w_linear = self.edge_mlp_lin(edge_input)
        w_qq, w_iq = w_linear.chunk(2, dim=-1)

        T_linear_edges = (
            w_qq.view(-1, 1, 1, 1, 1) * basis_qq
            + w_iq.view(-1, 1, 1, 1, 1) * basis_iq
        )
        T_linear = scatter(T_linear_edges, dst, dim=0, reduce='sum', dim_size=N)
        T_linear = T_linear * deg.pow(-p).view(N, 1, 1, 1, 1)

        crystal_feat = scatter(x, batch, dim=0, reduce='mean') 
        iso_coeffs = self.iso_mlp(crystal_feat) 
        c1, c2 = iso_coeffs[batch, 0], iso_coeffs[batch, 1]
        
        delta = torch.eye(3, device=device).unsqueeze(0)
        term1 = torch.einsum('bij, bkl -> bijkl', delta, delta)
        term2 = torch.einsum('bik, bjl -> bijkl', delta, delta) + torch.einsum('bil, bjk -> bijkl', delta, delta)
        T_iso = c1.view(-1,1,1,1,1) * term1 + c2.view(-1,1,1,1,1) * term2

        T_total = T_interaction + T_linear + T_iso

        T_sym = 0.5 * (T_total + T_total.transpose(1, 2))
        T_sym = 0.5 * (T_sym + T_sym.transpose(3, 4))
        T_sym = 0.5 * (T_sym + T_sym.transpose(1, 3).transpose(2, 4))

        return T_sym

class CEITNet(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        embsize = 128
        self.atom_embedding = nn.Linear(92, embsize)
        self.rbf = nn.Sequential(
            RBFExpansion(vmin=-4.0, vmax=0.0, bins=512),
            nn.Linear(512, embsize),
            nn.Softplus(),
        )

        self.att_layers = nn.ModuleList(
            [
                ComformerConv(in_channels=embsize, out_channels=embsize, heads=1, edge_dim=embsize)
                for _ in range(2) 
            ]
        )

        self.norm = nn.LayerNorm(embsize)
        
        self.tensor_head = MatrixInteractionHead(
            in_channels=embsize, 
            edge_dim=embsize,
            num_matrix_channels=128,
            mixing_channels=16,
        )
        
        # Pooling
        self.pooling = WeightedElasticPooling(in_channels=embsize)

        self.voigt_block = VoigtBlock()

        print("="*60)
        print('CEITNet')
        print("="*60)

    def forward(self, data, feat_mask=None, equality=None) -> torch.Tensor:
        node_features = self.atom_embedding(data.x)
        
        edge_len = torch.norm(data.edge_attr, dim=1, keepdim=True)
        edge_feat_raw = -0.75 / (edge_len + 1e-8)
        edge_features = self.rbf(edge_feat_raw.squeeze(-1))

        for layer in self.att_layers:
            node_features = layer(node_features, data.edge_index, edge_features)
        node_features = self.norm(node_features)

        atom_tensors = self.tensor_head(
            x=node_features,
            edge_index=data.edge_index,
            edge_vec=data.edge_attr, 
            edge_feat=edge_features,
            batch=data.batch
        )

        outputs_cart = self.pooling(atom_tensors, node_features, data.batch)
        # try mean pooling
        # outputs_cart = scatter(atom_tensors, data.batch, dim=0, reduce="mean")
        
        outputs_voigt = self.voigt_block(outputs_cart) 

        if equality is not None:
           outputs_voigt = equality_adjustment(equality, outputs_voigt)

        return outputs_voigt
