
import matplotlib.pyplot as plt
from scripts.utils import *
from scripts.utils import compute_loss_all_batches
from scripts.create_latent_ode_model import create_LatentODE_model
from torch.distributions.normal import Normal
import scripts.utils as utils
import torch.optim as optim
import torch
from random import SystemRandom
import numpy as np
import argparse
from tqdm import tqdm
from scripts.dataLoader import ParseData
import sys
import os
import networkx as nx
from matplotlib import rc
from scipy import signal
import pyEDM  
pred_size  = 30
# from scripts.dataLoader_bball import ParseData as BBallData 
# from scripts.dataLoader_penalize import  ParseData as PenalizeData
# from scripts.dataLoader_motion import ParseData as MotionData 
BASE_DIRECTORY = "/DIR/Projects/AttentionNet/"
sys.path.append(BASE_DIRECTORY)


# Generative model for noisy data based on ODE

BASE_DIRECTORY = "/DIR/Projects/AttentionNet/"
DATA_DIRECTORY = BASE_DIRECTORY + "data/"
COMMON_CONFIG = "/DIR/Projects/AttentionNet/configs/common/common.cfg"

# Generative model for noisy data based on ODE
parser = argparse.ArgumentParser('Latent ODE')
parser.add_argument('--config', type=str, default=None,
                    help='What simulation to generate.')
parser.add_argument('--ckpt', type=str, default=None, help = 'Path to the model checkpoint')
args = parser.parse_args()
args =  utils.set_args_from_config(args, args.config)
# args = utils.set_args_from_config(args, COMMON_CONFIG)
print(args)



args.total_balls = args.n_balls + args.hide_balls
args.logging_dir = BASE_DIRECTORY + "neurips_experiments/" +  \
    str(args.exp_no) + "_" + args.data + "_" + str(args.total_balls) + \
    "_" + str(args.hide_balls) + "/"
assert(int(args.rec_dims % args.n_heads) == 0)
args.save = args.logging_dir + args.save_tests

if not os.path.exists(args.save):
    os.makedirs(args.save)
    os.makedirs(args.save + "/" + args.save_graph)
if args.data == "spring":
    # args.dataset = '/DIR/Projects/AttentionNet/data_files/springs20_hide_10'
    # # args.dataset = "/home/cpslab/HemantWorkspace/2023/AttentionNet/data/Experiment1_5_4"
    args.suffix = '_springs' + str(args.total_balls)
    args.data_load_suffix = "_springs" + str(args.total_balls)
    args.total_ode_step = 60
elif args.data == "charged":
    args.suffix = '_charged' + str(args.total_balls)
    args.data_load_suffix = "_charged" + str(args.total_balls)
    args.total_ode_step = 60
elif args.data == "motion":
    # args.dataset = 'data/motion'
    args.suffix = 'motion'
    args.total_ode_step = 60
    args.data_load_suffix = "motion" 

elif args.data == "bball":
    args.dataset = '/DIR/Projects/AttentionNet/data_files/basketball'
    args.suffix = '_bball'
    args.total_ode_step = 49
    args.data_load_suffix = "_bball" + str(args.total_balls)
  


# ############ CPU AND GPU related, Mode related, Dataset Related
# if torch.cuda.is_available():
# 	print("Using GPU" + "-"*80)
# 	device = torch.device("cuda:" + str(args.gpu))
# else:
# 	print("Using CPU" + "-" * 80)
# 	device = torch.device("cpu")

# if args.extrap == "True":
#     print("Running extrap mode" + "-"*80)
#     args.mode = "extrap"
# elif args.extrap == "False":
#     print("Running interp mode" + "-" * 80)
#     args.mode = "interp"

############ CPU AND GPU related, Mode related, Dataset Related
if torch.cuda.is_available():
	print("Using GPU" + "-"*80)
	device = torch.device("cuda:" + str(args.gpu))
else:
	print("Using CPU" + "-" * 80)
	device = torch.device("cpu")

if args.extrap == "True":
    print("Running extrap mode" + "-"*80)
    args.mode = "extrap"
elif args.extrap == "False":
    print("Running interp mode" + "-" * 80)
    args.mode = "interp"

### Plot related
# colors = ['red', 'green', 'blue', 'yellow', 'orange',
#           'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black', ]
## define 31 colors
colors = ['red', 'green', 'blue', 'yellow', 'orange',
            'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black',
            'red', 'green', 'blue', 'yellow', 'orange',
            'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black',
            'red', 'green', 'blue', 'yellow', 'orange',
            'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black', ]

node_colors = []
for i in range(args.n_balls + args.hide_balls):
    if i < args.n_balls:
        node_colors.append(colors[i])
    else:
        node_colors.append('black')

assert len(node_colors) == args.n_balls + args.hide_balls

# Set the global font to be DejaVu Sans, size 10 (or any other sans-serif font of your choice!)
rc('font', **{'family': 'sans-serif',
   'sans-serif': ['DejaVu Sans'], 'size': 10})
rc('mathtext', **{'default': 'regular'})

