import pandas as pd
from tqdm import tqdm
import torch

from dataset import get_qm9_full_loader
from model import HEGNNModel
from trainer import Regressor

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

y_mean, y_std = 75.2691421508789, 8.183727264404297
label_idx = 1

encoder_ckpt = './hegnn_16c/encoder_pretrained.pt'
head_ckpt = [f'./hegnn_16c/head-best-mask_0_{l}.pt' for l in range(12)]

# Load encoder once
encoder = HEGNNModel(irreps_channels=16)
ell_all = [f"{l}{'o' if l % 2 else 'e'}" for l in range(12)]
# l = L
# mask_all = {l_max: [ell_all[l]] for l_max in range(12)}
# l = 0, ..., L
mask_all = {l_max: [ell_all[l] for l in range(l_max + 1)] for l_max in range(12)}

# Build models
models = []
for l in range(12):
    m = Regressor(
        encoder=encoder,
        input_dim=encoder.output_dim,
        y_mean=y_mean,
        y_std=y_std,
        label_idx=label_idx,
        mode="finetune",
        mask_list=mask_all[l],
        encoder_ckpt=encoder_ckpt,
        head_ckpt=head_ckpt[l],
        hidden_dim=64,
    )
    m.to(device)
    m.eval()
    models.append(m)

print("Load Model Successfully!")

# -----------------------------------------
#     Collect predictions + ground truth
# -----------------------------------------
all_idx = []
all_smiles = []
all_alpha = []

all_pred = {f"pred_{l}": [] for l in range(12)}

loader = get_qm9_full_loader()

with torch.no_grad():
    for batch in tqdm(loader):
        batch = batch.to(device)

        # Basic info
        all_idx.append(batch.idx.cpu())
        all_smiles.extend(batch.smiles)

        # Truth: batch.y[..., label_idx]
        alpha = batch.y[:, label_idx].view(-1).cpu()
        all_alpha.append(alpha)

        # Predictions for each order
        for l in range(12):
            pred_norm = models[l](batch)
            pred_real = y_mean + pred_norm * y_std
            all_pred[f"pred_{l}"].append(pred_real.view(-1).cpu())

# -----------------------------------------
#     Stack arrays
# -----------------------------------------
all_idx = torch.cat(all_idx).numpy()
all_alpha = torch.cat(all_alpha).numpy()

# stack predicted columns
pred_cols = {}
for l in range(12):
    pred_cols[f"pred_{l}"] = torch.cat(all_pred[f"pred_{l}"]).numpy()

# -----------------------------------------
#     Build DataFrame
# -----------------------------------------
df_pred = pd.DataFrame({
    "idx": all_idx,
    "smiles": all_smiles,
    "alpha": all_alpha,
})

for k, v in pred_cols.items():
    df_pred[k] = v

df_pred.to_csv("qm9_hegnn_pred_0_l.csv", index=False)
print(df_pred.head())
print("Saved → qm9_hegnn_pred_0_l.csv")
