# train_binary_gauge_gnn_logged_fast.py
import os, csv, time, argparse
from typing import List, Tuple
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ------------------ utils ------------------
def set_seed(seed=0):
    import random, numpy as np
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

def random_SO(d, device=None, dtype=torch.float32):
    M = torch.randn(d, d, device=device, dtype=dtype)
    q, _ = torch.linalg.qr(M)
    if torch.linalg.det(q) < 0: q[:, -1] *= -1
    return q

def skew(M): 
    return 0.5*(M - M.transpose(-1, -2))

# ------------------ data -------------------
def init_flat_graph(n, d, device, h_scale=1.0):
    Q = torch.stack([random_SO(d, device=device) for _ in range(n)])        # (n,d,d)
    U = Q[:,None] @ Q[None,:].transpose(-1,-2)                               # (n,n,d,d)
    U = U.contiguous()
    for i in range(n):
        for j in range(i+1,n):
            U[j,i] = U[i,j].T
    h = h_scale * torch.randn(n, d, device=device)
    return h, U

def all_triangles(n):
    tris=[]
    for i in range(n):
        for j in range(i+1,n):
            for k in range(j+1,n):
                tris.append((i,j,k))
    return tris

def inject_triangle_holonomy(U, tri, angle):
    i,j,k = tri
    d = U.size(-1)
    A = torch.randn(d, d, device=U.device, dtype=U.dtype)
    K = skew(A); nrm = torch.linalg.norm(K)
    if float(nrm) > 0: K = K/nrm
    R = torch.matrix_exp(angle*K)
    U[i,j] = R @ U[i,j]; U[j,i] = U[i,j].T
    return U

class GraphDataset(Dataset):
    """
    Positive graphs: inject holonomy on k_tris triangles with given angle.
    Negative graphs: flat connection.
    h_scale controls noise in node features (set 0 to remove).
    """
    def __init__(self, N=2000, n=6, d=3, angle_pos=1.2, k_tris=6, h_scale=0.0, device='cpu', seed=0):
        super().__init__()
        set_seed(seed)
        self.N, self.n, self.d = N, n, d
        self.angle_pos, self.k_tris, self.h_scale = angle_pos, k_tris, h_scale
        self.device = torch.device(device)
        tris = all_triangles(n)
        import random
        self.samples=[]
        for idx in range(N):
            y = 1.0 if (idx % 2 == 0) else 0.0
            h, U = init_flat_graph(n, d, device=self.device, h_scale=self.h_scale)
            if y==1.0 and len(tris)>0 and k_tris>0:
                for tri in random.sample(tris, k=min(k_tris, len(tris))):
                    U = inject_triangle_holonomy(U, tri, angle_pos)
            self.samples.append((h, U, torch.tensor(y, dtype=torch.float32, device=self.device)))
    def __len__(self): return self.N
    def __getitem__(self, i): return self.samples[i]

# --------------- model (vectorized) ---------------
class ScalarMLP(nn.Module):
    def __init__(self, in_dim, hidden=32, out_dim=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.SiLU(),
            nn.Linear(hidden, hidden), nn.SiLU(),
            nn.Linear(hidden, out_dim)
        )
    def forward(self, x):
        return self.net(torch.clamp(x, -1e4, 1e4))

class EquivLayerVec(nn.Module):
    """
    Vectorized O(d)-equivariant layer (no link update for stability/speed).
    h: (B,n,d), U: (B,n,n,d,d)
    """
    def __init__(self, d):
        super().__init__()
        self.d = d
        self.msg_mlp = ScalarMLP(4, 32, 2)
        self.up_mlp  = ScalarMLP(4, 32, 2)

    def forward(self, h, U):
        B, n, d = h.shape
        a = torch.einsum('b i j p q, b j q -> b i j p', U, h)            # (B,n,n,d)
        b = h[:, :, None, :].expand(B, n, n, d)
        xx = (a*a).sum(-1, keepdim=True); yy = (b*b).sum(-1, keepdim=True)
        xy = (a*b).sum(-1, keepdim=True); ones = torch.ones_like(xx)
        inv = torch.cat([xx, yy, xy, ones], dim=-1)                      # (B,n,n,4)
        ab = self.msg_mlp(inv.view(-1,4)).view(B, n, n, 2)
        alpha = torch.clamp(ab[...,0], -1.0, 1.0)[...,None]
        beta  = torch.clamp(ab[...,1], -1.0, 1.0)[...,None]
        msg = alpha*a + beta*b
        mask = ~torch.eye(n, dtype=torch.bool, device=h.device)
        msg = msg * mask[None,:,:,None]
        M = msg.sum(dim=2)                                               # (B,n,d)
        xx_i = (h*h).sum(-1, keepdim=True); yy_i = (M*M).sum(-1, keepdim=True)
        xy_i = (h*M).sum(-1, keepdim=True)
        inv_i = torch.cat([xx_i, yy_i, xy_i, torch.ones_like(xx_i)], -1) # (B,n,4)
        gd = self.up_mlp(inv_i.view(-1,4)).view(B, n, 2)
        gamma = torch.clamp(gd[...,0], -1.0, 1.0)[...,None]
        delta = torch.clamp(gd[...,1], -1.0, 1.0)[...,None]
        h_new = gamma*h + delta*M
        return h_new, U