def store_array_to_file(array, filename):
    with open(filename, 'wb') as file:
        pickle.dump(array, file)

def read_array_from_file(filename):
    with open(filename, 'rb') as file:
        array = pickle.load(file)
        return array

if __name__ == '__main__':

    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)

    ############ Saving Path and Preload.
    file_name = os.path.basename(__file__)[:-3]  # run_models
    utils.makedirs(args.save)
    utils.makedirs(args.save_graph)
    experimentID = args.load
    if experimentID is None:
        # Make a new experiment ID
        experimentID = int(SystemRandom().random() * 100000)
    args.test_data_size = 2000
    print("Loading dataset: " + args.dataset)
    if args.data == "bball":
        dataloader = BBallData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    elif args.penalize:
        dataloader = PenalizeData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    elif args.data == "motion":
        dataloader = MotionData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    else:
        dataloader = ParseData(
            args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    test_encoder, test_decoder, test_graph, test_batch, vis_encoder, full_edges = dataloader.load_data(sample_percent=args.sample_percent_test,
                                                                                                       batch_size=args.batch_size,
                                                                                                       data_type="test")
    # train_encoder, train_decoder, train_graph, train_batch = dataloader.load_data(
    #     sample_percent=args.sample_percent_train, batch_size=args.batch_size, data_type="train")

    input_dim = dataloader.feature  # TODO: feature dimension

    input_command = sys.argv
    ind = [i for i in range(len(input_command))
           if input_command[i] == "--load"]
    if len(ind) == 1:
        ind = ind[0]
        input_command = input_command[:ind] + input_command[(ind + 2):]
    input_command = " ".join(input_command)

    obsrv_std = 0.01
    obsrv_std = torch.Tensor([obsrv_std]).to(device)
    z0_prior = Normal(torch.Tensor([0.0]).to(
        device), torch.Tensor([1.]).to(device))

    model = create_LatentODE_model(
        args, input_dim, z0_prior, obsrv_std, device)

    # Load checkpoint and evaluate the model
    if args.load is not None:
        ckpt_path = os.path.join(args.save, args.load)
        utils.get_ckpt_model(ckpt_path, model, device)
        print("Loaded checkpoint from {}".format(ckpt_path))

    wait_until_kl_inc = 10
    best_test_mse = np.inf
    n_iters_to_viz = 1

    def test():
        model.eval()
        loss = torch.nn.MSELoss()
        final_loss =  np.zeros((args.n_balls, 30))
        MSE_time = 0
        MSE_array = []
        last_step_error_array = []
        with torch.no_grad(): 
            predictions = torch.zeros((args.test_data_size, args.n_balls,pred_size, 4))
            gt_truth = torch.zeros((args.test_data_size, args.n_balls,pred_size, 4))
            
            for i in tqdm(range(test_batch)):
                batch_dict_encoder = get_next_batch_new(test_encoder, device)
                # print(batch_dict_encoder["data"].shape)
                batch_dict_graph = get_next_batch_new(test_graph, device)
                batch_dict_decoder = get_next_batch(test_decoder, device)
                batch_vis_encoder = get_next_batch(vis_encoder, device)
                batch_full_edges = full_edges.__next__()
                # print(batch_full_edges.shape, "batch_full_edges")
                # print(batch_full_edges[0])
                pred_y, info, temporal_weights = model.get_reconstruction(
                    batch_dict_encoder, batch_dict_decoder, batch_dict_graph, n_traj_samples=1)
                # print(pred_y[0, 0, :, 0])
                # print(batch_dict_decoder["data"][ 0, :, 0])
                # print(batch_vis_encoder["data"][ 0, :, 0])
                # exit()
                # print(pred_y.shape, "pred_y")
                import torch.nn.functional as F

                for a in range(args.batch_size):
                    for atom in range(args.n_balls): 
                        if i*args.batch_size + a < args.test_data_size:
                            pass
                        else:
                            break
                        predictions[i*args.batch_size + a, atom,
                                    :, :] = pred_y[0, 4*a+atom, :, :]

                        gt_truth[i*args.batch_size + a, atom,
                                    :, :] = batch_dict_decoder["data"][ 4*a+atom, :, :]
                       

            
            print(F.mse_loss(predictions, gt_truth, reduction='none').shape)
            loss = F.mse_loss(predictions, gt_truth, reduction='none').mean(dim=-1).mean(dim = 1)
            # print(loss.shape)
            # # print(loss.mean())
            # print(loss[:,-1].mean(), loss[:,-1].std())
            print(loss.mean(dim=0), loss.std(dim=0))
            print(loss.shape)
            print(loss.mean(dim=0).mean(), loss.mean(dim=0).std())
            # print(loss.mean(dim=0).mean(), loss.mean(dim=0).std())
            
                    

            
                
            # last_step_error_array = last_step_error_array.cpu().numpy()
            # store_array_to_file(last_step_error_array, args.save + "/" + args.save_graph + "/" + str(args.n_balls) +  str(args.hide_balls) + "last_step_error_array.pkl")
                
        
    test()
