# train_qm9_gauge_gnn_plus.py  (vectorized + fast readout)
import os, csv, argparse, time
from typing import List, Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch_geometric.datasets import QM9

# ----------------------- utils -----------------------

def set_seed(seed: int = 0):
    import random, numpy as np
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

def move_to_device(obj, device):
    if torch.is_tensor(obj): return obj.to(device)
    if isinstance(obj, dict): return {k: move_to_device(v, device) for k,v in obj.items()}
    if isinstance(obj, (list, tuple)): return type(obj)(move_to_device(v, device) for v in obj)
    return obj

def random_SO3(device=None, dtype=torch.float32):
    M = torch.randn(3, 3, device=device, dtype=dtype)
    q, _ = torch.linalg.qr(M)
    if torch.linalg.det(q) < 0: q[:, -1] *= -1
    return q

# ----------------------- geom helpers -----------------------

def build_knn_edges(X: torch.Tensor, k: int) -> List[Tuple[int,int]]:
    n = X.size(0)
    if n <= 1: return []
    k_eff = max(1, min(k, n-1))
    D = torch.cdist(X, X)
    edges = []
    for i in range(n):
        nn_idx = torch.topk(D[i], k_eff+1, largest=False).indices.tolist()
        for j in nn_idx:
            if i != j: edges.append((i, j))
    undirected = set()
    for i, j in edges:
        a, b = (i, j) if i < j else (j, i)
        undirected.add((a, b))
    both = []
    for (i, j) in undirected:
        both.append((i, j)); both.append((j, i))
    return both

def pca_frame(X: torch.Tensor, idx: int, nbr_idx: List[int]) -> torch.Tensor:
    if len(nbr_idx) < 2:
        return torch.eye(3, dtype=X.dtype, device=X.device)
    P = X[nbr_idx] - X[idx]
    C = (P.T @ P) / max(len(nbr_idx), 1)
    evals, evecs = torch.linalg.eigh(C)
    e1, e2 = evecs[:, 2], evecs[:, 1]
    e1 = F.normalize(e1, dim=0)
    e2 = F.normalize(e2 - (e2 @ e1) * e1, dim=0)
    e3 = torch.linalg.cross(e1, e2)
    R = torch.stack([e1, e2, F.normalize(e3, dim=0)], dim=1)
    if torch.linalg.det(R) < 0: R[:, 2] = -R[:, 2]
    return R

def frames_and_U(X: torch.Tensor, edges: List[Tuple[int,int]]):
    n = X.size(0)
    nbrs = [[] for _ in range(n)]
    for i, j in edges:
        if j not in nbrs[i]:
            nbrs[i].append(j)
    R = [pca_frame(X, i, nbrs[i]) for i in range(n)]
    U = torch.zeros(n, n, 3, 3, dtype=X.dtype, device=X.device)
    for i in range(n):
        for j in range(n):
            U[i, j] = torch.eye(3, device=X.device, dtype=X.dtype) if i==j else R[i] @ R[j].T
    for i in range(n):
        for j in range(i+1, n):
            U[j, i] = U[i, j].T
    return U

# ----------------------- model (vectorized) -----------------------

class ScalarMLP(nn.Module):
    def __init__(self, in_dim, hidden=64, out_dim=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.SiLU(),
            nn.Linear(hidden, hidden), nn.SiLU(),
            nn.Linear(hidden, out_dim)
        )
    def forward(self, x):
        return self.net(torch.clamp(x, -1e4, 1e4))

