import os, torch, numpy as np, torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from torch import nn
from utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_printoptions(precision=4, sci_mode=False)

para = {
    'data_path': r'',
    'num': 1000,
    'view': 3,
    'input_dim': [64,64,64],
    'latent': 32,

    'epochs_pre': 9,
    'epochs': 50,
    'lr': 1e-2,
    'wd': 1e-3,

    'alpha': 0.5,  
    'beta': 8.0,       
    'K': 10,     
    'lambda_ang': 1.0,
    'lambda_dis': 0.0,
    'tau_pi': 20,

    'warm_align': 5,
    'ramp_dis_from': 15,
    'ramp_dis_to': 40
}

def main():
    npz = np.load(para['data_path'], allow_pickle=True)
    data_np = npz['data'].astype(np.float32)[:para['num']]
    data = torch.tensor(data_np, device=device)
    labels = data[:, -1].long().cpu().numpy()

    views = split_func(data, para['view'], para['input_dim'])
    X = [v.to(device) for v in views]

    encoders, decoders = [], []
    for v in range(para['view']):
        ae = AutoEncoder(para['input_dim'][v], para['latent']).to(device)
        encoders.append(ae.encoder)
        decoders.append(ae.decoder)

    params = [p for e in encoders+decoders for p in e.parameters()]
    opt = torch.optim.AdamW(params, lr=para['lr'], weight_decay=para['wd'])

    print("Pretraining...")
    for ep in range(para['epochs_pre']):
        opt.zero_grad(set_to_none=True)
        loss = 0.0
        for v in range(para['view']):
            h = encoders[v](X[v])
            xh = decoders[v](h)
            loss = loss + F.mse_loss(xh, X[v])
        loss.backward()
        nn.utils.clip_grad_norm_(params, 5.0)
        opt.step()
        print(f"[Pre] {ep+1}/{para['epochs_pre']}  recon={float(loss):.4f}")

    print("\nTraining...")
    pi = torch.ones(para['view'], device=device) / para['view']

    last_M_list = None

    for ep in range(para['epochs']):
        opt.zero_grad(set_to_none=True)

        H, Rec = [], []
        for v in range(para['view']):
            h = encoders[v](X[v]); xh = decoders[v](h)
            H.append(h); Rec.append(xh)
        with torch.no_grad():
            M0_list = [build_neighbors_unweighted(H[v], para['K']) for v in range(para['view'])]
            psi = compute_psi(M0_list, K=para['K'])  # [N] ∈ [0,1]

        with torch.no_grad():
            PHI = torch.stack([compute_phi(X[v], Rec[v], para['alpha']) for v in range(para['view'])], dim=0)  # [V,N,1]
            W = compute_W_with_psi(PHI, psi, beta=para['beta'])  # [N,N]

            phi_mean = PHI.mean().item()
            psi_mean = psi.mean().item()
            W_min, W_mean, W_max = W.min().item(), W.mean().item(), W.max().item()

        M_list = [build_neighbors(H[v], W, para['K']) for v in range(para['view'])]
        E_edges = int(sum(M.sum().item() for M in M_list) / para['view'])
        last_M_list = M_list 

        if ep < para['warm_align']:
            loss_rec = sum(F.mse_loss(Rec[v], X[v]) for v in range(para['view']))
            loss_rec.backward()
            nn.utils.clip_grad_norm_(params, 5.0)
            opt.step()

            if (ep+1) % 5 == 0 or ep == 0:
                print(f"[Warm] {ep+1}/{para['epochs']}  recon={float(loss_rec):.4f}  "
                      f"phi_mean={phi_mean:.3f}  psi_mean={psi_mean:.3f}  "
                      f"W[min/mean/max]=[{W_min:.3f}/{W_mean:.3f}/{W_max:.3f}]  |E|≈{E_edges}")
            continue

        loss_rec = sum(pi[v] * F.mse_loss(Rec[v], X[v]) for v in range(para['view']))

        loss_ang, loss_dis = 0.0, 0.0
        err_per_view = []
        for v in range(para['view']):
            ang_v = angular_loss(H, v, M_list[v])   
            dis_v = distance_loss(H, v, M_list[v])  
            loss_ang += ang_v
            loss_dis += dis_v
            err_per_view.append((ang_v + dis_v).detach())

        loss_ang = loss_ang / para['view']

        if ep < para['ramp_dis_from']:
            lam_dis = 0.0
        elif ep >= para['ramp_dis_to']:
            lam_dis = para['lambda_dis']
        else:
            t = (ep - para['ramp_dis_from']) / max(1, (para['ramp_dis_to'] - para['ramp_dis_from']))
            lam_dis = para['lambda_dis'] * t
        loss_dis = loss_dis / para['view']

        loss = loss_rec + para['lambda_ang'] * loss_ang + lam_dis * loss_dis
        if not torch.isfinite(loss):
            print(f"[Warn] Non-finite loss @ epoch {ep+1}; skip step.")
            continue

        loss.backward()
        nn.utils.clip_grad_norm_(params, 5.0)
        opt.step()

        with torch.no_grad():
            err = torch.stack(err_per_view)  # [V]
            pi = compute_pi(err, tau_pi=para['tau_pi'])

        if (ep+1) % 5 == 0:
            print(f"Epoch {ep+1}/{para['epochs']}  "
                  f"recon={float(loss_rec):.4f}  "
                  f"ang={float(loss_ang):.4f}  "
                  f"dis={float(loss_dis):.4f} (λ_dis={lam_dis:.2f})  "
                  f"phi_mean={phi_mean:.3f}  psi_mean={psi_mean:.3f}  "
                  f"W[min/mean/max]=[{W_min:.3f}/{W_mean:.3f}/{W_max:.3f}]  "
                  f"|E|≈{E_edges}  pi={pi.data.cpu().numpy()}")

    with torch.no_grad():

        H, Rec = [], []
        for v in range(para['view']):
            h = encoders[v](X[v]); xh = decoders[v](h)
            H.append(h); Rec.append(xh)

        S_attr = torch.zeros(len(labels), device=device)
        for v in range(para['view']):
            S_attr += (X[v] - Rec[v]).pow(2).sum(1)

        S_class = torch.zeros(len(labels), device=device)
        use_M_list = last_M_list
        for v in range(para['view']):
            M = use_M_list[v]
            I, J = torch.nonzero(M > 0.5, as_tuple=True)
            if I.numel() == 0:
                continue

            dv = (H[v][I] - H[v][J]).norm(p=2, dim=1)
            dc = []
            for u in range(para['view']):
                if u == v: continue
                dc.append((H[u][I] - H[u][J]).norm(p=2, dim=1))
            dc = torch.stack(dc, dim=0).mean(0)
            res_dis = (dv - dc).abs()

            dv_dir = F.normalize(H[v][I] - H[v][J], p=2, dim=1)
            dc_dir = []
            for u in range(para['view']):
                if u == v: continue
                dc_dir.append(F.normalize(H[u][I] - H[u][J], p=2, dim=1))
            dc_dir = torch.stack(dc_dir, dim=0).mean(0)
            cos_sim = (dv_dir * dc_dir).sum(1)
            res_ang = (1 - cos_sim).clamp_min(0)

            edge_res = res_dis + res_ang

            N = H[v].size(0)
            accum = torch.zeros(N, device=device)
            counts = torch.zeros(N, device=device)
            accum.index_add_(0, I, edge_res)
            counts.index_add_(0, I, torch.ones_like(edge_res))
            node_res = accum / (counts + 1e-8)
            S_class += node_res

        S = (S_attr - S_attr.mean()) / (S_attr.std() + 1e-8) \
            + (S_class - S_class.mean()) / (S_class.std() + 1e-8)

        auc = roc_auc_score(labels, S.cpu().numpy())
        print(f"\nFinal AUC = {auc:.4f}")

if __name__ == "__main__":
    torch.set_printoptions(threshold=torch.inf)
    main()
