from scripts_custom_time_steps.utils import compute_loss_all_batches
from scripts_custom_time_steps.dataLoader import ParseData
# from scripts.dataLoader_bball import ParseData as BBallData 
from tqdm import tqdm
import argparse
import numpy as np
from random import SystemRandom
import torch
import torch.optim as optim
import scripts_custom_time_steps.utils as utils
from torch.distributions.normal import Normal
from scripts_custom_time_steps.create_latent_ode_model import create_LatentODE_model
import os
from scripts_custom_time_steps.utils import *
import sys
import matplotlib.pyplot as plt 
from scripts_custom_time_steps.dataLoader import *
from matplotlib import rc
from scipy import signal
# from scripts.dataLoader_penalize import  ParseData as PenalizeData
# from scripts.dataLoader_motion import ParseData as MotionData
# from scripts.dataLoader_kuramoto import ParseData as KuramotoData
sys.path.append("/DIR/Projects/AttentionNet/")
# sys.path.append("/DIR/Projects/AttentionNet/scripts")

BASE_DIRECTORY = "/DIR/Projects/AttentionNet/"
DATA_DIRECTORY = BASE_DIRECTORY + "data/"
COMMON_CONFIG = "/DIR/Projects/AttentionNet/configs/common/common.cfg"

# Generative model for noisy data based on ODE
parser = argparse.ArgumentParser('Latent ODE')
parser.add_argument('--config', type=str, default=None,
                    help='What simulation to generate.')
args = parser.parse_args()
args =  utils.set_args_from_config(args, args.config)
# args = utils.set_args_from_config(args, COMMON_CONFIG)
print(args)

# for section_name in cfg.sections():
#     # loop over all keys in the section
#     for key in cfg[section_name]:
#         # get the value for the key
#         value = cfg[section_name][key]
#         # check if the value can be converted to an integer
#         try:
#             value = int(value)
#         except ValueError: 
#             try :
#                 value = float(value)
#             except ValueError:
#                 try:
#                     value = str(value)
#                     if value == "":
#                         value = None
#                 except ValueError:
#                     pass
               
        
#         setattr(args, key, value)



args.total_balls = args.n_balls + args.hide_balls
args.logging_dir = BASE_DIRECTORY + "neurips_experiments_custom_timesteps/" +  \
    str(args.exp_no) + "_" + args.data + "_" + str(args.total_balls) + \
    "_" + str(args.hide_balls) + "_"+ "num_obs" + "_" + str(args.num_obs)+ "/"
assert(int(args.rec_dims % args.n_heads) == 0)
args.save =  args.logging_dir+ args.save
if args.data == "spring":
    # args.dataset = '/DIR/Projects/AttentionNet/data_files/springs20_hide_10'
    # # args.dataset = "/home/cpslab/HemantWorkspace/2023/AttentionNet/data/Experiment1_5_4"
    args.suffix = '_springs' + str(args.total_balls)
    args.data_load_suffix = "_springs" + str(args.total_balls)
    args.total_ode_step = args.num_pre + args.num_obs 
elif args.data == "charged":
    args.suffix = '_charged' + str(args.total_balls)
    args.data_load_suffix = "_charged" + str(args.total_balls)
    args.total_ode_step = 60
elif args.data == "motion":
    # args.dataset = 'data/motion'
    args.suffix = 'motion'
    args.total_ode_step = 49
    args.data_load_suffix = "motion" 
elif args.data == "kuramoto":
    args.suffix = '_kuramoto5'
    # args.n_balls = 31

elif args.data == "bball":
    args.dataset = '/DIR/Projects/AttentionNet/data_files/basketball'
    args.suffix = '_bball'
    args.total_ode_step = 49
    args.data_load_suffix = "_bball" + str(args.total_balls)
  


############ CPU AND GPU related, Mode related, Dataset Related
if torch.cuda.is_available():
	print("Using GPU" + "-"*80)
	device = torch.device("cuda:" + str(args.gpu))
else:
	print("Using CPU" + "-" * 80)
	device = torch.device("cpu")

if args.extrap == "True":
    print("Running extrap mode" + "-"*80)
    args.mode = "extrap"
elif args.extrap == "False":
    print("Running interp mode" + "-" * 80)
    args.mode = "interp"
  


# ############ CPU AND GPU related, Mode related, Dataset Related
# if torch.cuda.is_available():
# 	print("Using GPU" + "-"*80)
# 	device = torch.device("cuda:" + str(args.gpu))
# else:
# 	print("Using CPU" + "-" * 80)
# 	device = torch.device("cpu")

