# import packages
import sys
import os
import numpy as np
from tqdm import tqdm

sys.path.append("../..")  # add the pde package to the python path
import pde
from tempfile import NamedTemporaryFile
import h5py
from numba import jit

from multiprocessing import pool
import itertools

import gc

NUM_SEEDS = 900

#def burgers_sim(grid_size=64, bcs='periodic', ic="exp", seed=137):
def burgers_sim(grid_size=64, bcs='periodic', ic="exp", seed=137):
    np.random.seed(seed)
    if(bcs == 'periodic'):
        grid = pde.CartesianGrid( [[-0.5, 0.5],[-0.5,0.5]], grid_size, periodic=[True,True])
    else:
        grid = pde.CartesianGrid( [[-0.5, 0.5],[-0.5,0.5]], grid_size)

    #Initial condition
    p, q = 1, 1
    if(ic == 'exp'):
        factors = [0., 0.]
        state = pde.ScalarField.from_expression(grid, "exp(-(x**2 + y**2) * 100)")
    elif(ic == 'prod_sin'):
        #factors = 15*np.random.random(2) + 1
        #factors = np.random.randint(1, 10, 2)
        factors = 2*np.random.randint(1, 5, 2)
        #print(factors)
        state = pde.ScalarField.from_expression(grid, f"sin(x*pi*{factors[0]})*sin(y*pi*{factors[1]})" )
    elif(ic == 'sum_sin'):
        #factors = 15*np.random.random(2) + 1
        #factors = np.random.randint(1, 10, 2)
        factors = 2*np.random.randint(1, 5, 2)
        state = pde.ScalarField.from_expression(grid, f"sin({factors[0]}*pi*x) + cos({factors[1]}*pi*y)" )
    elif(ic == 'random'):
        # This can't be seeded so let's just not use it for now.
        factors = [0., 0.]
        state = pde.ScalarField.random_uniform(grid, vmin=-1.0, vmax=1.0)
    else:
        raise ValueError("Pick exp, prod_sin, sim_sin, or random ic.")

    # Boundary Conditions
    if(bcs == 'periodic'):
        #bc = "auto_periodic_neumann"
        bc = "periodic"
    elif(bcs == 'dirichlet'):
        bc = {"value": 0.2*np.random.random()-0.1}
    elif(bcs == 'neumann'):
        bc = {"derivative": 0.2*np.random.random()-0.1}

    # Equation Coefficients
    coeffs = [0.015/np.pi*np.random.random()+0.005/np.pi, 1*np.random.random()-0.5, 1*np.random.random()-0.5]
    #print(coeffs)

    # Define PDE
    pde_str = f"2*{coeffs[0]} * (laplace(u)) - u * ({coeffs[1]}*get_x(gradient(u)) + {coeffs[2]}*get_y(gradient(u)))"
                #"u": f"- (1.0*get_x(gradient(u)) - 0.3*get_y(gradient(u)))"
    eq = pde.PDE({"u": pde_str}, user_funcs={"get_x": lambda arr: arr[0], "get_y": lambda arr: arr[1]}, bc=bc)

    # simulate the pde
    path = NamedTemporaryFile(dir="./", prefix="burgers", suffix=".hdf5")

    writer = pde.FileStorage(path.name, write_mode="truncate",
                             info={'coeffs': coeffs, 'bcs': bc, 'ics': {ic: factors}})

    #sol = eq.solve(state, t_range=2, dt=1e-4, tracker=[pde.PlotTracker(interrupts=0.1),
    #               writer.tracker(0.1)], ret_info=True)
    sol = eq.solve(state, t_range=2.0, dt=1e-5, tracker=[writer.tracker(0.02)], ret_info=True)
    #sol = eq.solve(state, t_range=2.0, dt=1e-5, ret_info=True, adaptive=True)
    del sol
    del eq

    f1 = h5py.File(writer.filename,'r+')
    return f1['data'][:], \
           f1['times'][:], \
           np.array(writer.info['coeffs']), \
           writer.info['bcs'], \
           writer.info['ics'], \
           f1
    #return None, None, None, None, None