class EquivLayerPlusVec(nn.Module):
    """
    Vectorized O(3)-equivariant message passing with scalar edge features.
    h: (n,3)
    U: (n,n,3,3)
    i_idx, j_idx: (E,)
    edge_s: (E, s_dim)  [r_ij, 1/r_ij, Zi/10, Zj/10]
    """
    def __init__(self, d=3, s_dim=4):
        super().__init__()
        self.d = d
        self.msg_mlp = ScalarMLP(in_dim=4 + s_dim, hidden=64, out_dim=2)
        self.up_mlp  = ScalarMLP(in_dim=4,          hidden=64, out_dim=2)

    @staticmethod
    def _inv4_batch(a, b):
        # a,b: (E,3) or (n,3) -> (E,4)/(n,4)
        xx = (a*a).sum(-1, keepdim=True)
        yy = (b*b).sum(-1, keepdim=True)
        xy = (a*b).sum(-1, keepdim=True)
        one = torch.ones_like(xx)
        return torch.cat([xx, yy, xy, one], dim=-1)

    def forward(self, h, U, i_idx, j_idx, edge_s):
        n, d = h.shape
        E = i_idx.numel()

        # gather per-edge tensors
        U_e  = U[i_idx, j_idx]                        # (E,3,3)
        h_j  = h[j_idx]                               # (E,3)
        h_i  = h[i_idx]                               # (E,3)
        a_e  = (U_e @ h_j.unsqueeze(-1)).squeeze(-1)  # (E,3)
        b_e  = h_i                                    # (E,3)

        inv_e = self._inv4_batch(a_e, b_e)            # (E,4)
        feat  = torch.cat([inv_e, edge_s], dim=-1)    # (E, 4+s_dim)

        ab = self.msg_mlp(feat)                       # (E,2)
        alpha = torch.clamp(ab[:,0:1], -2.0, 2.0)
        beta  = torch.clamp(ab[:,1:2], -2.0, 2.0)
        msg_e = alpha * a_e + beta * b_e              # (E,3)

        # aggregate to nodes i via scatter-add
        M = torch.zeros_like(h)                       # (n,3)
        M.index_add_(0, i_idx, msg_e)                 # sum messages into receivers

        # node update (vectorized)
        inv_i = self._inv4_batch(h, M)                # (n,4)
        gd = self.up_mlp(inv_i)                       # (n,2)
        gamma = torch.clamp(gd[:,0:1], -2.0, 2.0)
        delta = torch.clamp(gd[:,1:1+1], -2.0, 2.0)

        h_new = gamma * h + delta * M
        return h_new, U

class InvariantReadoutPlusFast(nn.Module):
    """
    Normalized readout:
      S0_bar  = mean_{i<=j} <h_i, h_j>
      S1_bar  = mean_{i,j}  <h_i, U_ij h_j>
      Z_mean  = mean atomic number (estimated from edge pairs)
      InvR_mu = mean(1/r_ij) over undirected edges
    """
    def __init__(self, d=3):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(4, 128), nn.SiLU(),
            nn.Linear(128, 128), nn.SiLU(),
            nn.Linear(128, 1)
        )

    def forward(self, h, U, i_idx, j_idx, edge_s):
        n, d = h.shape
        # S0_bar
        HH = h @ h.T                             # (n,n)
        num_pairs_tri = n*(n+1)//2 if n>0 else 1
        S0_bar = torch.triu(HH).sum() / max(num_pairs_tri, 1)

        # S1_bar
        S1 = torch.einsum('ip,ijpq,jq->', h, U, h)
        S1_bar = S1 / max(n*n, 1)

        # undirected mask
        mask = (i_idx < j_idx)
        es = edge_s[mask]                        # (E_u, 4) with [r, 1/r, Zi/10, Zj/10]
        E_u = max(es.size(0), 1)

        # Z_mean (approx from edge pairs)
        Z_pair_mean = 0.5 * (es[:,2] + es[:,3]) * 10.0   # average Z per pair
        Z_mean = Z_pair_mean.mean() if es.numel() > 0 else torch.tensor(0.0, device=h.device, dtype=h.dtype)

        # mean 1/r
        InvR_mu = es[:,1].mean() if es.numel() > 0 else torch.tensor(0.0, device=h.device, dtype=h.dtype)

        feats = torch.stack([S0_bar, S1_bar, Z_mean, InvR_mu]).unsqueeze(0)  # (1,4)
        feats = torch.clamp(feats, -1e6, 1e6)
        return self.mlp(feats).squeeze(0)  # scalar

class GaugeGNNPlusVec(nn.Module):
    def __init__(self, d=3, depth=2, s_dim=4):
        super().__init__()
        self.layers = nn.ModuleList([EquivLayerPlusVec(d=d, s_dim=s_dim) for _ in range(depth)])
        self.readout = InvariantReadoutPlusFast(d=d)
    def forward(self, h, U, i_idx, j_idx, edge_s):
        for layer in self.layers:
            h, U = layer(h, U, i_idx, j_idx, edge_s)
        return self.readout(h, U, i_idx, j_idx, edge_s)

# ----------------------- dataset -----------------------

TARGET_INDEX = {
    'mu': 0, 'alpha': 1, 'homo': 2, 'lumo': 3, 'gap': 4, 'r2': 5,
    'zpve': 6, 'U0': 7, 'U': 8, 'H': 9, 'G': 10, 'Cv': 11
}

