import h5py
import torch
import numpy as np
import itertools

NUM_SEEDS = 200
#NUM_SEEDS = 3
bcs = ['periodic', 'neumann', 'dirichlet']
ics = ['exp', 'prod_sin', 'sum_sin']
runs = list(itertools.product(bcs, ics))

all_data,all_time,all_coeff,all_bc_name,all_bc_fac,all_ic_name,all_ic_fac = [],[],[],[],[],[],[]
for run in runs:
    data = torch.Tensor(np.load("./{}/{}_{}_data.npy".format(NUM_SEEDS, run[0], run[1]), allow_pickle=True))
    time = torch.Tensor(np.load("./{}/{}_{}_time.npy".format(NUM_SEEDS, run[0], run[1])))
    coeff = torch.Tensor(np.load("./{}/{}_{}_coeff.npy".format(NUM_SEEDS, run[0], run[1])))
    bc_name = np.load("./{}/{}_{}_bc_name.npy".format(NUM_SEEDS, run[0], run[1]))
    bc_fac = torch.Tensor(np.load("./{}/{}_{}_bc_fac.npy".format(NUM_SEEDS, run[0], run[1])))
    ic_name = np.load("./{}/{}_{}_ic_name.npy".format(NUM_SEEDS, run[0], run[1]))
    ic_fac = torch.Tensor(np.load("./{}/{}_{}_ic_fac.npy".format(NUM_SEEDS, run[0], run[1])))
    grid = torch.Tensor(np.load("./{}/{}_{}_grid.npy".format(NUM_SEEDS, run[0], run[1])))
    
    #print(data.shape)
    #print(grid.shape)
    print(data.shape)

    all_data.append(data)
    all_time.append(time)
    all_coeff.append(coeff)
    all_bc_name.append(bc_name)
    all_bc_fac.append(bc_fac)
    all_ic_name.append(ic_name)
    all_ic_fac.append(ic_fac)
    
    #diverge_check = data.sum(dim=(1,2,3))
    #keeps = torch.logical_not(diverge_check.isnan())
    #print(keeps.all())

print()
all_data = torch.cat(all_data, dim=0)
diverge_check = all_data.sum(dim=(1,2,3))
keeps = torch.logical_not(diverge_check.isnan())
print(len(keeps) - sum(keeps))
#print(keeps.all())
#raise

all_data = all_data[keeps]
all_time = torch.cat(all_time, dim=0)[keeps]
all_coeff = torch.cat(all_coeff, dim=0)[keeps]
all_bc_name = np.stack(all_bc_name).flatten()[keeps]
all_bc_fac = torch.cat(all_bc_fac, dim=0)[:,0][keeps]
all_ic_name = np.stack(all_ic_name).flatten()[keeps]
all_ic_fac = torch.cat(all_ic_fac, dim=0)[:,0][keeps]
print()
print(all_data.shape)
print(all_time.shape)
print(all_coeff.shape)
print(all_bc_name.shape)
print(all_bc_fac.shape)
print(all_ic_name.shape)
print(all_ic_fac.shape)

np.save("./{}/all_data.npy".format(NUM_SEEDS), all_data)
np.save("./{}/all_time.npy".format(NUM_SEEDS), all_time)
np.save("./{}/all_coeff.npy".format(NUM_SEEDS), all_coeff)
np.save("./{}/all_bc_name.npy".format(NUM_SEEDS), all_bc_name)
np.save("./{}/all_bc_fac.npy".format(NUM_SEEDS), all_bc_fac)
np.save("./{}/all_ic_name.npy".format(NUM_SEEDS), all_ic_name)
np.save("./{}/all_ic_fac.npy".format(NUM_SEEDS), all_ic_fac)
np.save("./{}/all_grid.npy".format(NUM_SEEDS), grid)

