import torch
from torch import nn
from utils import RBFExpansion
from torch_scatter import scatter
from transformer import ComformerConv
import torch.nn.functional as F
from torch_geometric.utils import softmax

def equality_adjustment(equality: torch.Tensor, batch: torch.Tensor):
    b, l1, l2 = batch.shape
    n = l1 * l2
    x = batch.reshape(b, n)

    eq = equality.to(dtype=torch.bool)

    eq = eq | eq.transpose(-1, -2)
    eye = torch.eye(n, device=eq.device, dtype=torch.bool).expand(b, n, n)
    reach = eq | eye

    rep = reach.float().argmax(dim=-1)  # [b, n] in [0, n-1]

    sum_ = torch.zeros((b, n), device=x.device, dtype=x.dtype).scatter_add_(1, rep, x)
    cnt_ = torch.zeros((b, n), device=x.device, dtype=x.dtype).scatter_add_(1, rep, torch.ones_like(x))
    mean = sum_ / cnt_.clamp_min(1.0)

    x_adj = mean.gather(1, rep)
    return x_adj.reshape(b, l1, l2)


class WeightedDielectricPooling(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.att_net = nn.Sequential(
            nn.Linear(in_channels, max(8, in_channels // 2)),
            nn.Tanh(),
            nn.Linear(max(8, in_channels // 2), 1),
        )

    def forward(self, atom_tensors: torch.Tensor, node_feats: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
        scores = self.att_net(node_feats)                    # [N, 1]
        weights = softmax(scores, batch)                     # [N, 1]
        weighted = atom_tensors * weights.view(-1, 1, 1)     # [N, 3, 3]
        out = scatter(weighted, batch, dim=0, reduce="sum")  # [B, 3, 3]
        return out


class MatrixInteractionHead(nn.Module):
    def __init__(self, in_channels: int, edge_dim: int, num_matrix_channels: int = 128, mixing_channels: int = 16):
        super().__init__()
        self.in_channels = in_channels
        self.num_channels = num_matrix_channels   # M
        self.mix_channels = mixing_channels       # K

        self.edge_mlp = nn.Sequential(
            nn.Linear(in_channels * 2 + edge_dim, in_channels),
            nn.SiLU(),
            nn.Linear(in_channels, num_matrix_channels),
        )

        self.compress = nn.Linear(num_matrix_channels, mixing_channels, bias=False)

        self.lin_weight = nn.Parameter(torch.randn(mixing_channels) / mixing_channels)
        self.mix_weight = nn.Parameter(torch.randn(mixing_channels, mixing_channels) / mixing_channels)

        self.scalar_mlp = nn.Sequential(
            nn.Linear(in_channels, 32),
            nn.SiLU(),
            nn.Linear(32, 1),
        )

        self.p_raw = nn.Parameter(torch.tensor(0.0))

        self.gamma_raw = nn.Parameter(torch.tensor(0.0))

        self.reset_parameters()

    def reset_parameters(self):
        if hasattr(self.edge_mlp[-1], "weight"):
            nn.init.orthogonal_(self.edge_mlp[-1].weight, gain=1.0)
            if self.edge_mlp[-1].bias is not None:
                nn.init.zeros_(self.edge_mlp[-1].bias)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_vec: torch.Tensor, edge_feat: torch.Tensor):
        src, dst = edge_index
        N = x.size(0)
        M = self.num_channels
        device = x.device

        dist = torch.norm(edge_vec, dim=-1, keepdim=True).clamp_min(1e-8)
        n = edge_vec / dist  # [E, 3]

        I = torch.eye(3, device=device).unsqueeze(0) # [1, 3, 3]
        nnT = torch.einsum("bi,bj->bij", n, n) # [E, 3, 3]
        Q = nnT - (1.0 / 3.0) * I # [E, 3, 3]

        edge_input = torch.cat([x[src], x[dst], edge_feat], dim=-1) # [E, 2C+F]
        weights = F.softplus(self.edge_mlp(edge_input)) # [E, M]

        weighted_B = weights.view(-1, M, 1, 1) * Q.view(-1, 1, 3, 3)      # [E, M, 3, 3]

        S_sum = scatter(weighted_B, dst, dim=0, reduce="sum", dim_size=N)  # [N, M, 3, 3]

        ones = torch.ones(dst.size(0), device=device)
        deg = scatter(ones, dst, dim=0, reduce="sum", dim_size=N).clamp_min(1.0)  # [N]
        p = torch.sigmoid(self.p_raw) * 0.999
        norm = deg.pow(-p).view(N, 1, 1, 1)  # broadcast over [M,3,3]
        S_channels = S_sum * norm

        S_core = self.compress(S_channels.transpose(1, -1)).transpose(1, -1)      # [N, K, 3, 3]

        T_lin = torch.einsum("k,bkij->bij", self.lin_weight, S_core)      # [N, 3, 3]

        S_weighted = torch.einsum("mn,bmij->bnij", self.mix_weight, S_core)        # [N, K, 3, 3]
        T_quad = torch.einsum("bnij,bnjk->bik", S_weighted, S_core)               # [N, 3, 3]

        # M = self.quad_mix_weight
        # equal to T_quad = torch.einsum("bmij,mn,bnjk->bik", S_core, M, S_core)

        T_quad = 0.5 * (T_quad + T_quad.transpose(-1, -2))

        gamma = torch.sigmoid(self.gamma_raw)

        T_out = T_lin + gamma * T_quad

        T_out = 0.5 * (T_out + T_out.transpose(-1, -2))
        return T_out


class CEITNet(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        embsize = 128

        self.atom_embedding = nn.Linear(92, embsize)

        self.rbf = nn.Sequential(
            RBFExpansion(vmin=-4.0, vmax=0.0, bins=512),
            nn.Linear(512, embsize),
            nn.Softplus(),
        )

        self.att_layers = nn.ModuleList(
            [ComformerConv(in_channels=embsize, out_channels=embsize, heads=1, edge_dim=embsize) for _ in range(2)]
        )

        self.norm = nn.LayerNorm(embsize)

        self.tensor_head = MatrixInteractionHead(
            in_channels=embsize,
            edge_dim=embsize,
            num_matrix_channels=128,
            mixing_channels=16,
        )

        self.pooling = WeightedDielectricPooling(in_channels=embsize)

        self.iso_mlp = nn.Sequential(
            nn.Linear(embsize, 64),
            nn.SiLU(),
            nn.Linear(64, 1),
        )

        print("=" * 60)
        print("CEITNet")
        print("=" * 60)

    def forward(self, data, feat_mask=None, equality=None) -> torch.Tensor:
        node_features = self.atom_embedding(data.x)

        edge_len = torch.norm(data.edge_attr, dim=1, keepdim=True)
        edge_feat_raw = -0.75 / (edge_len + 1e-8)
        edge_features = self.rbf(edge_feat_raw.squeeze(-1))

        for layer in self.att_layers:
            node_features = layer(node_features, data.edge_index, edge_features)
        node_features = self.norm(node_features)

        atom_tensors = self.tensor_head(
            x=node_features,
            edge_index=data.edge_index,
            edge_vec=data.edge_attr,
            edge_feat=edge_features,
        )

        outputs = self.pooling(atom_tensors, node_features, data.batch)

        # mean pooling
        # outputs = scatter(atom_tensors, data.batch, dim=0, reduce="mean")

        crystal_feat = scatter(node_features, data.batch, dim=0, reduce="mean")  # [B, C]
        c_iso = F.softplus(self.iso_mlp(crystal_feat)).view(-1, 1, 1)           # [B,1,1]
        I = torch.eye(3, device=outputs.device).unsqueeze(0)                    # [1,3,3]
        outputs = outputs + c_iso * I

        outputs = 0.5 * (outputs + outputs.transpose(-1, -2))

        if equality is not None:
            outputs = equality_adjustment(equality, outputs)

        return outputs