# if args.extrap == "True":
#     print("Running extrap mode" + "-"*80)
#     args.mode = "extrap"
# elif args.extrap == "False":
#     print("Running interp mode" + "-" * 80)
#     args.mode = "interp"

############ CPU AND GPU related, Mode related, Dataset Related
if torch.cuda.is_available():
	print("Using GPU" + "-"*80)
	device = torch.device("cuda:" + str(args.gpu))
else:
	print("Using CPU" + "-" * 80)
	device = torch.device("cpu")

if args.extrap == "True":
    print("Running extrap mode" + "-"*80)
    args.mode = "extrap"
elif args.extrap == "False":
    print("Running interp mode" + "-" * 80)
    args.mode = "interp"

### Plot related
# colors = ['red', 'green', 'blue', 'yellow', 'orange',
#           'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black', ]
## define 31 colors
colors = ['red', 'green', 'blue', 'yellow', 'orange',
            'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black',
            'red', 'green', 'blue', 'yellow', 'orange',
            'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black',
            'red', 'green', 'blue', 'yellow', 'orange',
            'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black', ]

node_colors = []
for i in range(args.n_balls + args.hide_balls):
    if i < args.n_balls:
        node_colors.append(colors[i])
    else:
        node_colors.append('black')

assert len(node_colors) == args.n_balls + args.hide_balls

# Set the global font to be DejaVu Sans, size 10 (or any other sans-serif font of your choice!)
rc('font', **{'family': 'sans-serif',
   'sans-serif': ['DejaVu Sans'], 'size': 10})
rc('mathtext', **{'default': 'regular'})

