
import matplotlib.pyplot as plt
from scripts.utils import *
from scripts.utils import compute_loss_all_batches
from scripts.create_latent_ode_model import create_LatentODE_model
from torch.distributions.normal import Normal
import scripts.utils as utils
import torch.optim as optim
import torch
from random import SystemRandom
import numpy as np
import argparse
from tqdm import tqdm
from scripts.dataLoader_long import ParseData
import sys
import os
import networkx as nx
from matplotlib import rc
from scipy import signal
import torch.nn.functional as F
# from scripts.dataLoader_bball import ParseData as BBallData 
# from scripts.dataLoader_penalize import  ParseData as PenalizeData
# from scripts.dataLoader_motion import ParseData as MotionData 
BASE_DIRECTORY = "/DIR/Projects/AttentionNet/"
sys.path.append(BASE_DIRECTORY)


# Generative model for noisy data based on ODE

BASE_DIRECTORY = "/DIR/Projects/AttentionNet/"
DATA_DIRECTORY = BASE_DIRECTORY + "data/"
COMMON_CONFIG = "/DIR/Projects/AttentionNet/configs/common/common.cfg"

# Generative model for noisy data based on ODE
parser = argparse.ArgumentParser('Latent ODE')
parser.add_argument('--config', type=str, default=None,
                    help='What simulation to generate.')
parser.add_argument('--ckpt', type=str, default=None, help = 'Path to the model checkpoint')
args = parser.parse_args()
args =  utils.set_args_from_config(args, args.config)
# args = utils.set_args_from_config(args, COMMON_CONFIG)
print(args)

# cfg = utils.read_config(args.config)

# for section_name in cfg.sections():
#     # loop over all keys in the section
#     for key in cfg[section_name]:
#         # get the value for the key
#         value = cfg[section_name][key]
#         # check if the value can be converted to an integer
#         try:
#             value = int(value)
#         except ValueError:
#             try:
#                 value = float(value)
#             except ValueError:
#                 try:
#                     value = str(value)
#                     if value == "":
#                         value = None
#                 except ValueError:
#                     pass

#         setattr(args, key, value)
        # print the key and value
        # print(f"{key}: {value} (int value: {value})")
# print(args.load)


args.total_balls = args.n_balls + args.hide_balls
args.logging_dir = BASE_DIRECTORY + "neurips_experiments/" +  \
    str(args.exp_no) + "_" + args.data + "_" + str(args.total_balls) + \
    "_" + str(args.hide_balls) + "/"
assert(int(args.rec_dims % args.n_heads) == 0)
args.save = args.logging_dir + args.save_tests

if not os.path.exists(args.save):
    os.makedirs(args.save)
    os.makedirs(args.save + "/" + args.save_graph)
if args.data == "spring":
    # args.dataset = '/DIR/Projects/AttentionNet/data_files/springs20_hide_10'
    # # args.dataset = "/home/cpslab/HemantWorkspace/2023/AttentionNet/data/Experiment1_5_4"
    args.suffix = '_springs' + str(args.total_balls)
    args.data_load_suffix = "_springs" + str(args.total_balls)
    args.total_ode_step = 90
elif args.data == "charged":
    args.suffix = '_charged' + str(args.total_balls)
    args.data_load_suffix = "_charged" + str(args.total_balls)
    args.total_ode_step = 60
elif args.data == "motion":
    # args.dataset = 'data/motion'
    args.suffix = 'motion'
    args.total_ode_step = 60
    args.data_load_suffix = "motion" 

elif args.data == "bball":
    args.dataset = '/DIR/Projects/AttentionNet/data_files/basketball'
    args.suffix = '_bball'
    args.total_ode_step = 49
    args.data_load_suffix = "_bball" + str(args.total_balls)
  


# ############ CPU AND GPU related, Mode related, Dataset Related
# if torch.cuda.is_available():
# 	print("Using GPU" + "-"*80)
# 	device = torch.device("cuda:" + str(args.gpu))
# else:
# 	print("Using CPU" + "-" * 80)
# 	device = torch.device("cpu")

# if args.extrap == "True":
#     print("Running extrap mode" + "-"*80)
#     args.mode = "extrap"
# elif args.extrap == "False":
#     print("Running interp mode" + "-" * 80)
#     args.mode = "interp"

############ CPU AND GPU related, Mode related, Dataset Related
if torch.cuda.is_available():
	print("Using GPU" + "-"*80)
	device = torch.device("cuda:" + str(args.gpu))
else:
	print("Using CPU" + "-" * 80)
	device = torch.device("cpu")

if args.extrap == "True":
    print("Running extrap mode" + "-"*80)
    args.mode = "extrap"