class QM9GaugeDatasetPlus(Dataset):
    """
    Returns CPU tensors (moved to device in the loops).
    For each graph:
      h: (n,3) positions (centered)
      U: (n,n,3,3) local frame transports
      i_idx, j_idx: (E,) directed edge indices
      edge_s: (E,4) with [r_ij, 1/r_ij, Zi/10, Zj/10]
      y (normalized) and y_raw (target)
    """
    def __init__(self, root='./data/qm9', target_name='U0', k=6):
        super().__init__()
        self.raw = QM9(root=root)
        self.k = k
        self.tidx = TARGET_INDEX.get(target_name, 7)

        ys = []
        for i in range(len(self.raw)):
            y = self.raw[i].y
            if y.ndim == 1: y = y.unsqueeze(0)
            tidx = min(self.tidx, y.shape[1]-1)
            ys.append(float(y[0, tidx]))
        y_tensor = torch.tensor(ys)
        self.y_mean = float(y_tensor.mean())
        self.y_std  = float(y_tensor.std() + 1e-8)

    def __len__(self): return len(self.raw)

    def __getitem__(self, idx):
        data = self.raw[idx]
        pos = data.pos.to(torch.float32)           # (n,3) CPU
        Z   = data.z.to(torch.float32)             # (n,)  CPU
        y   = data.y
        if y.ndim == 1: y = y.unsqueeze(0)
        tidx = min(self.tidx, y.shape[1]-1)
        target = y[0, tidx].to(torch.float32)

        pos = pos - pos.mean(dim=0, keepdim=True)
        edges = build_knn_edges(pos, k=self.k)
        U = frames_and_U(pos, edges)
        h = pos.clone()

        # build i/j indices and edge scalars aligned with edges
        i_idx = torch.tensor([i for (i,_) in edges], dtype=torch.long)
        j_idx = torch.tensor([j for (_,j) in edges], dtype=torch.long)
        eps = 1e-6
        r   = torch.linalg.norm(pos[i_idx] - pos[j_idx], dim=-1) + eps
        s   = torch.stack([r, 1.0/r, Z[i_idx]/10.0, Z[j_idx]/10.0], dim=-1)  # (E,4)

        ex = {
            'h': h, 'U': U,
            'i_idx': i_idx, 'j_idx': j_idx, 'edge_s': s,
            'y': (target - self.y_mean)/self.y_std,
            'y_raw': target
        }
        return ex

def collate_list(samples):
    return samples  # one-graph batches

# ----------------------- train / eval / plot -----------------------

def train_epoch(model, loader, opt, device):
    model.train()
    loss_sum, count = 0.0, 0
    for batch in loader:
        for ex in batch:
            ex = move_to_device(ex, device)
            opt.zero_grad()
            yhat = model(ex['h'], ex['U'], ex['i_idx'], ex['j_idx'], ex['edge_s']).squeeze()
            loss = F.mse_loss(yhat, ex['y'])
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            opt.step()
            loss_sum += float(loss.detach()); count += 1
    return loss_sum / max(count,1)

@torch.no_grad()
def eval_epoch(model, loader, device, y_mean, y_std):
    model.eval()
    mse_sum, mae_sum, count = 0.0, 0.0, 0
    for batch in loader:
        for ex in batch:
            ex = move_to_device(ex, device)
            yhat = model(ex['h'], ex['U'], ex['i_idx'], ex['j_idx'], ex['edge_s']).squeeze()
            yhat_raw = yhat * y_std + y_mean
            mse = F.mse_loss(yhat, ex['y'])
            mae = torch.abs(yhat_raw - ex['y_raw'])
            mse_sum += float(mse); mae_sum += float(mae); count += 1
    return mse_sum / max(count,1), mae_sum / max(count,1)

@torch.no_grad()
def invariance_sanity(model, ex, trials=3):
    h, U = ex['h'], ex['U']
    i_idx, j_idx, s = ex['i_idx'], ex['j_idx'], ex['edge_s']
    base = float(model(h, U, i_idx, j_idx, s).squeeze())
    diffs = []
    for _ in range(trials):
        R = random_SO3(device=h.device, dtype=h.dtype)
        h2 = (R @ h.T).T
        U2 = (R.view(1,1,3,3) @ U @ R.T.view(1,1,3,3))
        n = h2.size(0)
        for i in range(n):
            for j in range(i+1, n):
                U2[j, i] = U2[i, j].T
        y2 = float(model(h2, U2, i_idx, j_idx, s).squeeze())
        diffs.append(abs(base - y2))
    return sum(diffs)/max(1,len(diffs))

