# train_qm9_gauge_gnn.py  (logging + plotting enabled)
import os
import argparse
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torch_geometric.datasets import QM9

# ----------------------- utilities -----------------------

def set_seed(seed: int = 0):
    import random, numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def random_SO3(device=None, dtype=torch.float32):
    M = torch.randn(3, 3, device=device, dtype=dtype)
    q, _ = torch.linalg.qr(M)
    if torch.linalg.det(q) < 0:
        q[:, -1] *= -1
    return q

def pca_frame(X: torch.Tensor, idx: int, nbr_idx: List[int]) -> torch.Tensor:
    if len(nbr_idx) < 2:
        return torch.eye(3, dtype=X.dtype, device=X.device)
    P = X[nbr_idx] - X[idx]
    C = (P.T @ P) / max(len(nbr_idx), 1)
    _, evecs = torch.linalg.eigh(C)
    e1, e2 = evecs[:, 2], evecs[:, 1]
    e1 = F.normalize(e1, dim=0)
    e2 = F.normalize(e2 - (e2 @ e1) * e1, dim=0)
    e3 = torch.linalg.cross(e1, e2, dim=0)
    R = torch.stack([e1, e2, F.normalize(e3, dim=0)], dim=1)
    if torch.linalg.det(R) < 0:
        R[:, 2] = -R[:, 2]
    return R

def build_knn_edges(X: torch.Tensor, k: int = 6) -> List[Tuple[int,int]]:
    n = X.size(0)
    if n <= 1:
        return []
    k_eff = min(k, max(1, n - 1))
    D = torch.cdist(X, X)
    edges = []
    for i in range(n):
        nn_idx = torch.topk(D[i], k_eff + 1, largest=False).indices.tolist()
        for j in nn_idx:
            if i != j:
                edges.append((i, j))
    undirected = set()
    for i, j in edges:
        a, b = (i, j) if i < j else (j, i)
        undirected.add((a, b))
    both = []
    for (i, j) in undirected:
        both.append((i, j))
        both.append((j, i))
    return both

def frames_and_U(X: torch.Tensor, edges: List[Tuple[int,int]]):
    n = X.size(0)
    nbrs = [[] for _ in range(n)]
    for i, j in edges:
        if j not in nbrs[i]:
            nbrs[i].append(j)
    R_list = [pca_frame(X, i, nbrs[i]) for i in range(n)]
    U = torch.zeros(n, n, 3, 3, dtype=X.dtype, device=X.device)
    I = torch.eye(3, dtype=X.dtype, device=X.device)
    for i in range(n):
        for j in range(n):
            U[i, j] = I if i == j else R_list[i] @ R_list[j].T
    for i in range(n):
        for j in range(i+1, n):
            U[j, i] = U[i, j].T
    return R_list, U

# ----------------------- model (no link update; fast & stable) -----------------------

class ScalarMLP(nn.Module):
    def __init__(self, in_dim, hidden=64, out_dim=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.SiLU(),
            nn.Linear(hidden, hidden), nn.SiLU(),
            nn.Linear(hidden, out_dim)
        )
    def forward(self, x):
        return self.net(torch.clamp(x, -1e4, 1e4))