class InvariantReadout(nn.Module):
    """
    Invariants: S0, S1, W3, and holonomy energy E3 = sum ||I - UijUjkUki||_F^2.
    E3 gives a very strong signal for classification.
    """
    def __init__(self, d):
        super().__init__()
        self.d = d
        self.mlp = nn.Sequential(
            nn.Linear(4, 64), nn.SiLU(),
            nn.Linear(64, 64), nn.SiLU(),
            nn.Linear(64, 1)
        )

    def forward(self, h, U):
        B, n, d = h.shape
        HH = torch.einsum('b i p, b j p -> b i j', h, h)                  # (B,n,n)
        S0 = torch.triu(HH).sum(dim=(1,2))                                # (B,)

        a = torch.einsum('b i j p q, b j q -> b i j p', U, h)             # (B,n,n,d)
        S1 = (h[:, :, None, :] * a).sum(-1).sum(dim=(1,2))                # (B,)

        W3 = torch.zeros(B, device=h.device, dtype=h.dtype)
        E3 = torch.zeros(B, device=h.device, dtype=h.dtype)
        I = torch.eye(d, device=h.device, dtype=h.dtype).expand(B,d,d)
        for i in range(n):
            for j in range(n):
                if j==i: continue
                for k in range(n):
                    if k==i or k==j: continue
                    P = U[:, i, j] @ U[:, j, k] @ U[:, k, i]              # (B,d,d)
                    W3 += torch.einsum('b p p -> b', P)                   # batched trace
                    E3 += torch.linalg.matrix_norm(I - P, ord='fro', dim=(-2,-1))**2

        feats = torch.stack([S0, S1, W3, E3], dim=-1)                     # (B,4)
        feats = torch.clamp(feats, -1e6, 1e6)
        return self.mlp(feats).squeeze(-1)                                # (B,)

class GaugeGNN(nn.Module):
    def __init__(self, d=3, depth=2):
        super().__init__()
        self.layers = nn.ModuleList([EquivLayerVec(d) for _ in range(depth)])
        self.readout = InvariantReadout(d)
    def forward(self, h, U):
        for layer in self.layers:
            h, U = layer(h, U)
        return self.readout(h, U)