def plot_metrics(csv_path, out_png):
    import matplotlib.pyplot as plt
    ep, tr, va, te = [], [], [], []
    with open(csv_path, newline='') as f:
        r = csv.DictReader(f)
        for row in r:
            ep.append(int(row['epoch']))
            tr.append(float(row['train_mse']))
            va.append(float(row['val_mse']))
            te.append(float(row['test_mse']))
    plt.figure()
    plt.plot(ep, tr, label='train MSE')
    plt.plot(ep, va, label='val MSE')
    plt.plot(ep, te, label='test MSE')
    plt.xlabel('epoch'); plt.ylabel('MSE (normalized)'); plt.legend(); plt.title('QM9 regression')
    plt.savefig(out_png, dpi=150, bbox_inches='tight')

# ----------------------- main -----------------------

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--root', type=str, default='./data/qm9')
    ap.add_argument('--epochs', type=int, default=10)
    ap.add_argument('--target_name', type=str, default='U0')
    ap.add_argument('--k', type=int, default=6)
    ap.add_argument('--depth', type=int, default=2)
    ap.add_argument('--lr', type=float, default=1e-3)
    ap.add_argument('--seed', type=int, default=0)
    ap.add_argument('--split', type=str, default='3000,300,300')
    ap.add_argument('--device', type=str, default='cuda')
    ap.add_argument('--out_dir', type=str, default='runs_qm9_plus/default')
    args = ap.parse_args()

    set_seed(args.seed)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    print('device:', device)

    ds_all = QM9GaugeDatasetPlus(root=args.root, target_name=args.target_name, k=args.k)

    a,b,c = [int(x) for x in args.split.split(',')]
    total = len(ds_all)
    use_total = min(a+b+c, total)
    print(f"Split summary: train={a} val={b} test={min(c, use_total-a-b)} (sum={use_total}, original_total={total})")
    if a+b > use_total: raise ValueError("train+val exceeds available samples")
    c = min(c, use_total - (a+b))
    ds, _ = random_split(ds_all, [use_total, total-use_total],
                         generator=torch.Generator().manual_seed(args.seed))
    train_set, val_set, test_set = random_split(ds, [a,b,c],
                         generator=torch.Generator().manual_seed(args.seed+1))

    pin = (device.type == 'cuda')
    train_loader = DataLoader(train_set, batch_size=1, shuffle=True,  collate_fn=collate_list, pin_memory=pin)
    val_loader   = DataLoader(val_set,   batch_size=1, shuffle=False, collate_fn=collate_list, pin_memory=pin)
    test_loader  = DataLoader(test_set,  batch_size=1, shuffle=False, collate_fn=collate_list, pin_memory=pin)

    model = GaugeGNNPlusVec(d=3, depth=args.depth, s_dim=4).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-6)

    os.makedirs(args.out_dir, exist_ok=True)
    csv_path = os.path.join(args.out_dir, 'metrics.csv')
    with open(csv_path, 'w', newline='') as f:
        csv.writer(f).writerow(['epoch','train_mse','val_mse','test_mse'])

    print(f"Target: {args.target_name} (mapped index {TARGET_INDEX.get(args.target_name,'?')}); "
          f"y_mean={ds_all.y_mean:.4f}, y_std={ds_all.y_std:.4f}")

    for ep in range(1, args.epochs+1):
        t0 = time.time()
        tr_mse = train_epoch(model, train_loader, opt, device)
        va_mse, _ = eval_epoch(model, val_loader, device, ds_all.y_mean, ds_all.y_std)
        te_mse, _ = eval_epoch(model, test_loader, device, ds_all.y_mean, ds_all.y_std)
        dt = time.time() - t0
        print(f"[{ep:02d}] train MSE {tr_mse:.4f} | val MSE {va_mse:.4f} | test MSE {te_mse:.4f} | {dt:.1f}s")
        with open(csv_path, 'a', newline='') as f:
            csv.writer(f).writerow([ep, tr_mse, va_mse, te_mse])

    if len(test_set) > 0:
        ex = move_to_device(test_set[0], device)
        diff = invariance_sanity(model, ex, trials=3)
        print(f"O(3) invariance sanity |y - y'| (avg over 3) = {diff:.3e}")

    try:
        plot_metrics(csv_path, os.path.join(args.out_dir, 'metrics.png'))
        print("Saved plot:", os.path.join(args.out_dir, 'metrics.png'))
    except Exception as e:
        print("(plot skipped)", e)

if __name__ == "__main__":
    main()
