import torch
from torch import nn
from torch_scatter import scatter
from torch_geometric.utils import softmax
import pickle as pk
from scipy.spatial.transform import Rotation as R
from utils import RBFExpansion
from transformer import ComformerConv

def equality_adjustment(equality, batch):
    b, l1, l2 = batch.size()
    L = l1 * l2
    x = batch.reshape(b, L)

    eq_cpu = equality.detach().to('cpu')

    for i in range(b):
        parent = list(range(L))
        rank = [0] * L
        parity = [0] * L

        def find(u: int):
            if parent[u] == u:
                return u, 0
            r, p = find(parent[u])
            parity[u] ^= p
            parent[u] = r
            return parent[u], parity[u]

        def union(a: int, b: int, w: int):
            ra, pa = find(a)
            rb, pb = find(b)
            if ra == rb:
                return (pa ^ pb) == w
            if rank[ra] < rank[rb]:
                ra, rb = rb, ra
                pa, pb = pb, pa
            parent[rb] = ra
            parity[rb] = pa ^ pb ^ w
            if rank[ra] == rank[rb]:
                rank[ra] += 1
            return True

        eq = eq_cpu[i, 0]
        anti = eq_cpu[i, 1]

        for a in range(L):
            for b2 in range(a + 1, L):
                if bool(eq[a, b2]):
                    union(a, b2, 0)
                if bool(anti[a, b2]):
                    union(a, b2, 1)

        comp = {}
        sign_to_root = [1] * L
        for u in range(L):
            r, pu = find(u)
            sign_to_root[u] = -1 if (pu & 1) else 1
            comp.setdefault(r, []).append(u)

        for r, nodes in comp.items():
            if len(nodes) <= 1:
                continue
            idx = torch.tensor(nodes, device=x.device, dtype=torch.long)
            sgn = torch.tensor([sign_to_root[u] for u in nodes], device=x.device, dtype=x.dtype)
            base = (sgn * x[i, idx]).mean()
            x[i, idx] = sgn * base

    return x.reshape(b, l1, l2)

