
import matplotlib.pyplot as plt
from scripts.utils import *
from scripts.utils import compute_loss_all_batches
from scripts.create_latent_ode_model import create_LatentODE_model
from torch.distributions.normal import Normal
import scripts.utils as utils
import torch.optim as optim
import torch
from random import SystemRandom
import numpy as np
import argparse
from tqdm import tqdm
from scripts.dataLoader import ParseData
import sys
import os
import networkx as nx
from matplotlib import rc
from scipy import signal
# from scripts.dataLoader_bball import ParseData as BBallData 
# from scripts.dataLoader_penalize import  ParseData as PenalizeData
# from scripts.dataLoader_motion import ParseData as MotionData 
BASE_DIRECTORY = "/DIR/Projects/AttentionNet/"
sys.path.append(BASE_DIRECTORY)


# Generative model for noisy data based on ODE

BASE_DIRECTORY = "/DIR/Projects/AttentionNet/"
DATA_DIRECTORY = BASE_DIRECTORY + "data/"
COMMON_CONFIG = "/DIR/Projects/AttentionNet/configs/common/common.cfg"

# Generative model for noisy data based on ODE
parser = argparse.ArgumentParser('Latent ODE')
parser.add_argument('--config', type=str, default=None,
                    help='What simulation to generate.')
parser.add_argument('--ckpt', type=str, default=None, help = 'Path to the model checkpoint')
args = parser.parse_args()
args =  utils.set_args_from_config(args, args.config)
# args = utils.set_args_from_config(args, COMMON_CONFIG)
print(args)

# cfg = utils.read_config(args.config)

# for section_name in cfg.sections():
#     # loop over all keys in the section
#     for key in cfg[section_name]:
#         # get the value for the key
#         value = cfg[section_name][key]
#         # check if the value can be converted to an integer
#         try:
#             value = int(value)
#         except ValueError:
#             try:
#                 value = float(value)
#             except ValueError:
#                 try:
#                     value = str(value)
#                     if value == "":
#                         value = None
#                 except ValueError:
#                     pass

#         setattr(args, key, value)
        # print the key and value
        # print(f"{key}: {value} (int value: {value})")
# print(args.load)


args.total_balls = args.n_balls + args.hide_balls
args.logging_dir = BASE_DIRECTORY + "neurips_experiments/" +  \
    str(args.exp_no) + "_" + args.data + "_" + str(args.total_balls) + \
    "_" + str(args.hide_balls) + "/"
assert(int(args.rec_dims % args.n_heads) == 0)
args.save = args.logging_dir + args.save_tests

if not os.path.exists(args.save):
    os.makedirs(args.save)
    os.makedirs(args.save + "/" + args.save_graph)
if args.data == "spring":
    # args.dataset = '/DIR/Projects/AttentionNet/data_files/springs20_hide_10'
    # # args.dataset = "/home/cpslab/HemantWorkspace/2023/AttentionNet/data/Experiment1_5_4"
    args.suffix = '_springs' + str(args.total_balls)
    args.data_load_suffix = "_springs" + str(args.total_balls)
    args.total_ode_step = 60
elif args.data == "charged":
    args.suffix = '_charged' + str(args.total_balls)
    args.data_load_suffix = "_charged" + str(args.total_balls)
    args.total_ode_step = 60
elif args.data == "motion":
    # args.dataset = 'data/motion'
    args.suffix = 'motion'
    args.total_ode_step = 60
    args.data_load_suffix = "motion" 

elif args.data == "bball":
    args.dataset = '/DIR/Projects/AttentionNet/data_files/basketball'
    args.suffix = '_bball'
    args.total_ode_step = 49
    args.data_load_suffix = "_bball" + str(args.total_balls)
  


# ############ CPU AND GPU related, Mode related, Dataset Related
# if torch.cuda.is_available():
# 	print("Using GPU" + "-"*80)
# 	device = torch.device("cuda:" + str(args.gpu))
# else:
# 	print("Using CPU" + "-" * 80)
# 	device = torch.device("cpu")

