from synthetic_sim_coupling import *
import time
import numpy as np
import argparse
import os
import sys 
sys.path.append("/hdd2/extra_home/hkumawat6/Projects/AttentionNet/")
from scripts.utils import *


parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default=None,
                    help='What simulation to generate.')
parser.add_argument('--simulation', type=str, default='springs',
                    help='What simulation to generate.')
parser.add_argument('--n_balls', type=int, default=10,
                    help='Number of balls in the simulation.')
parser.add_argument('--hide_balls', type=int, default=5,
                    help='Number of balls to hide in the simulation.')

args = parser.parse_args()
cfg = read_config(args.config) 

## READ THE CONFIG FILE 

# args.simulation = cfg.get('Experiment', 'simulation')
# args.n_balls = cfg.getint('Experiment', 'n-balls')
# args.hide_balls = cfg.getint('Experiment', 'hide_balls')
args.save_dir = '/hdd2/extra_home/hkumawat6/Projects/AttentionNet/data_files_coupling2/'

args.num_train = cfg.getint('Experiment', 'num-train')
args.num_test = cfg.getint('Experiment', 'num-test')
args.ode = cfg.getint('Default', 'ode')
args.num_test_box = cfg.getint('Default', 'num-test-box')
args.num_test_extra = cfg.getint('Default', 'num-test-extra')
args.sample_freq = cfg.getint('Default', 'sample-freq')
args.ob_max = cfg.getint('Default', 'ob_max')
args.ob_min = cfg.getint('Default', 'ob_min')
args.seed = cfg.getint('Default', 'seed')






def generate_dataset_hide_n_seek_springs(args, num_sims, isTrain=True, hide_balls=2):
    loc_all = list()
    vel_all = list()
    edges = list()
    timestamps = list()

    loc_all_broken = list()
    vel_all_broken = list()
    edges_broken = list()

    for i in range(num_sims):
        t = time.time()
        #graph generation
        broken_graph_edges, static_graph, adj_mod = sim.generate_static_graph_hide_n_seek(
            hide_balls=hide_balls)
        edges.append(static_graph)  # [5,5]
        edges_broken.append(broken_graph_edges)  # [5,5]

        loc, vel, T_samples = sim.sample_trajectory_static_graph_irregular_difflength_each_no_sampling(args, edges=adj_mod,
                                                                                                       isTrain=isTrain)

        # loc, vel, T_samples = sim.sample_trajectory_static_graph_irregular_difflength_each(args, edges=static_graph, isTrain=isTrain)
        # print(123)
        # print(loc.shape, "loc")
        if i % 100 == 0:
            print("Iter: {}, Simulation time: {}".format(i, time.time() - t))
        loc_all.append(loc)  # [49,2,5]
        vel_all.append(vel)  # [49,2,5]
        timestamps.append(T_samples)  # [99]
        # print(loc[0].shape,"loc")
        loc_all_broken.append(loc[:(args.n_balls-hide_balls)])  # [49,2,5]
        vel_all_broken.append(vel[:(args.n_balls-hide_balls)])  # [49,2,5]

    # print(loc_all.shape, "loc_all")
    loc_all = np.asarray(loc_all)  # [5000,5 list(timestamps,2)]
    vel_all = np.asarray(vel_all)
    edges = np.stack(edges)
    timestamps = np.asarray(timestamps)
    edges_broken = np.stack(edges_broken)
    loc_all_broken = np.asarray(loc_all_broken)  # [5000,5 list(timestamps,2)]
    vel_all_broken = np.asarray(vel_all_broken)
    return loc_all, loc_all_broken, vel_all, vel_all_broken, edges, broken_graph_edges, timestamps


