import pdb

import numpy as np
import torch as th

from models.transformers import SetTransformer


def train(cfg):
    # Load data
    offline_data = th.load("./ckpt_plot/coop_navigation_n15/offline-data-best.ckpt")
    obs_data = th.from_numpy(offline_data["obs"]).float()  # episodes x steps x agents x dims
    next_obs_data = th.from_numpy(offline_data["next_obs"]).float()  # same as above
    action_data = th.from_numpy(offline_data["action"]).float()  # episodes x steps x agents x dims
    reward_data = th.from_numpy(offline_data["reward"]).float()  # episodes x steps x agents
    omega_data = th.cat([obs_data, action_data], dim=-1)  # concatenate obs & action

    # Build model and optimizer
    model = SetTransformer(
        dim_input=omega_data.shape[-1],
        num_outputs=omega_data.shape[-2],
        dim_output=obs_data.shape[-1],
        ln=True,
    )
    optimizer = th.optim.Adam(model.parameters(), lr=cfg["lr"])

    batch = omega_data[:10, 0]
    output = model(batch.float())

    return model


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Train forward model")
    parser.add_argument("--lr", type=float, default=1e-3)
    args = parser.parse_args()

    cfg = {}

    train(cfg)
