
from tqdm import tqdm
import torch
from refiner.models.utils import rmsd_align
from refiner.models.loss import batchwise_l2_loss
from utils.commons.io import save_pkl
import numpy as np
import matplotlib.pyplot as plt

def analyze_t_vs_loss(model, val_loader, steps=25):
    model.eval()
    model.to("cuda")
    t_values = torch.linspace(0.01, 0.99, steps=steps).to("cuda")

    loss_by_t = []
    loss_dict = {}

    for t in t_values:
        all_loss = []
        for batch in tqdm(val_loader):
            batch = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            batch_size = batch["batch"].max().item() + 1
            t_tensor = torch.full((batch_size, 1), fill_value=t.item(), device="cuda")

            with torch.no_grad():
                delta_x0  = model.sample_base_dist(
                    batch["pos"].shape,
                    edge_index=batch["edge_index"],
                    batch=batch["batch"],
                    smiles=batch.get("smiles", None),
                )
                x0 = batch["pos"] + delta_x0 * 1

                x0 = rmsd_align(pos=x0, ref_pos=batch["pos"], batch=batch["batch"])

                x_t, u_t = model.compute_conditional_vector_field(x0, batch["pos"], t_tensor, batch=batch["batch"])

                v_t = model(
                    z=batch["atomic_numbers"],
                    t=t_tensor,
                    pos=x_t,
                    bond_index=batch["edge_index"],
                    edge_attr=batch.get("edge_attr", None),
                    node_attr=batch.get("node_attr", None),
                    batch=batch["batch"],
                )

                loss = batchwise_l2_loss(v_t, u_t, batch=batch["batch"], reduce="mean")
                all_loss.append(loss.item())

        print(f"t: {t.item()} loss: {np.mean(all_loss)}")
        loss_by_t.append(np.mean(all_loss))
        loss_dict[t] = all_loss
    
    save_pkl(file_path='loss.pkl', data=loss_dict)
    print(loss_by_t)
    # plot
    plt.plot(t_values.cpu().numpy(), loss_by_t)
    plt.xlabel("t")
    plt.ylabel("Flow Matching Loss")
    plt.title("Validation Loss vs Timestep t")
    plt.grid()
    plt.savefig("loss_vs_t.png")
    print("Saved loss_vs_t.png")