# --------------- train / eval / plot ---------------
def train_epoch(model, loader, opt, device):
    model.train()
    loss_sum, correct, total = 0.0, 0, 0
    for h, U, y in loader:
        h = h.to(device); U = U.to(device); y = y.to(device)

        opt.zero_grad()
        logit = model(h, U).view(-1)            # (B,)
        y = y.view(-1).float()                  # (B,)
        loss = F.binary_cross_entropy_with_logits(logit, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        opt.step()

        loss_sum += float(loss.detach()) * y.numel()
        probs = torch.sigmoid(logit.detach())
        preds = (probs >= 0.5).to(y.dtype)
        correct += int((preds == y).sum())
        total   += int(y.numel())

    return loss_sum / max(total, 1), correct / max(total, 1)

@torch.no_grad()
def eval_epoch(model, loader, device):
    model.eval()
    loss_sum, correct, total = 0.0, 0, 0
    pos_count = 0  # sanity: class balance
    for h, U, y in loader:
        h = h.to(device); U = U.to(device); y = y.to(device)

        logit = model(h, U).view(-1)       # (B,)
        y = y.view(-1).float()             # (B,)

        loss_sum += float(F.binary_cross_entropy_with_logits(logit, y)) * y.numel()
        pos_count += int(y.sum())

        probs = torch.sigmoid(logit)
        preds = (probs >= 0.5).to(y.dtype)
        correct += int((preds == y).sum())
        total   += int(y.numel())

    acc  = correct / max(total, 1)
    loss = loss_sum / max(total, 1)
    return loss, acc, pos_count, total

@torch.no_grad()
def invariance_sanity(model, device, n=6, d=3):
    h,U = init_flat_graph(n,d,device=device, h_scale=0.0)
    base = float(model(h.unsqueeze(0),U.unsqueeze(0)))
    r = random_SO(d, device=device)
    h2 = (r @ h.T).T
    U2 = r.view(1,1,d,d) @ U @ r.T.view(1,1,d,d)
    for i in range(n):
        for j in range(i+1,n):
            U2[j,i] = U2[i,j].T
    y2 = float(model(h2.unsqueeze(0), U2.unsqueeze(0)))
    return abs(base - y2)

def plot_csv(path_csv, out_png):
    import matplotlib.pyplot as plt
    ep, tr_a, va_a, te_a = [], [], [], []
    with open(path_csv,'r',newline='') as f:
        r=csv.DictReader(f)
        for row in r:
            ep.append(int(row['epoch']))
            tr_a.append(float(row['train_acc']))
            va_a.append(float(row['val_acc']))
            te_a.append(float(row['test_acc']))
    plt.figure()
    plt.plot(ep, tr_a, label='train acc')
    plt.plot(ep, va_a, label='val acc')
    plt.plot(ep, te_a, label='test acc')
    plt.xlabel('epoch'); plt.ylabel('accuracy'); plt.title('Binary accuracy'); plt.legend()
    plt.savefig(out_png, dpi=150, bbox_inches='tight')

# ---------------------- main ----------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--epochs', type=int, default=50)
    ap.add_argument('--trainN', type=int, default=5000)
    ap.add_argument('--valN', type=int, default=1000)
    ap.add_argument('--testN', type=int, default=1000)
    ap.add_argument('--n', type=int, default=6)
    ap.add_argument('--d', type=int, default=3)
    ap.add_argument('--depth', type=int, default=2)
    ap.add_argument('--batch_size', type=int, default=512)
    ap.add_argument('--lr', type=float, default=3e-4)
    ap.add_argument('--angle_pos', type=float, default=1.2)
    ap.add_argument('--k_tris', type=int, default=6)
    ap.add_argument('--h_scale', type=float, default=0.0, help='0.0 removes node-feature noise')
    ap.add_argument('--seed', type=int, default=0)
    ap.add_argument('--device', type=str, default='cuda')
    ap.add_argument('--out_dir', type=str, default='runs_binary_fast')
    ap.add_argument('--plot', type=str, default='metrics.png')
    args = ap.parse_args()

    set_seed(args.seed)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    print('device:', device)

    os.makedirs(args.out_dir, exist_ok=True)
    csv_path = os.path.join(args.out_dir, 'metrics.csv')
    with open(csv_path,'w',newline='') as f:
        csv.writer(f).writerow(['epoch','train_loss','train_acc','val_loss','val_acc','test_loss','test_acc'])

    # keep datasets on CPU; we move batches to device inside the loops
    train_ds = GraphDataset(args.trainN, args.n, args.d, args.angle_pos, args.k_tris, args.h_scale, 'cpu', args.seed)
    val_ds   = GraphDataset(args.valN,   args.n, args.d, args.angle_pos, args.k_tris, args.h_scale, 'cpu', args.seed+1)
    test_ds  = GraphDataset(args.testN,  args.n, args.d, args.angle_pos, args.k_tris, args.h_scale, 'cpu', args.seed+2)

    train_ld = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, pin_memory=True)
    val_ld   = DataLoader(val_ds,   batch_size=args.batch_size, shuffle=False, pin_memory=True)
    test_ld  = DataLoader(test_ds,  batch_size=args.batch_size, shuffle=False, pin_memory=True)

    model = GaugeGNN(d=args.d, depth=args.depth).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)

    for ep in range(1, args.epochs+1):
        t0=time.time()
        trL,trA = train_epoch(model, train_ld, opt, device)
        vaL,vaA,va_pos,va_tot = eval_epoch(model, val_ld, device)
        teL,teA,te_pos,te_tot = eval_epoch(model, test_ld, device)

        if ep == 1:
            print(f"[sanity] class balance: val pos={va_pos/max(va_tot,1):.3f}, "
                  f"test pos={te_pos/max(te_tot,1):.3f}")

        print(f"[{ep:03d}] train {trL:.4f} acc {trA:.3f} | "
              f"val {vaL:.4f} acc {vaA:.3f} | "
              f"test {teL:.4f} acc {teA:.3f} | "
              f"{time.time()-t0:.1f}s")

        with open(csv_path,'a',newline='') as f:
            csv.writer(f).writerow([ep,trL,trA,vaL,vaA,teL,teA])

    diff = invariance_sanity(model, device, n=args.n, d=args.d)
    print(f"O(d) invariance sanity |logit - logit'| = {diff:.3e}")

    if args.plot:
        try:
            out_png = os.path.join(args.out_dir, args.plot)
            plot_csv(csv_path, out_png)
            print("Saved plot:", out_png)
        except Exception as e:
            print("(plot skipped)", e)

if __name__=='__main__':
    main()
