from synthetic_sim import *
import time
import numpy as np
import argparse
import os
import sys 
sys.path.append("/hdd2/extra_home/hkumawat6/Projects/AttentionNet/")
from scripts.utils import *


parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default=None,
                    help='What simulation to generate.')
parser.add_argument('--simulation', type=str, default='springs',
                    help='What simulation to generate.')
parser.add_argument('--n_balls', type=int, default=10,
                    help='Number of balls in the simulation.')
parser.add_argument('--hide_balls', type=int, default=5,
                    help='Number of balls to hide in the simulation.')

args = parser.parse_args()
cfg = read_config(args.config) 

## READ THE CONFIG FILE 

# args.simulation = cfg.get('Experiment', 'simulation')
# args.n_balls = cfg.getint('Experiment', 'n-balls')
# args.hide_balls = cfg.getint('Experiment', 'hide_balls')
args.save_dir = '/hdd2/extra_home/hkumawat6/Projects/AttentionNet/data_files_long2/'

args.num_train = cfg.getint('Experiment', 'num-train')
args.num_test = cfg.getint('Experiment', 'num-test')
args.ode = cfg.getint('Default', 'ode')
args.num_test_box = cfg.getint('Default', 'num-test-box')
args.num_test_extra = cfg.getint('Default', 'num-test-extra')
args.sample_freq = cfg.getint('Default', 'sample-freq')
args.ob_max = cfg.getint('Default', 'ob_max')
args.ob_min = cfg.getint('Default', 'ob_min')
args.seed = cfg.getint('Default', 'seed')

print(args.n_balls)

args.save_dir = args.save_dir + "/" + args.simulation+ '/' + args.simulation + str(args.n_balls) + "_hide_" + str(args.hide_balls)+ "_ode_"+ str(args.ode) + "/"
if not os.path.exists(args.save_dir):
    os.makedirs(args.save_dir)

if args.simulation == 'springs':
    sim = SpringSim(noise_var=0.0, n_balls=args.n_balls)
    suffix = "_springs"
    
elif args.simulation == 'charged':
    sim = ChargedParticlesSim(noise_var=0.0, n_balls=args.n_balls)
    suffix = '_charged'
else:
    raise ValueError('Simulation {} not implemented'.format(args.simulation))

suffix += str(args.n_balls)
np.random.seed(args.seed)

print(suffix)


def generate_dataset_hide_n_seek_springs(args, num_sims, isTrain=True, hide_balls=2):
    loc_all = list()
    vel_all = list()
    edges = list()
    timestamps = list()

    loc_all_broken = list()
    vel_all_broken = list()
    edges_broken = list()

    for i in range(num_sims):
        t = time.time()
        #graph generation
        broken_graph_edges, static_graph = sim.generate_static_graph_hide_n_seek(
            hide_balls=hide_balls)
        edges.append(static_graph)  # [5,5]
        edges_broken.append(broken_graph_edges)  # [5,5]

        loc, vel, T_samples = sim.sample_trajectory_static_graph_irregular_difflength_each_no_sampling(args, edges=static_graph,
                                                                                                       isTrain=isTrain)

        # loc, vel, T_samples = sim.sample_trajectory_static_graph_irregular_difflength_each(args, edges=static_graph, isTrain=isTrain)
        # print(123)
        # print(loc.shape, "loc")
        if i % 100 == 0:
            print("Iter: {}, Simulation time: {}".format(i, time.time() - t))
        loc_all.append(loc)  # [49,2,5]
        vel_all.append(vel)  # [49,2,5]
        timestamps.append(T_samples)  # [99]
        # print(loc[0].shape,"loc")
        loc_all_broken.append(loc[:(args.n_balls-hide_balls)])  # [49,2,5]
        vel_all_broken.append(vel[:(args.n_balls-hide_balls)])  # [49,2,5]

    # print(loc_all.shape, "loc_all")
    loc_all = np.asarray(loc_all)  # [5000,5 list(timestamps,2)]
    vel_all = np.asarray(vel_all)
    edges = np.stack(edges)
    timestamps = np.asarray(timestamps)
    edges_broken = np.stack(edges_broken)
    loc_all_broken = np.asarray(loc_all_broken)  # [5000,5 list(timestamps,2)]
    vel_all_broken = np.asarray(vel_all_broken)
    return loc_all, loc_all_broken, vel_all, vel_all_broken, edges, broken_graph_edges, timestamps