def sim_wrapper(inputs):
    print(inputs)
    all_data, all_time, all_coeff, all_bc_name, all_bc_fac, all_ic_name, all_ic_fac = [], [], [], [], [], [], []
    #for seed in tqdm(range(NUM_SEEDS)):
    for seed in tqdm(range(inputs[2], inputs[2]+10)):
        #burgers_sim(128, inputs[0], inputs[1], i)

        data, time, coeff, bc, ic, f = burgers_sim(64, bcs=inputs[0], ic=inputs[1], seed=seed+1)
        f.close()
        all_data.append(data)
        all_time.append(time)
        all_coeff.append(coeff)

        if(isinstance(bc, str)):
            all_bc_name.append([bc])
            all_bc_fac.append([0.])
        else:
            all_bc_name.append(list(bc.keys()))
            all_bc_fac.append(list(bc.values()))

        all_ic_name.append(list(ic.keys()))
        all_ic_fac.append(list(ic.values()))

        try:
            np.save("./another_new_burger_{}/{}_{}_{}_data.npy".format(NUM_SEEDS, inputs[0], inputs[1], inputs[2]), all_data)
            np.save("./another_new_burger_{}/{}_{}_{}_time.npy".format(NUM_SEEDS, inputs[0], inputs[1], inputs[2]), all_time)
            np.save("./another_new_burger_{}/{}_{}_{}_coeff.npy".format(NUM_SEEDS, inputs[0], inputs[1], inputs[2]), all_coeff)
            np.save("./another_new_burger_{}/{}_{}_{}_bc_name.npy".format(NUM_SEEDS, inputs[0], inputs[1], inputs[2]), all_bc_name)
            np.save("./another_new_burger_{}/{}_{}_{}_bc_fac.npy".format(NUM_SEEDS, inputs[0], inputs[1], inputs[2]), all_bc_fac)
            np.save("./another_new_burger_{}/{}_{}_{}_ic_name.npy".format(NUM_SEEDS, inputs[0], inputs[1], inputs[2]), all_ic_name)
            np.save("./another_new_burger_{}/{}_{}_{}_ic_fac.npy".format(NUM_SEEDS, inputs[0], inputs[1], inputs[2]), all_ic_fac)
            np.save("./another_new_burger_{}/{}_{}_{}_grid.npy".format(NUM_SEEDS, inputs[0], inputs[1], inputs[2]), grid)
        except ValueError:
            print(all_bc_name)
            print(all_ic_name)

    del all_data
    del all_time
    del all_coeff
    del all_bc_name
    del all_bc_fac
    del all_ic_name
    del all_ic_fac
    gc.collect()

if __name__ == '__main__':
    grid1 = pde.CartesianGrid( [[-0.5, 0.5],[-0.5,0.5]], 64, periodic=[True,True])
    grid_points = grid1.axes_coords
    X, Y = np.meshgrid(*grid_points)
    grid = np.stack((X, Y), axis=-1)

    bcs = ['periodic', 'neumann', 'dirichlet']
    ics = ['exp', 'prod_sin', 'sum_sin']

    #seeds = [i for i in range(0, NUM_SEEDS, 10)]

    print(len(range(0, NUM_SEEDS, 10)))

    #seeds = [i for i in range(0, NUM_SEEDS, 10)][:30]
    #seeds = [i for i in range(0, NUM_SEEDS, 10)][30:60]
    seeds = [i for i in range(0, NUM_SEEDS, 10)][60:]

    #seeds = [i for i in range(0, NUM_SEEDS, 10)][400:600]
    #seeds = [i for i in range(0, NUM_SEEDS, 10)][600:800]
    #seeds = [i for i in range(0, NUM_SEEDS, 10)][800:900]

    #seeds = [i for i in range(0, NUM_SEEDS, 10)][25:]
    print(seeds)
    #raise

    runs = list(itertools.product(bcs, ics, seeds))
    os.makedirs("./another_new_burger_{}".format(NUM_SEEDS), exist_ok=True)

    print(len(runs))
    p = pool.Pool(50)
    p.map(sim_wrapper, runs)