class EquivLayer(nn.Module):
    def __init__(self, d=3):
        super().__init__()
        self.msg_mlp = ScalarMLP(4, 64, 2)
        self.up_mlp  = ScalarMLP(4, 64, 2)

    def _inv4(self, a, b):
        xx = torch.sum(a*a, dim=-1, keepdim=True)
        yy = torch.sum(b*b, dim=-1, keepdim=True)
        xy = torch.sum(a*b, dim=-1, keepdim=True)
        return torch.clamp(torch.cat([xx, yy, xy, torch.ones_like(xx)], -1), -1e6, 1e6)

    def forward(self, h, U, edges):
        n, d = h.shape
        M = torch.zeros_like(h)
        for i in range(n):
            agg = torch.zeros(d, device=h.device, dtype=h.dtype)
            for j in range(n):
                if i == j: continue
                a = U[i, j] @ h[j]
                b = h[i]
                alpha, beta = self.msg_mlp(self._inv4(a, b)).reshape(-1)
                alpha = torch.clamp(alpha, -2.0, 2.0)
                beta  = torch.clamp(beta,  -2.0, 2.0)
                agg += alpha * a + beta * b
            M[i] = agg
        h_new = torch.empty_like(h)
        for i in range(n):
            gamma, delta = self.up_mlp(self._inv4(h[i], M[i])).reshape(-1)
            gamma = torch.clamp(gamma, -2.0, 2.0)
            delta = torch.clamp(delta,  -2.0, 2.0)
            h_new[i] = gamma*h[i] + delta*M[i]
        return h_new, U

class InvariantReadout(nn.Module):
    """
    Fast invariants: S0 = sum_{i<=j} <h_i, h_j>, S1 = sum_{i,j} <h_i, U_ij h_j>
    """
    def __init__(self, d=3):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(2, 128), nn.SiLU(),
            nn.Linear(128, 128), nn.SiLU(),
            nn.Linear(128, 1)
        )
    def forward(self, h, U):
        n, d = h.shape
        S0 = h.T @ h
        S0 = torch.sum(torch.triu(S0))
        S1 = 0.0
        for i in range(n):
            Uh = torch.einsum('j p q, j q -> j p', U[i], h)  # (n,d): U_ij h_j for all j
            S1 += torch.einsum('p, j p ->', h[i], Uh)
        feats = torch.stack([S0, S1]).unsqueeze(0)
        feats = torch.clamp(feats, -1e6, 1e6)
        return self.mlp(feats).squeeze(0)

class GaugeGNN(nn.Module):
    def __init__(self, d=3, depth=2):
        super().__init__()
        self.layers = nn.ModuleList([EquivLayer(d) for _ in range(depth)])
        self.readout = InvariantReadout(d)
    def forward(self, h, U, edges):
        for layer in self.layers:
            h, U = layer(h, U, edges)
        return self.readout(h, U)

# ----------------------- dataset wrapper -----------------------

TARGET_INDEX = {
    'mu': 0, 'alpha': 1, 'homo': 2, 'lumo': 3, 'gap': 4, 'r2': 5,
    'zpve': 6, 'U0': 7, 'U': 8, 'H': 9, 'G': 10, 'Cv': 11
}

class QM9GaugeDataset(Dataset):
    def __init__(self, root='./data/qm9', target_name='U0', k=6, max_atoms=None):
        super().__init__()
        self.raw = QM9(root=root)                     # CPU tensors
        self.k = k
        self.tidx = TARGET_INDEX.get(target_name, 7)
        # choose indices (possibly filtered by size)
        self.idxs = list(range(len(self.raw)))
        if max_atoms is not None:
            self.idxs = [i for i in self.idxs if int(self.raw[i].z.numel()) <= int(max_atoms)]
        # stats on the used subset
        ys = []
        for i in self.idxs:
            y = self.raw[i].y
            if y.ndim == 1: y = y.unsqueeze(0)
            tidx = min(self.tidx, y.shape[1]-1)
            ys.append(float(y[0, tidx]))
        y_tensor = torch.tensor(ys)
        self.y_mean = float(y_tensor.mean())
        self.y_std  = float(y_tensor.std() + 1e-8)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, local_idx):
        idx = self.idxs[local_idx]
        data = self.raw[idx]
        pos = data.pos.to(torch.float32).cpu()
        y = data.y
        if y.ndim == 1: y = y.unsqueeze(0)
        tidx = min(self.tidx, y.shape[1]-1)
        target = y[0, tidx].to(torch.float32).cpu()
        pos = pos - pos.mean(dim=0, keepdim=True)
        edges = build_knn_edges(pos, k=self.k)
        _, U = frames_and_U(pos, edges)
        h = pos.clone()
        return {
            'h': h,                                  # CPU
            'U': U,                                  # CPU
            'edges': edges,                          # list
            'y': ((target - self.y_mean) / self.y_std),
            'y_raw': target
        }