# if args.extrap == "True":
#     print("Running extrap mode" + "-"*80)
#     args.mode = "extrap"
# elif args.extrap == "False":
#     print("Running interp mode" + "-" * 80)
#     args.mode = "interp"

############ CPU AND GPU related, Mode related, Dataset Related
args.gpu = 5
if torch.cuda.is_available():
	print("Using GPU" + "-"*80)
	device = torch.device("cuda:" + str(args.gpu))
else:
	print("Using CPU" + "-" * 80)
	device = torch.device("cpu")

if args.extrap == "True":
    print("Running extrap mode" + "-"*80)
    args.mode = "extrap"
elif args.extrap == "False":
    print("Running interp mode" + "-" * 80)
    args.mode = "interp"
device = torch.device("cpu")
### Plot related
# colors = ['red', 'green', 'blue', 'yellow', 'orange',
#           'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black', ]
## define 31 colors
colors = ['red', 'green', 'blue', 'yellow', 'orange',
            'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black',
            'red', 'green', 'blue', 'yellow', 'orange',
            'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black',
            'red', 'green', 'blue', 'yellow', 'orange',
            'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black', ]

node_colors = []
for i in range(args.n_balls + args.hide_balls):
    if i < args.n_balls:
        node_colors.append(colors[i])
    else:
        node_colors.append('black')

assert len(node_colors) == args.n_balls + args.hide_balls

# Set the global font to be DejaVu Sans, size 10 (or any other sans-serif font of your choice!)
rc('font', **{'family': 'sans-serif',
   'sans-serif': ['DejaVu Sans'], 'size': 10})
rc('mathtext', **{'default': 'regular'})