elif args.extrap == "False":
    print("Running interp mode" + "-" * 80)
    args.mode = "interp"

### Plot related
# colors = ['red', 'green', 'blue', 'yellow', 'orange',
#           'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black', ]
## define 31 colors
colors = ['red', 'green', 'blue', 'yellow', 'orange',
            'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black',
            'red', 'green', 'blue', 'yellow', 'orange',
            'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black',
            'red', 'green', 'blue', 'yellow', 'orange',
            'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black', ]

node_colors = []
for i in range(args.n_balls + args.hide_balls):
    if i < args.n_balls:
        node_colors.append(colors[i])
    else:
        node_colors.append('black')

assert len(node_colors) == args.n_balls + args.hide_balls

# Set the global font to be DejaVu Sans, size 10 (or any other sans-serif font of your choice!)
rc('font', **{'family': 'sans-serif',
   'sans-serif': ['DejaVu Sans'], 'size': 10})
rc('mathtext', **{'default': 'regular'})

if __name__ == '__main__':

    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)

    ############ Saving Path and Preload.
    file_name = os.path.basename(__file__)[:-3]  # run_models
    utils.makedirs(args.save)
    utils.makedirs(args.save_graph)
    experimentID = args.load
    if experimentID is None:
        # Make a new experiment ID
        experimentID = int(SystemRandom().random() * 100000)

    print("Loading dataset: " + args.dataset)
    if args.data == "bball":
        dataloader = BBallData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    elif args.penalize:
        dataloader = PenalizeData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    elif args.data == "motion":
        dataloader = MotionData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    else:
        dataloader = ParseData(
            args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    test_encoder, test_decoder, test_graph, test_batch, vis_encoder, full_edges = dataloader.load_data(sample_percent=args.sample_percent_test,
                                                                                                       batch_size=args.batch_size,
                                                                                                       data_type="test")
    # train_encoder, train_decoder, train_graph, train_batch = dataloader.load_data(
    #     sample_percent=args.sample_percent_train, batch_size=args.batch_size, data_type="train")

    input_dim = dataloader.feature  # TODO: feature dimension

    input_command = sys.argv
    ind = [i for i in range(len(input_command))
           if input_command[i] == "--load"]
    if len(ind) == 1:
        ind = ind[0]
        input_command = input_command[:ind] + input_command[(ind + 2):]
    input_command = " ".join(input_command)

    obsrv_std = 0.01
    obsrv_std = torch.Tensor([obsrv_std]).to(device)
    z0_prior = Normal(torch.Tensor([0.0]).to(
        device), torch.Tensor([1.]).to(device))

    model = create_LatentODE_model(
        args, input_dim, z0_prior, obsrv_std, device)

    # Load checkpoint and evaluate the model
    if args.load is not None:
        ckpt_path = os.path.join(args.save, args.load)
        utils.get_ckpt_model(ckpt_path, model, device)
        print("Loaded checkpoint from {}".format(ckpt_path))

    wait_until_kl_inc = 10
    best_test_mse = np.inf
    n_iters_to_viz = 1

    def test():
        model.eval()
        loss = torch.nn.MSELoss()
        final_loss =  np.zeros((args.n_balls, 30))
        with torch.no_grad():
            MSE_time = 0
            for i in tqdm(range(test_batch)):
                if i == test_batch-1:
                    break
                batch_dict_encoder = get_next_batch_new(test_encoder, device)
                # print(batch_dict_encoder["data"].shape)
                batch_dict_graph = get_next_batch_new(test_graph, device)
                batch_dict_decoder = get_next_batch(test_decoder, device)
                batch_vis_encoder = get_next_batch(vis_encoder, device)
                batch_full_edges = full_edges.__next__()
                # print(batch_full_edges.shape, "batch_full_edges")
                # print(batch_full_edges[0])
                pred_y, info, temporal_weights = model.get_reconstruction(
                    batch_dict_encoder, batch_dict_decoder, batch_dict_graph, n_traj_samples=1)
                print(pred_y.shape, "pred_y")
                print(F.mse_loss(
                    pred_y[0], batch_dict_decoder["data"], reduction='none').shape, "MSE_time")
                MSE_time += F.mse_loss(pred_y[0], batch_dict_decoder["data"], reduction='none').view(pred_y.size(1), pred_y.size(2), -1).mean(dim=-1).sum(dim=0)
            print(MSE_time, "MSE_time")
            MSE_time = MSE_time / (2000* args.n_balls)
            print(MSE_time.shape, "MSE_time")
            print(MSE_time[:60].mean(dim=0), "MSE_time")

             
    # print(test_batch)
    test()