if __name__ == '__main__':

    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)

    ############ Saving Path and Preload.
    file_name = os.path.basename(__file__)[:-3]  # run_models
    utils.makedirs(args.save)
    utils.makedirs(args.save_graph)
    experimentID = args.load
    if experimentID is None:
        # Make a new experiment ID
        experimentID = int(SystemRandom().random() * 100000)

    print("Loading dataset: " + args.dataset)
    if args.data == "bball":
        dataloader = BBallData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    elif args.penalize:
        dataloader = PenalizeData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    elif args.data == "motion":
        dataloader = MotionData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    else:
        dataloader = ParseData(
            args.dataset, suffix=args.suffix, mode=args.mode, args=args, num_obs = args.num_obs, num_pre = args.num_pre)
    test_encoder, test_decoder, test_graph, test_batch, vis_encoder, full_edges = dataloader.load_data(sample_percent=args.sample_percent_test,
                                                                                                       batch_size=args.batch_size,
                                                                                                       data_type="test")
    # train_encoder, train_decoder, train_graph, train_batch = dataloader.load_data(
    #     sample_percent=args.sample_percent_train, batch_size=args.batch_size, data_type="train")

    input_dim = dataloader.feature  # TODO: feature dimension

    input_command = sys.argv
    ind = [i for i in range(len(input_command))
           if input_command[i] == "--load"]
    if len(ind) == 1:
        ind = ind[0]
        input_command = input_command[:ind] + input_command[(ind + 2):]
    input_command = " ".join(input_command)

    obsrv_std = 0.01
    obsrv_std = torch.Tensor([obsrv_std]).to(device)
    z0_prior = Normal(torch.Tensor([0.0]).to(
        device), torch.Tensor([1.]).to(device))

    model = create_LatentODE_model(
        args, input_dim, z0_prior, obsrv_std, device)

    # Load checkpoint and evaluate the model
    if args.load is not None:
        ckpt_path = os.path.join(args.save, args.load)
        utils.get_ckpt_model(ckpt_path, model, device)
        print("Loaded checkpoint from {}".format(ckpt_path))

    wait_until_kl_inc = 10
    best_test_mse = np.inf
    n_iters_to_viz = 1

    def test():
        model.eval()
        loss = torch.nn.MSELoss()
        final_loss =  np.zeros((args.n_balls, 30))
        with torch.no_grad():
            for i in tqdm(range(test_batch)):
                batch_dict_encoder = get_next_batch_new(test_encoder, device)
                # print(batch_dict_encoder["data"].shape)
                batch_dict_graph = get_next_batch_new(test_graph, device)
                batch_dict_decoder = get_next_batch(test_decoder, device)
                batch_vis_encoder = get_next_batch(vis_encoder, device)
                batch_full_edges = full_edges.__next__()
                # print(batch_full_edges.shape, "batch_full_edges")
                # print(batch_full_edges[0])
                pred_y, info, temporal_weights = model.get_reconstruction(
                    batch_dict_encoder, batch_dict_decoder, batch_dict_graph, n_traj_samples=1)
                print(pred_y.shape, "pred_y")
                
                # exit()
                
                

                # ## Get loss for each batch fopr each time step 
                # batch_loss = loss(pred_y, batch_dict_decoder["data"]) 







                # print(batch_full_edges["data"].shape, "batch_full_edges")
                # print(batch_vis_encoder["data"].shape, "batch_vis_encoder")

                # print(batch_vis_encoder["data"].shape, batch_dict_decoder["data"].shape, "batch_dict_encoder")

        # test_res = compute_loss_all_batches(model, test_encoder, test_graph, test_decoder,
        #                                     n_batches=test_batch, device=device,
        #                                     n_traj_samples=3, kl_coef=0.49511411121293036)
    #  pred_y , infom
        # pred_y, info, temporal_weights = model.get_reconstruction(
        #     test_encoder, test_decoder, test_graph, n_traj_samples=3)
                for a in range(args.batch_size):
                    # batch_loss = loss(pred_y[0, :, :, :], batch_dict_decoder["data"][:, :, :])
                    # print(batch_loss.shape, "batch_loss")
                    # print(batch_loss, "batch_loss")
                    # plt.figure()
                    fig, ax = plt.subplots(2, 3, figsize=(20, 10))
                    # print(pred_y.shape, "pred_y")
                    # print(batch_dict_decoder["data"].cpu().numpy().shape)
                    ##TODO: plot the predicted trajectory
                    pd_t = 30
                    if args.data == "bball":
                        pd_t = 19 
                        pd_enc = 30
                    else:
                        pd_enc = 30
                    for atom in range(args.n_balls):
                        # for time in range(pd_t):
                        #     final_loss[atom, time] += loss(pred_y[0, 4*a+atom, time, :], batch_dict_decoder["data"][4*a+atom, time, :])
                        color = colors[atom]
                        ax[0, 0].scatter(batch_dict_decoder["data"][4*a + atom, :pd_t, 0].cpu().numpy(), batch_dict_decoder["data"]
                                         [4*a+atom, :pd_t, 1].cpu().numpy(), label="true_atom_" + str(atom), color=color, marker="x", alpha=0.3)
                        ax[0, 0].scatter(pred_y[0, 4*a+atom, :pd_t, 0].cpu().numpy(), pred_y[0, 4*a+atom, :pd_t, 1].cpu(
                        ).numpy(), label="pred_atom" + str(atom), color=color, marker="o", alpha=0.5)
                        ax[0, 0].scatter(batch_vis_encoder["data"][4*a+atom, :pd_enc, 0].cpu().numpy(), batch_vis_encoder["data"]
                                         [4*a+atom, :pd_enc, 1].cpu().numpy(), label="encoder_atom_" + str(atom), color=color, marker="*", alpha=1)
                        ax[0, 0].set_title("Trajectory")
                        ax[0, 0].set_xlabel("X")
                        ax[0, 0].set_ylabel("Y")
                        corr = signal.correlate(
                            batch_dict_decoder["data"][4*a + atom, :pd_t, 0].cpu().numpy(), pred_y[0, 4*a+atom, :pd_t, 0].cpu().numpy())
                        corr = corr / np.max(corr)
                        lags = signal.correlation_lags(
                            len(batch_dict_decoder["data"][4*a + atom, :pd_t, 0].cpu().numpy()), len(pred_y[0, 4*a+atom, :pd_t, 0].cpu().numpy()))
                        ax[0, 1].plot(lags, corr, label="corr_atom" +
                                      str(atom), color=color)
                        ax[0, 1].set_title("Correlation X")
                        ax[0, 1].set_xlabel("Lags")
                        ax[0, 1].set_ylabel("Correlation")
                        corr = signal.correlate(
                            batch_dict_decoder["data"][4*a + atom, :pd_t, 1].cpu().numpy(), pred_y[0, 4*a+atom, :pd_t, 1].cpu().numpy())
                        corr = corr / np.max(corr)
                        lags = signal.correlation_lags(
                            len(batch_dict_decoder["data"][4*a + atom, :pd_t, 1].cpu().numpy()), len(pred_y[0, 4*a+atom, :pd_t, 1].cpu().numpy()))
                        ax[0, 2].plot(lags, corr, label="corr_atom" +
                                      str(atom), color=color)
                        ax[0, 2].set_title("Correlation Y")
                        ax[0, 2].set_xlabel("Lags")
                        ax[0, 2].set_ylabel("Correlation")
                        
                        ax[1, 0].phase_spectrum(batch_dict_decoder["data"][4*a + atom, :pd_t, 0].cpu(
                        ).numpy(), label="true_atom_" + str(atom), color=color, alpha=0.3)
                        ax[1, 0].phase_spectrum(pred_y[0, 4*a+atom, :pd_t, 0].cpu().numpy(
                        ), label="pred_atom" + str(atom), color=color, marker="o", alpha=0.5)
                        ax[1, 0].set_title("Phase Spectrum of x")
                        # else:
                        ax[1, 1].phase_spectrum(batch_dict_decoder["data"][4*a + atom, :pd_t, 1].cpu(
                        ).numpy(), label="true_atom_" + str(atom), color=color, alpha=0.3)
                        ax[1, 1].phase_spectrum(pred_y[0, 4*a+atom, :pd_t, 1].cpu().numpy(
                        ), label="pred_atom" + str(atom), color=color, marker="o", alpha=0.5)
                        ax[1, 1].set_title("Phase Spectrum of y")

                    if args.data!="bball":
                        rows, cols = np.where(
                            batch_full_edges[a, :, :].cpu().numpy() == 1)
                        # print(batch_full_edges[a, :, :].cpu().numpy().shape)
                        edges = zip(rows.tolist(), cols.tolist())
                        # print(edges.shape)
                        # print(np.shape(list(edges)))
                        # gr = nx.Graph()
                        # gr.add_edges_from(edges)
                        ax[1, 2].set_title("Graph")
                        gr = nx.DiGraph(
                            batch_full_edges[a, :, :].cpu().numpy())

                        # print(gr.shape)
                        nx.draw(gr, ax=ax[1, 2], with_labels=True,
                                node_color=node_colors, node_size=1000, font_size=10)

                    # plt.scatter(batch_dict_decoder["data"][4*a, :pd_t, 0].cpu().numpy(), batch_dict_decoder["data"]
                    #             [4*a, :pd_t, 1].cpu().numpy(), label="true_atom1", color="red", marker="x", alpha=0.3)
                    # plt.scatter(batch_dict_decoder["data"][4*a + 1, :pd_t, 0].cpu().numpy(), batch_dict_decoder["data"]
                    #             [4*a+1, :pd_t, 1].cpu().numpy(), label="true_atom2", color="green", marker="x", alpha=0.3)
                    # plt.scatter(batch_dict_decoder["data"][4*a + 2, :pd_t, 0].cpu().numpy(), batch_dict_decoder["data"]
                    #             [4*a+2, :pd_t, 1].cpu().numpy(), label="true_atom3", color="cyan", marker="x", alpha=0.3)
                    # plt.scatter(batch_dict_decoder["data"][4*a + 3, :pd_t, 0].cpu().numpy(), batch_dict_decoder["data"]
                    #             [4*a+3, :pd_t, 1].cpu().numpy(), label="true_atom4", color="purple", marker="x", alpha=0.3)
                    # plt.scatter(pred_y[0, 4*a, :pd_t, 0].cpu().numpy(), pred_y[0, 4*a, :pd_t, 1].cpu(
                    # ).numpy(), label="pred_atom1", color="red", marker="o", alpha=0.5)
                    # plt.scatter(pred_y[0, 4*a+1, :pd_t, 0].cpu().numpy(), pred_y[0, 4*a+1, :pd_t, 1].cpu(
                    # ).numpy(), label="pred_atom2", color="green", marker="o", alpha=0.5)
                    # plt.scatter(pred_y[0, 4*a+2, :pd_t, 0].cpu().numpy(), pred_y[0, 4*a+2, :pd_t, 1].cpu(
                    # ).numpy(), label="pred_atom3", color="cyan", marker="o", alpha=0.5)
                    # plt.scatter(pred_y[0, 4*a+3, :pd_t, 0].cpu().numpy(), pred_y[0, 4*a+3, :pd_t, 1].cpu(
                    # ).numpy(), label="pred_atom4", color="purple", marker="o", alpha=0.5)

                    # plt.scatter(batch_vis_encoder["data"][4*a, :pd_t, 0].cpu().numpy(), batch_vis_encoder["data"]
                    #             [4*a, :pd_t, 1].cpu().numpy(), label="true_atom1", color="red", marker="*", alpha=1)
                    # plt.scatter(batch_vis_encoder["data"][4*a + 1, :pd_t, 0].cpu().numpy(), batch_vis_encoder["data"]
                    #             [4*a+1, :pd_t, 1].cpu().numpy(), label="true_atom2", color="green", marker="*", alpha=1)
                    # plt.scatter(batch_vis_encoder["data"][4*a + 2, :pd_t, 0].cpu().numpy(), batch_vis_encoder["data"]
                    #             [4*a+2, :pd_t, 1].cpu().numpy(), label="true_atom3", color="cyan", marker="*", alpha=1)
                    # plt.scatter(batch_vis_encoder["data"][4*a + 3, :pd_t, 0].cpu().numpy(), batch_vis_encoder["data"]
                    #             [4*a+3, :pd_t, 1].cpu().numpy(), label="true_atom4", color="purple", marker="*", alpha=1)

                    #TODO: plot the predicted trajectory
                    # fig, ax = plt.subplots(2, 2)
                    # ax[0,0].plot(pred_y[0, 4*a, :, 0].cpu().numpy(),  label="pred_atom1", color="red", marker="o", alpha=0.5)
                    # ax[0, 0].set_title('x')
                    # ax[0,1].plot(pred_y[0, 4*a, :, 1].cpu().numpy(),  label="pred_atom1", color="blue", marker="o", alpha=0.5)
                    # ax[0, 1].set_title('y')
                    # ax[1,0].plot(pred_y[0, 4*a, :, 2].cpu().numpy(),  label="pred_atom1", color="green", marker="o", alpha=0.5)
                    # ax[1, 0].set_title('xdot')
                    # ax[1,1].plot(pred_y[0, 4*a, :, 3].cpu().numpy(),  label="pred_atom1", color="cyan", marker="o", alpha=0.5)
                    # ax[1, 1].set_title('ydot')

                    # ax[0,0].plot(batch_dict_decoder["data"][4*a, :, 0].cpu().numpy(), label="true_atom1", color="red", marker="x", alpha=0.3)
                    # ax[0, 0].set_title('x')
                    # ax[0,1].plot(batch_dict_decoder["data"][4*a, :, 1].cpu().numpy(), label="true_atom1", color="blue", marker="x", alpha=0.3)
                    # ax[0, 1].set_title('y')
                    # ax[1,0].plot(batch_dict_decoder["data"][4*a, :, 2].cpu().numpy(), label="true_atom1", color="green", marker="x", alpha=0.3)
                    # ax[1, 0].set_title('xdot')
                    # ax[1,1].plot(batch_dict_decoder["data"][4*a, :, 3].cpu().numpy(), label="true_atom1", color="cyan", marker="x", alpha=0.3)
                    # ax[1, 1].set_title('ydot')

                    # ax[0,0].plot(batch_vis_encoder["data"][4*a, :, 0].cpu().numpy(), label="true_atom1", color="red", marker="*", alpha=1)
                    # ax[0, 0].set_title('x')
                    # ax[0,1].plot(batch_vis_encoder["data"][4*a, :, 1].cpu().numpy(), label="true_atom1", color="blue", marker="*", alpha=1)
                    # ax[0, 1].set_title('y')
                    # ax[1,0].plot(batch_vis_encoder["data"][4*a, :, 2].cpu().numpy(), label="true_atom1", color="green", marker="*", alpha=1)
                    # ax[1, 0].set_title('xdot')
                    # ax[1,1].plot(batch_vis_encoder["data"][4*a, :, 3].cpu().numpy(), label="true_atom1", color="cyan", marker="*", alpha=1)
                    # ax[1, 1].set_title('ydot')

                    # plt.plot(pred_y[0, a, :, 0].cpu().numpy(), label="pred_x")
                    # plt.plot(pred_y[0, a, :, 1].cpu().numpy(), label="pred_y")
                    # plt.plot(pred_y[0, a, :, 2].cpu().numpy(), label="pred_xdot")
                    # plt.plot(pred_y[0, a, :, 3].cpu().numpy(), label="pred_ydot")
                    # plt.plot(batch_dict_decoder["data"][ a, :, 0].cpu().numpy(), label="true_x")
                    # plt.plot(batch_dict_decoder["data"][ a, :, 1].cpu().numpy(), label="true_y")
                    # plt.plot(batch_dict_decoder["data"][ a, :, 2].cpu().numpy(), label="true_xdot")
                    # plt.plot(batch_dict_decoder["data"][ a, :, 3].cpu().numpy(), label="true_ydot")
                    # plt.legend()
                    if args.save_flag:
                        print("saving")
                        plt.savefig(
                            args.save + "plot/test_{}_atom_{}.png".format(i, a))
                    plt.close()
                    fig, ax = plt.subplots(2, 2, figsize=(20, 10))
                    for atom in range(args.n_balls):
                        color = colors[atom]
                        ax[0, 0].scatter(pred_y[0, 4*a+atom, :pd_t, 0].cpu().numpy(), pred_y[0, 4*a+atom, :pd_t, 1].cpu(
                            ).numpy(), label="pred_atom" + str(atom), color=color, marker="o", alpha=0.5)
                        ax[0, 0].scatter(batch_vis_encoder["data"][4*a+atom, :pd_enc, 0].cpu().numpy(), batch_vis_encoder["data"]
                                            [4*a+atom, :pd_enc, 1].cpu().numpy(), label="encoder_atom_" + str(atom), color=color, marker="*", alpha=1)
                        ax[0, 0].set_title("Model Trajectory")
                        ax[0, 0].set_xlabel("X")
                        ax[0, 0].set_ylabel("Y")
                        # ax[0, 0].legend()

                        ax[0, 1].scatter(batch_dict_decoder["data"][4*a+atom, :pd_t, 0].cpu().numpy(), batch_dict_decoder["data"]
                                            [4*a+atom, :pd_t, 1].cpu().numpy(), label="true_atom_" + str(atom), color=color, marker="x", alpha=0.5) 
                        ax[0,1].scatter(batch_vis_encoder["data"][4*a+atom, :pd_enc, 0].cpu().numpy(), batch_vis_encoder["data"]
                                            [4*a+atom, :pd_enc, 1].cpu().numpy(), label="encoder_atom_" + str(atom), color=color, marker="*", alpha=1)
                        ax[0, 1].set_title(" GT Trajectory")
                        ax[0, 1].set_xlabel("X")
                        ax[0, 1].set_ylabel("Y")
                        ax[0, 1].legend()

                        ax[1, 0].scatter(pred_y[0, 4*a+atom, :pd_t, 2].cpu().numpy(), pred_y[0, 4*a+atom, :pd_t, 3].cpu(
                            ).numpy(), label="pred_atom" + str(atom), color=color, marker="o", alpha=0.5)
                        ax[1, 0].scatter(batch_vis_encoder["data"][4*a+atom, :pd_t, 2].cpu().numpy(), batch_vis_encoder["data"]
                                            [4*a+atom, :pd_t, 3].cpu().numpy(), label="encoder_atom_" + str(atom), color=color, marker="*", alpha=1)
                        ax[1, 0].set_title("Model Velocity")
                        ax[1, 0].set_xlabel("Xdot")
                        ax[1, 0].set_ylabel("Ydot")
                        # ax[0, 2].legend()

                        ax[1, 1].scatter(batch_dict_decoder["data"][4*a+atom, :pd_t, 2].cpu().numpy(), batch_dict_decoder["data"]
                                            [4*a+atom, :pd_t, 3].cpu().numpy(), label="true_atom_" + str(atom), color=color, marker="x", alpha=0.5)
                        ax[1, 1].scatter(batch_vis_encoder["data"][4*a+atom, :pd_t, 2].cpu().numpy(), batch_vis_encoder["data"]
                                            [4*a+atom, :pd_t, 3].cpu().numpy(), label="encoder_atom_" + str(atom), color=color, marker="*", alpha=1)
                        ax[1, 1].set_title("GT Velocity")
                        ax[1, 1].set_xlabel("Xdot")
                        ax[1, 1].set_ylabel("Ydot")


                    if args.save_flag:
                        print("saving")
                        if not os.path.exists(args.save + "plot2/"):
                            os.makedirs(args.save + "plot2/")
                        plt.savefig(
                            args.save + "plot2/test_{}_atom_{}.png".format(i, a))
                    plt.close()
                    # ax[1, 0].legend()




                                

                    
                # print(batch_dict_decoder["data"].shape)
                # print(info)
                # # pred_y = pred_y[0]
                # print(pred_y.shape)
                # for i in range(pred_y.shape[2]):
                del batch_dict_encoder, batch_dict_graph, batch_dict_decoder
        numpy.save(args.save + "test_loss.npy", final_loss)
    # print(test_batch)
    test()