def generate_dataset_hide_n_seek_charged(args, num_sims, isTrain=True, hide_balls=2):
    loc_all = list()
    vel_all = list()
    edges = list()
    timestamps = list()

    loc_all_broken = list()
    vel_all_broken = list()
    edges_broken = list()

    for i in range(num_sims):
        t = time.time()
        #graph generation
        static_graph, diag_mask  = sim.generate_static_graph()
        edges.append(static_graph)  # [5,5]
      

        loc, vel, T_samples = sim.sample_trajectory_static_graph_irregular_difflength_each_no_sampling(args, edges=static_graph,
                                                                                                       isTrain=isTrain, diag_mask=diag_mask)

        # loc, vel, T_samples = sim.sample_trajectory_static_graph_irregular_difflength_each(args, edges=static_graph, isTrain=isTrain)
        # print(123)
        # print(loc.shape, "loc")
        if i % 100 == 0:
            print("Iter: {}, Simulation time: {}".format(i, time.time() - t))
        loc_all.append(loc)  # [49,2,5]
        vel_all.append(vel)  # [49,2,5]
        timestamps.append(T_samples)  # [99]
        # print(loc[0].shape,"loc")
        loc_all_broken.append(loc[:(args.n_balls-hide_balls)])  # [49,2,5]
        vel_all_broken.append(vel[:(args.n_balls-hide_balls)])  # [49,2,5]

    # print(loc_all.shape, "loc_all")
    loc_all = np.asarray(loc_all)  # [5000,5 list(timestamps,2)]
    vel_all = np.asarray(vel_all)
    edges = np.stack(edges)
    timestamps = np.asarray(timestamps)
    # edges_broken = np.stack(edges_broken)
    loc_all_broken = np.asarray(loc_all_broken)  # [5000,5 list(timestamps,2)]
    vel_all_broken = np.asarray(vel_all_broken)
    return loc_all, loc_all_broken, vel_all, vel_all_broken, edges, None, timestamps

if args.simulation == "springs":
    print("Generating {} test simulations".format(args.num_test))
    # print(args.save_dir+'loc_all_test' + suffix + '.npy')
    loc_all_test, loc_all_broken_test, vel_all_test, vel_all_broken_test, edges_test, broken_graph_edges_test, timestamps_test = generate_dataset_hide_n_seek_springs(
        args, args.num_test, isTrain=False, hide_balls=args.hide_balls)
    print(loc_all_test.shape, "loc_all_test")
    print(loc_all_test.shape, "loc_all_test")
    print(edges_test.shape, "edges_test")
    print(timestamps_test.shape, "timestamps_test")
    # for i in range(5):
    #     plt.plot(loc_all_test[0,i,:,0], loc_all_test[0,i,:,1])
    #     # plt.plot(loc_all_broken_test[0,i,:,0], loc_all_broken_test[0,i,:,1])
    # plt.show()
    # plt.figure()
    # plt.scatter(loc_all_test[0,:,0,0], loc_all_test[0,:,1,0])
    # plt.scatter(loc_all_test[0,:,0,1], loc_all_test[0,:,1,1])
    # plt.scatter(loc_all_test[0,:,0,2], loc_all_test[0,:,1,2])
    # plt.scatter(loc_all_test[0,:,0,3], loc_all_test[0,:,1,3])
    # plt.scatter(loc_all_test[0,:,0,4], loc_all_test[0,:,1,4])
    # plt.show()
    np.save(args.save_dir+'loc_all_test' + suffix + '.npy', loc_all_test)
    np.save(args.save_dir+'vel_all_test' + suffix + '.npy', vel_all_test)
    np.save(args.save_dir+'edges_test' + suffix + '.npy', edges_test)
    np.save(args.save_dir+'times_test' + suffix + '.npy', timestamps_test)
    np.save(args.save_dir+'loc_all_broken_test' +
            suffix + '.npy', loc_all_broken_test)
    np.save(args.save_dir+'vel_all_broken_test' +
            suffix + '.npy', vel_all_broken_test)
    np.save(args.save_dir+'broken_graph_edges_test' +
            suffix + '.npy', broken_graph_edges_test)

    print("Generating {} training simulations".format(args.num_train))
    loc_all_train, loc_all_broken_train, vel_all_train, vel_all_broken_train, edges_train, broken_graph_edges_train, timestamps_train = generate_dataset_hide_n_seek_springs(
        args, args.num_train, isTrain=True, hide_balls=args.hide_balls)
    np.save(args.save_dir+'loc_all_train' + suffix + '.npy', loc_all_train)
    np.save(args.save_dir+'vel_all_train' + suffix + '.npy', vel_all_train)
    np.save(args.save_dir+'edges_train' + suffix + '.npy', edges_train)

    np.save(args.save_dir+'loc_all_broken_train' +
            suffix + '.npy', loc_all_broken_train)
    np.save(args.save_dir+'vel_all_broken_train' +
            suffix + '.npy', vel_all_broken_train)
    np.save(args.save_dir+'broken_graph_edges_train' +
            suffix + '.npy', broken_graph_edges_train)
    np.save(args.save_dir+'times_train' + suffix + '.npy', timestamps_train)