def collate_list(samples):
    return samples  # one-molecule per batch

# ----------------------- manual split -----------------------

def make_splits(ds: Dataset, split_str: str, seed: int):
    total = len(ds)
    a, b, c = [int(x) for x in split_str.split(',')]
    want = a + b + c
    use = min(want, total)
    g = torch.Generator().manual_seed(seed)
    perm = torch.randperm(total, generator=g)
    idx_used = perm[:use]
    a = min(a, use); b = min(b, use - a); c = use - a - b
    return (Subset(ds, idx_used[:a].tolist()),
            Subset(ds, idx_used[a:a+b].tolist()),
            Subset(ds, idx_used[a+b:a+b+c].tolist()))

# ----------------------- training / eval -----------------------

def train_epoch(model, loader, opt, device, y_mean, y_std):
    model.train()
    loss_sum, count = 0.0, 0
    for batch in loader:
        for ex in batch:
            h = ex['h'].to(device, non_blocking=True)
            U = ex['U'].to(device, non_blocking=True)
            y = ex['y'].to(device, non_blocking=True)
            edges = ex['edges']
            opt.zero_grad()
            yhat = model(h, U, edges).squeeze()
            loss = F.mse_loss(yhat, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            opt.step()
            loss_sum += float(loss.detach()); count += 1
    return loss_sum / max(count,1)

@torch.no_grad()
def eval_epoch(model, loader, device, y_mean, y_std):
    model.eval()
    mse_sum, mae_sum, count = 0.0, 0.0, 0
    for batch in loader:
        for ex in batch:
            h = ex['h'].to(device, non_blocking=True)
            U = ex['U'].to(device, non_blocking=True)
            y = ex['y'].to(device, non_blocking=True)
            y_raw = ex['y_raw'].to(device, non_blocking=True)
            edges = ex['edges']
            yhat = model(h, U, edges).squeeze()
            yhat_raw = yhat * y_std + y_mean
            mse = F.mse_loss(yhat, y)
            mae = torch.abs(yhat_raw - y_raw)
            mse_sum += float(mse); mae_sum += float(mae); count += 1
    return mse_sum / max(count,1), mae_sum / max(count,1)

@torch.no_grad()
def invariance_sanity(model, ex, device):
    h = ex['h'].to(device); U = ex['U'].to(device); edges = ex['edges']
    base = float(model(h, U, edges).squeeze())
    R = random_SO3(device=device, dtype=h.dtype)
    h2 = (R @ h.T).T
    U2 = (R.view(1,1,3,3) @ U @ R.T.view(1,1,3,3)).clone()
    n = h.size(0)
    for i in range(n):
        for j in range(i+1, n):
            U2[j, i] = U2[i, j].T
    y2 = float(model(h2, U2, edges).squeeze())
    return abs(base - y2)

# ----------------------- logging / plotting -----------------------

def ensure_dir(p):
    os.makedirs(p, exist_ok=True)

def write_csv_header(csv_path):
    hdr = 'epoch,train_mse,val_mse,val_mae,test_mse,test_mae\n'
    with open(csv_path, 'w', newline='') as f:
        f.write(hdr)

def append_csv(csv_path, ep, tr, vm, va, tm, ta):
    with open(csv_path, 'a', newline='') as f:
        f.write(f'{ep},{tr:.6f},{vm:.6f},{va:.6f},{tm:.6f},{ta:.6f}\n')

def plot_from_csv(csv_path, out_png):
    try:
        import csv as _csv
        import matplotlib.pyplot as plt
        ep, tr, vm, tm = [], [], [], []
        with open(csv_path, 'r', newline='') as f:
            r = _csv.DictReader(f)
            for row in r:
                ep.append(int(row['epoch']))
                tr.append(float(row['train_mse']))
                vm.append(float(row['val_mse']))
                tm.append(float(row['test_mse']))
        plt.figure()
        plt.plot(ep, tr, label='train MSE')
        plt.plot(ep, vm, label='val MSE')
        plt.plot(ep, tm, label='test MSE')
        plt.xlabel('epoch'); plt.ylabel('MSE'); plt.title('QM9 regression'); plt.legend()
        plt.savefig(out_png, dpi=150, bbox_inches='tight')
        print('Saved plot:', out_png)
    except Exception as e:
        print('(plot skipped)', e)

# ----------------------- main -----------------------

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--root', type=str, default='./data/qm9')
    ap.add_argument('--epochs', type=int, default=5)
    ap.add_argument('--target_name', type=str, default='U0')
    ap.add_argument('--k', type=int, default=6)
    ap.add_argument('--depth', type=int, default=1)
    ap.add_argument('--lr', type=float, default=1e-3)
    ap.add_argument('--seed', type=int, default=0)
    ap.add_argument('--split', type=str, default='11000,1000,1000')
    ap.add_argument('--device', type=str, default='cuda')
    ap.add_argument('--max_atoms', type=int, default=None)
    ap.add_argument('--num_workers', type=int, default=0)
    ap.add_argument('--out_dir', type=str, default='runs_qm9')
    ap.add_argument('--plot', type=str, default='metrics.png')
    args = ap.parse_args()

    set_seed(args.seed)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    print('device:', device)

    ds_all = QM9GaugeDataset(root=args.root, target_name=args.target_name, k=args.k, max_atoms=args.max_atoms)
    train_set, val_set, test_set = make_splits(ds_all, args.split, args.seed)
    used = len(train_set) + len(val_set) + len(test_set)
    print(f"Split summary: train={len(train_set)} val={len(val_set)} test={len(test_set)} "
          f"(sum={used}, original_total={len(ds_all)})")
    print(f"Target: {args.target_name} (mapped index {TARGET_INDEX.get(args.target_name,'?')}); "
          f"y_mean={ds_all.y_mean:.4f}, y_std={ds_all.y_std:.4f}")

    ensure_dir(args.out_dir)
    csv_path = os.path.join(args.out_dir, 'metrics.csv')
    write_csv_header(csv_path)

    train_loader = DataLoader(train_set, batch_size=1, shuffle=True,
                              collate_fn=collate_list, pin_memory=True, num_workers=args.num_workers)
    val_loader   = DataLoader(val_set,   batch_size=1, shuffle=False,
                              collate_fn=collate_list, pin_memory=True, num_workers=args.num_workers)
    test_loader  = DataLoader(test_set,  batch_size=1, shuffle=False,
                              collate_fn=collate_list, pin_memory=True, num_workers=args.num_workers)

    model = GaugeGNN(d=3, depth=args.depth).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-6)

    for ep in range(1, args.epochs+1):
        tr_mse = train_epoch(model, train_loader, opt, device, ds_all.y_mean, ds_all.y_std)
        va_mse, va_mae = eval_epoch(model, val_loader, device, ds_all.y_mean, ds_all.y_std)
        te_mse, te_mae = eval_epoch(model, test_loader, device, ds_all.y_mean, ds_all.y_std)
        print(f"[{ep:02d}] train MSE {tr_mse:.4f} | val MSE {va_mse:.4f} MAE {va_mae:.4f} | "
              f"test MSE {te_mse:.4f} MAE {te_mae:.4f}")
        append_csv(csv_path, ep, tr_mse, va_mse, va_mae, te_mse, te_mae)

    if len(test_set) > 0:
        ex = test_set[0]
        diff = invariance_sanity(model, ex, device)
        print(f"O(3) invariance sanity |y - y'| (avg over 1) = {diff:.3e}")

    # plot
    if args.plot:
        out_png = os.path.join(args.out_dir, args.plot)
        plot_from_csv(csv_path, out_png)

if __name__ == "__main__":
    main()