if __name__ == '__main__':

    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)

    ############ Saving Path and Preload.
    file_name = os.path.basename(__file__)[:-3]  # run_models
    utils.makedirs(args.save)
    utils.makedirs(args.save_graph)
    experimentID = args.load
    if experimentID is None:
        # Make a new experiment ID
        experimentID = int(SystemRandom().random() * 100000)

    print("Loading dataset: " + args.dataset)
    if args.data == "bball":
        dataloader = BBallData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    elif args.penalize:
        dataloader = PenalizeData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    elif args.data == "motion":
        dataloader = MotionData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    else:
        dataloader = ParseData(
            args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    test_encoder, test_decoder, test_graph, test_batch, vis_encoder, full_edges = dataloader.load_data(sample_percent=args.sample_percent_test,
                                                                                                       batch_size=args.batch_size,
                                                                                                       data_type="test")
    # train_encoder, train_decoder, train_graph, train_batch = dataloader.load_data(
    #     sample_percent=args.sample_percent_train, batch_size=args.batch_size, data_type="train")

    input_dim = dataloader.feature  # TODO: feature dimension

    input_command = sys.argv
    ind = [i for i in range(len(input_command))
           if input_command[i] == "--load"]
    if len(ind) == 1:
        ind = ind[0]
        input_command = input_command[:ind] + input_command[(ind + 2):]
    input_command = " ".join(input_command)

    obsrv_std = 0.01
    obsrv_std = torch.Tensor([obsrv_std]).to(device)
    z0_prior = Normal(torch.Tensor([0.0]).to(
        device), torch.Tensor([1.]).to(device))

    model = create_LatentODE_model(
        args, input_dim, z0_prior, obsrv_std, device)

    # Load checkpoint and evaluate the model
    if args.load is not None:
        ckpt_path = os.path.join(args.save, args.load)
        utils.get_ckpt_model(ckpt_path, model, device)
        print("Loaded checkpoint from {}".format(ckpt_path))

    wait_until_kl_inc = 10
    best_test_mse = np.inf
    n_iters_to_viz = 1

    def test():
        model.eval()
        loss = torch.nn.MSELoss()
        final_loss =  np.zeros((args.n_balls, 30))
        with torch.no_grad():
            for i in tqdm(range(test_batch)):
                batch_dict_encoder = get_next_batch_new(test_encoder, device)
                # print(batch_dict_encoder["data"].shape)
                batch_dict_graph = get_next_batch_new(test_graph, device)
                batch_dict_decoder = get_next_batch(test_decoder, device)
                batch_vis_encoder = get_next_batch(vis_encoder, device)
                batch_full_edges = full_edges.__next__()
                # print(batch_full_edges.shape, "batch_full_edges")
                # print(batch_full_edges[0])
                pred_y, info, temporal_weights = model.get_reconstruction(
                    batch_dict_encoder, batch_dict_decoder, batch_dict_graph, n_traj_samples=1)
                print(pred_y.shape, "pred_y")
                
                # exit()
                
                

                # ## Get loss for each batch fopr each time step 
                # batch_loss = loss(pred_y, batch_dict_decoder["data"]) 







                # print(batch_full_edges["data"].shape, "batch_full_edges")
                # print(batch_vis_encoder["data"].shape, "batch_vis_encoder")

                # print(batch_vis_encoder["data"].shape, batch_dict_decoder["data"].shape, "batch_dict_encoder")

        # test_res = compute_loss_all_batches(model, test_encoder, test_graph, test_decoder,
        #                                     n_batches=test_batch, device=device,
        #                                     n_traj_samples=3, kl_coef=0.49511411121293036)
    #  pred_y , infom
        # pred_y, info, temporal_weights = model.get_reconstruction(
        #     test_encoder, test_decoder, test_graph, n_traj_samples=3)
                for a in range(args.batch_size):
                    # batch_loss = loss(pred_y[0, :, :, :], batch_dict_decoder["data"][:, :, :])
                    # print(batch_loss.shape, "batch_loss")
                    # print(batch_loss, "batch_loss")
                    # plt.figure()
                    fig, ax = plt.subplots(2, 3, figsize=(20, 10))
                    # print(pred_y.shape, "pred_y")
                    # print(batch_dict_decoder["data"].cpu().numpy().shape)
                    ##TODO: plot the predicted trajectory
                    pd_t = 30
                    if args.data == "bball":
                        pd_t = 19 
                        pd_enc = 30
                    else:
                        pd_enc = 30
                    for atom in range(args.n_balls):
                        # for time in range(pd_t):
                        #     final_loss[atom, time] += loss(pred_y[0, 4*a+atom, time, :], batch_dict_decoder["data"][4*a+atom, time, :])
                        color = colors[atom]
                        ax[0, 0].scatter(batch_dict_decoder["data"][4*a + atom, :pd_t, 0].cpu().numpy(), batch_dict_decoder["data"]
                                         [4*a+atom, :pd_t, 1].cpu().numpy(), label="true_atom_" + str(atom), color=color, marker="x", alpha=0.3)
                        ax[0, 0].scatter(pred_y[0, 4*a+atom, :pd_t, 0].cpu().numpy(), pred_y[0, 4*a+atom, :pd_t, 1].cpu(
                        ).numpy(), label="pred_atom" + str(atom), color=color, marker="o", alpha=0.5)
                        ax[0, 0].scatter(batch_vis_encoder["data"][4*a+atom, :pd_enc, 0].cpu().numpy(), batch_vis_encoder["data"]
                                         [4*a+atom, :pd_enc, 1].cpu().numpy(), label="encoder_atom_" + str(atom), color=color, marker="*", alpha=1)
                        ax[0, 0].set_title("Trajectory")
                        ax[0, 0].set_xlabel("X")
                        ax[0, 0].set_ylabel("Y")
                        corr = signal.correlate(
                            batch_dict_decoder["data"][4*a + atom, :pd_t, 0].cpu().numpy(), pred_y[0, 4*a+atom, :pd_t, 0].cpu().numpy())
                        corr = corr / np.max(corr)
                        lags = signal.correlation_lags(
                            len(batch_dict_decoder["data"][4*a + atom, :pd_t, 0].cpu().numpy()), len(pred_y[0, 4*a+atom, :pd_t, 0].cpu().numpy()))
                        ax[0, 1].plot(lags, corr, label="corr_atom" +
                                      str(atom), color=color)
                        ax[0, 1].set_title("Correlation X")
                        ax[0, 1].set_xlabel("Lags")
                        ax[0, 1].set_ylabel("Correlation")
                        corr = signal.correlate(
                            batch_dict_decoder["data"][4*a + atom, :pd_t, 1].cpu().numpy(), pred_y[0, 4*a+atom, :pd_t, 1].cpu().numpy())
                        corr = corr / np.max(corr)
                        lags = signal.correlation_lags(
                            len(batch_dict_decoder["data"][4*a + atom, :pd_t, 1].cpu().numpy()), len(pred_y[0, 4*a+atom, :pd_t, 1].cpu().numpy()))
                        ax[0, 2].plot(lags, corr, label="corr_atom" +
                                      str(atom), color=color)
                        ax[0, 2].set_title("Correlation Y")
                        ax[0, 2].set_xlabel("Lags")
                        ax[0, 2].set_ylabel("Correlation")
                        
                        ax[1, 0].phase_spectrum(batch_dict_decoder["data"][4*a + atom, :pd_t, 0].cpu(
                        ).numpy(), label="true_atom_" + str(atom), color=color, alpha=0.3)
                        ax[1, 0].phase_spectrum(pred_y[0, 4*a+atom, :pd_t, 0].cpu().numpy(
                        ), label="pred_atom" + str(atom), color=color, marker="o", alpha=0.5)
                        ax[1, 0].set_title("Phase Spectrum of x")
                        # else:
                        ax[1, 1].phase_spectrum(batch_dict_decoder["data"][4*a + atom, :pd_t, 1].cpu(
                        ).numpy(), label="true_atom_" + str(atom), color=color, alpha=0.3)
                        ax[1, 1].phase_spectrum(pred_y[0, 4*a+atom, :pd_t, 1].cpu().numpy(
                        ), label="pred_atom" + str(atom), color=color, marker="o", alpha=0.5)
                        ax[1, 1].set_title("Phase Spectrum of y")

                    if args.data!="bball":
                        rows, cols = np.where(
                            batch_full_edges[a, :, :].cpu().numpy() == 1)
                        # print(batch_full_edges[a, :, :].cpu().numpy().shape)
                        edges = zip(rows.tolist(), cols.tolist())
                        # print(edges.shape)
                        # print(np.shape(list(edges)))
                        # gr = nx.Graph()
                        # gr.add_edges_from(edges)
                        ax[1, 2].set_title("Graph")
                        gr = nx.DiGraph(
                            batch_full_edges[a, :, :].cpu().numpy())

                        # print(gr.shape)
                        nx.draw(gr, ax=ax[1, 2], with_labels=True,
                                node_color=node_colors, node_size=1000, font_size=10)

                    # plt.scatter(batch_dict_decoder["data"][4*a, :pd_t, 0].cpu().numpy(), batch_dict_decoder["data"]
                    #             [4*a, :pd_t, 1].cpu().numpy(), label="true_atom1", color="red", marker="x", alpha=0.3)
                    # plt.scatter(batch_dict_decoder["data"][4*a + 1, :pd_t, 0].cpu().numpy(), batch_dict_decoder["data"]
                    #             [4*a+1, :pd_t, 1].cpu().numpy(), label="true_atom2", color="green", marker="x", alpha=0.3)
                    # plt.scatter(batch_dict_decoder["data"][4*a + 2, :pd_t, 0].cpu().numpy(), batch_dict_decoder["data"]
                    #             [4*a+2, :pd_t, 1].cpu().numpy(), label="true_atom3", color="cyan", marker="x", alpha=0.3)
                    # plt.scatter(batch_dict_decoder["data"][4*a + 3, :pd_t, 0].cpu().numpy(), batch_dict_decoder["data"]
                    #             [4*a+3, :pd_t, 1].cpu().numpy(), label="true_atom4", color="purple", marker="x", alpha=0.3)
                    # plt.scatter(pred_y[0, 4*a, :pd_t, 0].cpu().numpy(), pred_y[0, 4*a, :pd_t, 1].cpu(
                    # ).numpy(), label="pred_atom1", color="red", marker="o", alpha=0.5)
                    # plt.scatter(pred_y[0, 4*a+1, :pd_t, 0].cpu().numpy(), pred_y[0, 4*a+1, :pd_t, 1].cpu(
                    # ).numpy(), label="pred_atom2", color="green", marker="o", alpha=0.5)
                    # plt.scatter(pred_y[0, 4*a+2, :pd_t, 0].cpu().numpy(), pred_y[0, 4*a+2, :pd_t, 1].cpu(
                    # ).numpy(), label="pred_atom3", color="cyan", marker="o", alpha=0.5)
                    # plt.scatter(pred_y[0, 4*a+3, :pd_t, 0].cpu().numpy(), pred_y[0, 4*a+3, :pd_t, 1].cpu(
                    # ).numpy(), label="pred_atom4", color="purple", marker="o", alpha=0.5)

                    # plt.scatter(batch_vis_encoder["data"][4*a, :pd_t, 0].cpu().numpy(), batch_vis_encoder["data"]
                    #             [4*a, :pd_t, 1].cpu().numpy(), label="true_atom1", color="red", marker="*", alpha=1)
                    # plt.scatter(batch_vis_encoder["data"][4*a + 1, :pd_t, 0].cpu().numpy(), batch_vis_encoder["data"]
                    #             [4*a+1, :pd_t, 1].cpu().numpy(), label="true_atom2", color="green", marker="*", alpha=1)
                    # plt.scatter(batch_vis_encoder["data"][4*a + 2, :pd_t, 0].cpu().numpy(), batch_vis_encoder["data"]
                    #             [4*a+2, :pd_t, 1].cpu().numpy(), label="true_atom3", color="cyan", marker="*", alpha=1)
                    # plt.scatter(batch_vis_encoder["data"][4*a + 3, :pd_t, 0].cpu().numpy(), batch_vis_encoder["data"]
                    #             [4*a+3, :pd_t, 1].cpu().numpy(), label="true_atom4", color="purple", marker="*", alpha=1)

                    #TODO: plot the predicted trajectory
                    # fig, ax = plt.subplots(2, 2)
                    # ax[0,0].plot(pred_y[0, 4*a, :, 0].cpu().numpy(),  label="pred_atom1", color="red", marker="o", alpha=0.5)
                    # ax[0, 0].set_title('x')
                    # ax[0,1].plot(pred_y[0, 4*a, :, 1].cpu().numpy(),  label="pred_atom1", color="blue", marker="o", alpha=0.5)
                    # ax[0, 1].set_title('y')
                    # ax[1,0].plot(pred_y[0, 4*a, :, 2].cpu().numpy(),  label="pred_atom1", color="green", marker="o", alpha=0.5)
                    # ax[1, 0].set_title('xdot')
                    # ax[1,1].plot(pred_y[0, 4*a, :, 3].cpu().numpy(),  label="pred_atom1", color="cyan", marker="o", alpha=0.5)
                    # ax[1, 1].set_title('ydot')

                    # ax[0,0].plot(batch_dict_decoder["data"][4*a, :, 0].cpu().numpy(), label="true_atom1", color="red", marker="x", alpha=0.3)
                    # ax[0, 0].set_title('x')
                    # ax[0,1].plot(batch_dict_decoder["data"][4*a, :, 1].cpu().numpy(), label="true_atom1", color="blue", marker="x", alpha=0.3)
                    # ax[0, 1].set_title('y')
                    # ax[1,0].plot(batch_dict_decoder["data"][4*a, :, 2].cpu().numpy(), label="true_atom1", color="green", marker="x", alpha=0.3)
                    # ax[1, 0].set_title('xdot')
                    # ax[1,1].plot(batch_dict_decoder["data"][4*a, :, 3].cpu().numpy(), label="true_atom1", color="cyan", marker="x", alpha=0.3)
                    # ax[1, 1].set_title('ydot')

                    # ax[0,0].plot(batch_vis_encoder["data"][4*a, :, 0].cpu().numpy(), label="true_atom1", color="red", marker="*", alpha=1)
                    # ax[0, 0].set_title('x')
                    # ax[0,1].plot(batch_vis_encoder["data"][4*a, :, 1].cpu().numpy(), label="true_atom1", color="blue", marker="*", alpha=1)
                    # ax[0, 1].set_title('y')
                    # ax[1,0].plot(batch_vis_encoder["data"][4*a, :, 2].cpu().numpy(), label="true_atom1", color="green", marker="*", alpha=1)
                    # ax[1, 0].set_title('xdot')
                    # ax[1,1].plot(batch_vis_encoder["data"][4*a, :, 3].cpu().numpy(), label="true_atom1", color="cyan", marker="*", alpha=1)
                    # ax[1, 1].set_title('ydot')

                    # plt.plot(pred_y[0, a, :, 0].cpu().numpy(), label="pred_x")
                    # plt.plot(pred_y[0, a, :, 1].cpu().numpy(), label="pred_y")
                    # plt.plot(pred_y[0, a, :, 2].cpu().numpy(), label="pred_xdot")
                    # plt.plot(pred_y[0, a, :, 3].cpu().numpy(), label="pred_ydot")
                    # plt.plot(batch_dict_decoder["data"][ a, :, 0].cpu().numpy(), label="true_x")
                    # plt.plot(batch_dict_decoder["data"][ a, :, 1].cpu().numpy(), label="true_y")
                    # plt.plot(batch_dict_decoder["data"][ a, :, 2].cpu().numpy(), label="true_xdot")
                    # plt.plot(batch_dict_decoder["data"][ a, :, 3].cpu().numpy(), label="true_ydot")
                    # plt.legend()
                    if args.save_flag:
                        print("saving")
                        if not os.path.exists(args.save + "plot_iclr/"):
                            os.makedirs(args.save + "plot_iclr/")
                        plt.savefig(
                            args.save + "plot_iclr/test_{}_atom_{}.png".format(i, a))
                    plt.close()
                    fig, ax = plt.subplots(2, 2, figsize=(20, 10))
                    import seaborn as sns

                    # Set up the seaborn aesthetic configurations
                    sns.set_style("whitegrid")
                    sns.set_context("talk")  # Use "paper" for smaller font sizes

                    fig, ax = plt.subplots(1, 5, figsize=(25, 5))

                    # Define a color palette that is distinguishable and colorblind-friendly
                    colorst = sns.color_palette("colorblind", n_colors=args.n_balls)

                    # Use different markers for different series
                    # Define line styles for different series
                    # Define line styles for different series
                    line_styles = {'pred': '--', 'encoder': '-', 'true': '--'}
                    from matplotlib import patheffects as pe
                    # Assuming colorst is a list of unique colors for each atom
                    node_colors2 = []
                    for i in range(args.n_balls + args.hide_balls):
                        if i < args.n_balls:
                            node_colors2.append(colors[i])
                        else:
                            node_colors2.append('black')

                    assert len(node_colors) == args.n_balls + args.hide_balls
                    for atom in range(args.n_balls):
                        color = list(plt.get_cmap('tab10')(atom))
                        color[3] = 0.6  # Adjust the alpha value of the color to make it lighter

                        node_colors2[atom] = color

                        # Define common plotting arguments for brush stroke effect
                        plot_args = {
                            'linewidth': 3, 'alpha': 0.7, 'path_effects': [
                                pe.Stroke(linewidth=5, foreground='black', alpha=0.3), pe.Normal()]}

                        # Adjust the indices for accessing the ax array here
                        ax[0].plot(pred_y[0, 4 * a + atom, :pd_t, 0].cpu().numpy(),
                                pred_y[0, 4 * a + atom, :pd_t, 1].cpu().numpy(),
                                color=color, linestyle=line_styles['pred'], **plot_args)
                        ax[0].plot(batch_vis_encoder["data"][4 * a + atom, :pd_enc, 0].cpu().numpy(),
                                batch_vis_encoder["data"][4 * a +
                                                            atom, :pd_enc, 1].cpu().numpy(),
                                color=color, linestyle=line_styles['encoder'], **plot_args)
                        ax[0].set_title("Predicted Trajectory")

                        ax[1].plot(batch_dict_decoder["data"][4 * a + atom, :pd_t, 0].cpu().numpy(),
                                batch_dict_decoder["data"][4 * a +
                                                            atom, :pd_t, 1].cpu().numpy(),
                                color=color, linestyle=line_styles['true'], **plot_args)
                        ax[1].plot(batch_vis_encoder["data"][4 * a + atom, :pd_enc, 0].cpu().numpy(),
                                batch_vis_encoder["data"][4 * a +
                                                            atom, :pd_enc, 1].cpu().numpy(),
                                color=color, linestyle=line_styles['encoder'], **plot_args)
                        ax[1].set_title("True Trajectory")

                        ax[2].plot(pred_y[0, 4 * a + atom, :pd_t, 2].cpu().numpy(),
                                pred_y[0, 4 * a + atom, :pd_t, 3].cpu().numpy(),
                                color=color, linestyle=line_styles['pred'], **plot_args)
                        ax[2].plot(batch_vis_encoder["data"][4 * a + atom, :pd_enc, 2].cpu().numpy(),
                                batch_vis_encoder["data"][4 * a +
                                                            atom, :pd_enc, 3].cpu().numpy(),
                                color=color, linestyle=line_styles['encoder'], **plot_args)
                        ax[2].set_title("Predicted Velocity")

                        ax[3].plot(batch_dict_decoder["data"][4 * a + atom, :pd_t, 2].cpu().numpy(),
                                batch_dict_decoder["data"][4 * a +
                                                            atom, :pd_t, 3].cpu().numpy(),
                                color=color, linestyle=line_styles['true'], **plot_args)
                        ax[3].plot(batch_vis_encoder["data"][4 * a + atom, :pd_enc, 2].cpu().numpy(),
                                batch_vis_encoder["data"][4 * a +
                                                            atom, :pd_enc, 3].cpu().numpy(),
                                color=color, linestyle=line_styles['encoder'], **plot_args)
                        ax[3].set_title("True Velocity")

                        # Remove the axis labels, ticks, and legends
                        for axis in ax:
                            axis.set_xticklabels([])
                            axis.set_yticklabels([])
                            axis.set_xticks([])
                            axis.set_yticks([])
                            # Retain the spines to illustrate the box
                            axis.spines['top'].set_color('black')
                            axis.spines['bottom'].set_color('black')
                            axis.spines['right'].set_color('black')
                            axis.spines['left'].set_color('black')
                    if args.data != "bball":
                        rows, cols = np.where(
                            batch_full_edges[a, :, :].cpu().numpy() == 1)
                        # print(batch_full_edges[a, :, :].cpu().numpy().shape)
                        edges = zip(rows.tolist(), cols.tolist())
                        # print(edges.shape)
                        # print(np.shape(list(edges)))
                        # gr = nx.Graph()
                        # gr.add_edges_from(edges)
                        ax[4].set_title("Graph")
                        gr = nx.DiGraph(
                            batch_full_edges[a, :, :].cpu().numpy())

                        # print(gr.shape)
                        nx.draw(gr, ax=ax[4], with_labels=True,
                                node_color=node_colors2, node_size=1000, font_size=10)

                    # Save the figure in high-resolution and in a vector format
                    plt.tight_layout()
                    # plt.savefig("improved_plot.svg", format="svg")
                    # plt.savefig("improved_plot.png", dpi=300)
                    # plt.show()


                    if args.save_flag:
                        print("saving")
                        if not os.path.exists(args.save + "plot_iclr2/"):
                            os.makedirs(args.save + "plot_iclr2/")
                        plt.savefig(
                            args.save + "plot_iclr2/test_{}_atom_{}.png".format(i, a))
                    plt.close()
                    # ax[1, 0].legend()




                                

                    
                # print(batch_dict_decoder["data"].shape)
                # print(info)
                # # pred_y = pred_y[0]
                # print(pred_y.shape)
                # for i in range(pred_y.shape[2]):
                del batch_dict_encoder, batch_dict_graph, batch_dict_decoder
        numpy.save(args.save + "test_loss.npy", final_loss)
    # print(test_batch)
    test()