elif args.simulation == "charged":
    print("Generating {} test simulations".format(args.num_test))
    loc_all_test, loc_all_broken_test, vel_all_test, vel_all_broken_test, edges_test, broken_graph_edges_test, timestamps_test = generate_dataset_hide_n_seek_charged(
        args, args.num_test, isTrain=False, hide_balls=args.hide_balls)
    # for i in range(5):
    #     plt.plot(loc_all_test[0,i,:,0], loc_all_test[0,i,:,1])
    #     # plt.plot(loc_all_broken_test[0,i,:,0], loc_all_broken_test[0,i,:,1])
    # plt.show()
    # # plt.figure()
    # # plt.scatter(loc_all_test[0,:,0,0], loc_all_test[0,:,1,0])
    # # plt.scatter(loc_all_test[0,:,0,1], loc_all_test[0,:,1,1])
    # # plt.scatter(loc_all_test[0,:,0,2], loc_all_test[0,:,1,2])
    # # plt.scatter(loc_all_test[0,:,0,3], loc_all_test[0,:,1,3])
    # # plt.scatter(loc_all_test[0,:,0,4], loc_all_test[0,:,1,4])
    # plt.show()
    # plt.plot(loc_all_test[0,0,:,0])
    # plt.figure()
    # plt.plot(loc_all_test[0,0,:,1])
    # plt.show()
    np.save(args.save_dir+'loc_all_test' + suffix + '.npy', loc_all_test)
    np.save(args.save_dir+'vel_all_test' + suffix + '.npy', vel_all_test)
    np.save(args.save_dir+'edges_test' + suffix + '.npy', edges_test)
    np.save(args.save_dir+'times_test' + suffix + '.npy', timestamps_test)
    np.save(args.save_dir+'loc_all_broken_test' +
            suffix + '.npy', loc_all_broken_test)
    np.save(args.save_dir+'vel_all_broken_test' +
            suffix + '.npy', vel_all_broken_test)
    np.save(args.save_dir+'broken_graph_edges_test' +
            suffix + '.npy', broken_graph_edges_test)
    
    print("Generating {} training simulations".format(args.num_train))
    loc_all_train, loc_all_broken_train, vel_all_train, vel_all_broken_train, edges_train, broken_graph_edges_train, timestamps_train = generate_dataset_hide_n_seek_charged(
        args, args.num_train, isTrain=True, hide_balls=args.hide_balls)
    np.save(args.save_dir+'loc_all_train' + suffix + '.npy', loc_all_train)
    np.save(args.save_dir+'vel_all_train' + suffix + '.npy', vel_all_train)
    np.save(args.save_dir+'edges_train' + suffix + '.npy', edges_train)
    np.save(args.save_dir+'times_train' + suffix + '.npy', timestamps_train)
    np.save(args.save_dir+'loc_all_broken_train' +
            suffix + '.npy', loc_all_broken_train)
    np.save(args.save_dir+'vel_all_broken_train' +
            suffix + '.npy', vel_all_broken_train)
    np.save(args.save_dir+'broken_graph_edges_train' +
            suffix + '.npy', broken_graph_edges_train)
    