class WeightedPiezoPooling(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        hidden = max(8, in_channels // 2)
        self.att_net = nn.Sequential(
            nn.Linear(in_channels, hidden),
            nn.Tanh(),
            nn.Linear(hidden, 1),
        )

    def forward(self, atom_tensors: torch.Tensor, node_feats: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
        scores = self.att_net(node_feats)
        weights = softmax(scores, batch)

        weighted = atom_tensors * weights.view(-1, 1, 1, 1)

        out = scatter(weighted, batch, dim=0, reduce="sum")
        out = 0.5 * (out + out.transpose(-1, -2))
        return out


class MatrixInteractionHead(nn.Module):
    def __init__(
        self,
        in_channels: int,
        edge_dim: int,
        num_tensor_channels: int = 64,
        mixing_channels: int = 16,
    ):
        super().__init__()
        self.M = num_tensor_channels
        self.K = mixing_channels
        self.register_buffer("I3", torch.eye(3))

        self.edge_mlp_iso = nn.Sequential(
            nn.Linear(in_channels * 2 + edge_dim, in_channels),
            nn.SiLU(),
            nn.Linear(in_channels, self.M),
        )
        self.edge_mlp_dev = nn.Sequential(
            nn.Linear(in_channels * 2 + edge_dim, in_channels),
            nn.SiLU(),
            nn.Linear(in_channels, self.M),
        )

        self.edge_mlp_vec = nn.Sequential(
            nn.Linear(in_channels * 2 + edge_dim, in_channels),
            nn.SiLU(),
            nn.Linear(in_channels, self.M),
        )

        self.compress_rank3 = nn.Linear(self.M, self.K, bias=False)

        self.compress_vec = nn.Linear(self.M, self.K, bias=False)
        self.compress_mat = nn.Linear(self.M, self.K, bias=False)

        self.lin_weight = nn.Parameter(torch.ones(self.K) / self.K)

        self.mix_weight = nn.Parameter(torch.randn(self.K, self.K) / self.K)
        self.p_raw = nn.Parameter(torch.tensor(0.0))

        self.gamma_raw = nn.Parameter(torch.tensor(0.0))

        self.reset_parameters()

    def reset_parameters(self):
        for net in [self.edge_mlp_iso, self.edge_mlp_dev, self.edge_mlp_vec]:
            if hasattr(net[-1], "weight"):
                nn.init.orthogonal_(net[-1].weight, gain=1.0)
                if net[-1].bias is not None:
                    nn.init.zeros_(net[-1].bias)
        # nn.init.normal_(self.lin_weight, mean=0.0, std=0.1)
        # nn.init.xavier_uniform_(self.mix_weight)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor,
                edge_vec: torch.Tensor, edge_feat: torch.Tensor) -> torch.Tensor:
        src, dst = edge_index
        N = x.size(0)
        device = x.device

        dist = torch.norm(edge_vec, dim=-1, keepdim=True).clamp_min(1e-8)
        n = edge_vec / dist

        I = self.I3.to(device)[None, :, :]
        nnT = torch.einsum("ea,eb->eab", n, n)
        Q = nnT - (1.0 / 3.0) * I

        B_iso = n[:, :, None, None] * I[None, None, :, :]
        B_dev = n[:, :, None, None] * Q[:, None, :, :]

        edge_input = torch.cat([x[src], x[dst], edge_feat], dim=-1)
        w_iso = self.edge_mlp_iso(edge_input)
        w_dev = self.edge_mlp_dev(edge_input)
        w_vec = self.edge_mlp_vec(edge_input)

        contrib_rank3 = (
            w_iso.view(-1, self.M, 1, 1, 1) * B_iso.view(-1, 1, 3, 3, 3) +
            w_dev.view(-1, self.M, 1, 1, 1) * B_dev.view(-1, 1, 3, 3, 3)
        )
        S_sum = scatter(contrib_rank3, dst, dim=0, reduce="sum", dim_size=N)

        V_edges = w_vec.view(-1, self.M, 1) * n.view(-1, 1, 3)
        V_sum = scatter(V_edges, dst, dim=0, reduce="sum", dim_size=N)

        M_edges = (
            w_iso.view(-1, self.M, 1, 1) * I.view(1, 1, 3, 3) +
            w_dev.view(-1, self.M, 1, 1) * Q.view(-1, 1, 3, 3)
        )
        M_sum = scatter(M_edges, dst, dim=0, reduce="sum", dim_size=N)

        ones = torch.ones(dst.size(0), device=device)
        deg = scatter(ones, dst, dim=0, reduce="sum", dim_size=N).clamp_min(1.0)
        p = torch.sigmoid(self.p_raw) * 0.98 + 0.01
        norm = deg.pow(-p).view(N, 1, 1, 1, 1)

        S = S_sum * norm
        V = V_sum * deg.pow(-p).view(N, 1, 1)
        Mmat = M_sum * deg.pow(-p).view(N, 1, 1, 1)

        S_core = torch.einsum("nmabc,km->nkabc", S, self.compress_rank3.weight)

        V_core = torch.einsum("nma,km->nka", V, self.compress_vec.weight)

        M_core = torch.einsum("nmab,km->nkab", Mmat, self.compress_mat.weight)

        T_lin = torch.einsum("k,nkabc->nabc", self.lin_weight, S_core)

        M_mix = torch.einsum("hk,nkbc->nhbc", self.mix_weight, M_core)
        T_bilin = torch.einsum("nha,nhbc->nabc", V_core, M_mix)

        T_bilin = 0.5 * (T_bilin + T_bilin.transpose(-1, -2))
        T_lin = 0.5 * (T_lin + T_lin.transpose(-1, -2))

        gamma = torch.sigmoid(self.gamma_raw)
        T_out = T_lin + gamma * T_bilin
        T_out = 0.5 * (T_out + T_out.transpose(-1, -2))
        return T_out


class CEITNet(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        embsize = 128

        self.atom_embedding = nn.Linear(92, embsize)

        self.rbf = nn.Sequential(
            RBFExpansion(vmin=-4.0, vmax=0.0, bins=512),
            nn.Linear(512, embsize),
            nn.SiLU(), 
        )

        self.att_layers = nn.ModuleList(
            [ComformerConv(in_channels=embsize, out_channels=embsize, heads=1, edge_dim=embsize) for _ in range(2)]
        )
        self.norm = nn.LayerNorm(embsize)

        self.tensor_head = MatrixInteractionHead(
            in_channels=embsize,
            edge_dim=embsize,
            num_tensor_channels=64, 
            mixing_channels=16,
        )

        self.pooling = WeightedPiezoPooling(in_channels=embsize)

        print("=" * 60)
        print("CEITNet")
        print("=" * 60)

    def to_voigt(self, tensor_333: torch.Tensor) -> torch.Tensor:
        """
        Input: [B, 3, 3, 3] -> Output: [B, 3, 6]
        Order: xx, yy, zz, xy, yz, zx
        """
        d_xx = tensor_333[..., 0, 0]
        d_yy = tensor_333[..., 1, 1]
        d_zz = tensor_333[..., 2, 2]
        d_xy = tensor_333[..., 0, 1]
        d_yz = tensor_333[..., 1, 2]
        d_zx = tensor_333[..., 2, 0]
        
        return torch.stack([d_xx, d_yy, d_zz, d_xy, d_yz, d_zx], dim=-1)

    def forward(self, data, feat_mask=None, equality=None, add_feat_mask=None) -> torch.Tensor:
        node_features = self.atom_embedding(data.x)

        edge_len = torch.norm(data.edge_attr, dim=1, keepdim=True)
        edge_feat_raw = -0.75 / (edge_len + 1e-8)
        edge_features = self.rbf(edge_feat_raw.squeeze(-1))

        for layer in self.att_layers:
            node_features = layer(node_features, data.edge_index, edge_features)
        node_features = self.norm(node_features)

        atom_tensors = self.tensor_head(
            x=node_features,
            edge_index=data.edge_index,
            edge_vec=data.edge_attr,
            edge_feat=edge_features,
        )

        outputs = self.pooling(atom_tensors, node_features, data.batch)

        outputs_voigt = self.to_voigt(outputs)

        if equality is not None:
            outputs_voigt = equality_adjustment(equality, outputs_voigt)

        return outputs_voigt