def generate_dataset_hide_n_seek_charged(args, num_sims, isTrain=True, hide_balls=2):
    loc_all = list()
    vel_all = list()
    edges = list()
    timestamps = list()

    loc_all_broken = list()
    vel_all_broken = list()
    edges_broken = list()

    for i in range(num_sims):
        t = time.time()
        #graph generation
        static_graph, diag_mask  = sim.generate_static_graph()
        edges.append(static_graph)  # [5,5]
      

        loc, vel, T_samples = sim.sample_trajectory_static_graph_irregular_difflength_each_no_sampling(args, edges=static_graph,
                                                                                                       isTrain=isTrain, diag_mask=diag_mask)

        # loc, vel, T_samples = sim.sample_trajectory_static_graph_irregular_difflength_each(args, edges=static_graph, isTrain=isTrain)
        # print(123)
        # print(loc.shape, "loc")
        if i % 100 == 0:
            print("Iter: {}, Simulation time: {}".format(i, time.time() - t))
        loc_all.append(loc)  # [49,2,5]
        vel_all.append(vel)  # [49,2,5]
        timestamps.append(T_samples)  # [99]
        # print(loc[0].shape,"loc")
        loc_all_broken.append(loc[:(args.n_balls-hide_balls)])  # [49,2,5]
        vel_all_broken.append(vel[:(args.n_balls-hide_balls)])  # [49,2,5]

    # print(loc_all.shape, "loc_all")
    loc_all = np.asarray(loc_all)  # [5000,5 list(timestamps,2)]
    vel_all = np.asarray(vel_all)
    edges = np.stack(edges)
    timestamps = np.asarray(timestamps)
    # edges_broken = np.stack(edges_broken)
    loc_all_broken = np.asarray(loc_all_broken)  # [5000,5 list(timestamps,2)]
    vel_all_broken = np.asarray(vel_all_broken)
    return loc_all, loc_all_broken, vel_all, vel_all_broken, edges, None, timestamps

if args.simulation == "springs":
    coupling = [ 0.5,1.5, 2.0, 2.5 , 3.0, 3.5, 4.0, 4.5, 5.0]

    save_name = args.save_dir 
    for c in coupling:
        args.save_dir = save_name + "/" + args.simulation+ '/' + args.simulation + \
            str(args.n_balls) + "_hide_" + str(args.hide_balls)+ "_coupling_"+ str(c) + "/"
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
        if args.simulation == 'springs':
            sim = SpringSim(noise_var=0.0, n_balls=args.n_balls, coupling=c)
            suffix = "_springs"
            
        elif args.simulation == 'charged':
            sim = ChargedParticlesSim(noise_var=0.0, n_balls=args.n_balls)
            suffix = '_charged'
        else:
            raise ValueError('Simulation {} not implemented'.format(args.simulation))

        suffix += str(args.n_balls)
        np.random.seed(args.seed)

        print(suffix)
        print("coupling", c)




        print("Generating {} test simulations".format(args.num_test))
        # print(args.save_dir+'loc_all_test' + suffix + '.npy')
        loc_all_test, loc_all_broken_test, vel_all_test, vel_all_broken_test, edges_test, broken_graph_edges_test, timestamps_test = generate_dataset_hide_n_seek_springs(
            args, args.num_test, isTrain=False, hide_balls=args.hide_balls)
        print(loc_all_test.shape, "loc_all_test")
        # print(loc_all_test.shape, "loc_all_test")
    
        np.save(args.save_dir+'loc_all_test' + suffix + '.npy', loc_all_test)
        np.save(args.save_dir+'vel_all_test' + suffix + '.npy', vel_all_test)
        np.save(args.save_dir+'edges_test' + suffix + '.npy', edges_test)
        np.save(args.save_dir+'times_test' + suffix + '.npy', timestamps_test)
        np.save(args.save_dir+'loc_all_broken_test' +
                suffix + '.npy', loc_all_broken_test)
        np.save(args.save_dir+'vel_all_broken_test' +
                suffix + '.npy', vel_all_broken_test)
        np.save(args.save_dir+'broken_graph_edges_test' +
                suffix + '.npy', broken_graph_edges_test)

        print("Generating {} training simulations".format(args.num_train))
        loc_all_train, loc_all_broken_train, vel_all_train, vel_all_broken_train, edges_train, broken_graph_edges_train, timestamps_train = generate_dataset_hide_n_seek_springs(
            args, args.num_train, isTrain=True, hide_balls=args.hide_balls)
        np.save(args.save_dir+'loc_all_train' + suffix + '.npy', loc_all_train)
        np.save(args.save_dir+'vel_all_train' + suffix + '.npy', vel_all_train)
        np.save(args.save_dir+'edges_train' + suffix + '.npy', edges_train)

        np.save(args.save_dir+'loc_all_broken_train' +
                suffix + '.npy', loc_all_broken_train)
        np.save(args.save_dir+'vel_all_broken_train' +
                suffix + '.npy', vel_all_broken_train)
        np.save(args.save_dir+'broken_graph_edges_train' +
                suffix + '.npy', broken_graph_edges_train)
        np.save(args.save_dir+'times_train' + suffix + '.npy', timestamps_train)





