import jax.numpy as np
import jax
from jax import jit, vmap
import numpy as onp
import h5py
from functools import partial, reduce
from jax import config
from time import time

config.update("jax_enable_x64", True)

from basisfunctions import num_elements
from simulator import simulate_2D
from rungekutta import FUNCTION_MAP
from training import get_f_phi, get_initial_condition
from flux import Flux
from helper import f_to_FV, _evalf_2D_integrate, nabla
from poissonsolver import get_poisson_solver
from poissonbracket import get_poisson_bracket
from diffusion import get_diffusion_func
from helper import legendre_inner_product, inner_prod_with_legendre

PI = np.pi

def create_dataset(args, data_dir, unique_id, nx, ny, nt):
    f = h5py.File(
        "{}/{}_{}x{}.hdf5".format(data_dir, unique_id, nx, ny),
        "w",
    )

    dset_a_new = f.create_dataset(
        "a_data", (nt, nx, ny, 1), dtype="float64"
    )
    dset_t_new = f.create_dataset("t_data", (nt), dtype="float64")
    f.close()


def write_dataset(args, data_dir, unique_id, a, t, i):
    nx, ny = a.shape[0:2]
    f = h5py.File(
        "{}/{}_{}x{}.hdf5".format(data_dir, unique_id, nx, ny),
        "r+",
    )
    dset_a = f["a_data"]
    dset_t = f["t_data"]
    dset_a[i] = a
    dset_t[i] = t
    f.close()


@partial(
    jit,
    static_argnums=(
        1,
        2,
        3,
        4,
        7,
    ),
)
def convert_representation(
    a, order_new, order_high, nx_new, ny_new, Lx, Ly, equation, n = 8
):
    _, nx_high, ny_high = a.shape[0:3]
    dx_high = Lx / nx_high
    dy_high = Ly / ny_high

    def convert_repr(a):
        def f_high(x, y, t):
            return _evalf_2D_integrate(x, y, a, dx_high, dy_high, order_high)

        t0 = 0.0
        return f_to_FV(nx_new, ny_new, Lx, Ly, order_new, f_high, t0, n=n)

    vmap_convert_repr = vmap(convert_repr)

    return vmap_convert_repr(a)


####################
# GENERATE EVAL DATA
####################

def generate_eval_data(args, data_dir, N, T, nxs, fluxes, unique_ids, seed):
    key = jax.random.PRNGKey(seed)
    f_init = get_initial_condition(key, args)
    t0 = 0.0

    a0 = f_to_FV(args.nx_max, args.ny_max, args.Lx, args.Ly, args.order_max, f_init, t0, n = 8)


    #########
    # burn_in
    #########
    dx_min = args.Lx / (args.nx_max)
    dy_min = args.Ly / (args.ny_max)
    dt_exact = args.cfl_safety * ((dx_min * dy_min) / (dx_min + dy_min)) / (2 * args.order_max + 1)
    nt_burn_in = int(args.burn_in_time // dt_exact)
    

    f_poisson_exact = get_poisson_solver(
        args.poisson_dir, args.nx_max, args.ny_max, args.Lx, args.Ly, args.order_max
    )
    f_phi_exact = lambda zeta, t: f_poisson_exact(zeta)

    if args.diffusion_coefficient > 0.0:
        f_diffusion_exact = get_diffusion_func(args.order_max, args.Lx, args.Ly, args.diffusion_coefficient)
    else:
        f_diffusion_exact = None

    if args.is_forcing:
        leg_ip_exact = np.asarray(legendre_inner_product(args.order_max))
        ffe = lambda x, y, t: -4 * (2 * PI / args.Ly) * np.cos(4 * (2 * PI / args.Ly) * y)
        y_term_exact = inner_prod_with_legendre(args.nx_max, args.ny_max, args.Lx, args.Ly, args.order_max, ffe, 0.0, n = 2 * args.order_max + 1)
        dx_min = args.Lx / args.nx_max
        dy_min = args.Ly / args.ny_max
        f_forcing_exact = lambda zeta: (y_term_exact - dx_min * dy_min * args.damping_coefficient * zeta * leg_ip_exact) * args.forcing_coefficient
    else:
        f_forcing_exact = None

    f_poisson_bracket_exact = get_poisson_bracket(args.poisson_dir, args.order_max, args.exact_flux)

    @partial(
        jit,
        static_argnums=(
            2
        ),
    )
    def sim_exact(a_i, t_i, nt, dt):
        a_f, t_f = simulate_2D(
            a_i,
            t_i,
            args.nx_max,
            args.ny_max,
            args.Lx,
            args.Ly,
            args.order_max,
            dt,
            nt,
            args.exact_flux,
            equation=args.equation,
            output=False,
            f_poisson_bracket=f_poisson_bracket_exact,
            f_phi=f_phi_exact,
            f_diffusion=f_diffusion_exact,
            f_forcing=f_forcing_exact,
            rk=FUNCTION_MAP[args.runge_kutta],
        )
        return a_f, t_f

    a_burn_in_exact, t_burn_in = sim_exact(a0, t0, nt_burn_in, dt_exact)
    a_burn_in_exact, t_burn_in = sim_exact(a_burn_in_exact, t_burn_in, 1, args.burn_in_time - t_burn_in)

    print("burnt in")

    for num_id, flux in enumerate(fluxes):
        print("Flux is now {}".format(flux))
        unique_id = unique_ids[num_id]

        order = 0 
        for nx in nxs[num_id]:
            ny = nx
            dx = args.Lx / nx
            dy = args.Ly / ny
            dt = args.cfl_safety * ((dx * dy) / (dx + dy)) / (2 * order + 1)
            nt = int(T / dt)
            create_dataset(args, data_dir, unique_id, nx, ny, int(T))


        for nx in nxs[num_id]:
            print("nx is {}".format(nx))
            ny = nx
            dx = args.Lx / (nx)
            dy = args.Ly / (ny)
            dt = args.cfl_safety * ((dx * dy) / (dx + dy)) / (2 * order + 1)
            nt_sim = int(1 // dt)


            f_poisson_bracket = get_poisson_bracket(args.poisson_dir, order, flux)
            f_poisson_solve = get_poisson_solver(
                args.poisson_dir, nx, ny, args.Lx, args.Ly, order
            )
            f_phi = lambda zeta, t: f_poisson_solve(zeta)
            if args.diffusion_coefficient > 0.0:
                f_diffusion = get_diffusion_func(order, args.Lx, args.Ly, args.diffusion_coefficient)
            else:
                f_diffusion = None

            if args.is_forcing:
                leg_ip = np.asarray(legendre_inner_product(order))
                ff = lambda x, y, t: -4 * (2 * PI / args.Ly) * np.cos(4 * (2 * PI / args.Ly) * y)
                y_term = inner_prod_with_legendre(nx, ny, args.Lx, args.Ly, order, ff, 0.0, n = 2 * order + 1)
                dx = args.Lx / nx
                dy = args.Ly / ny
                f_forcing_sim = lambda zeta: (y_term - dx * dy * args.damping_coefficient * zeta * leg_ip) * args.forcing_coefficient
            else:
                f_forcing_sim = None

            @partial(
                jit,
                static_argnums=(
                    2
                ),
            )
            def simulate(a_i, t_i, nt, dt):
                
                return simulate_2D(
                    a_i,
                    t_i,
                    nx,
                    ny,
                    args.Lx,
                    args.Ly,
                    order,
                    dt,
                    nt,
                    flux,
                    equation=args.equation,
                    a_data=None,
                    output=False,
                    f_phi=f_phi,
                    f_poisson_bracket=f_poisson_bracket,
                    f_diffusion=f_diffusion,
                    f_forcing=f_forcing_sim,
                    rk=FUNCTION_MAP[args.runge_kutta],
                    inner_loop_steps=1,
                )




            a_burn_in = convert_representation(
                a_burn_in_exact[None], order, args.order_max, nx, ny, args.Lx, args.Ly, args.equation
            )[0]
            a_f, t_f = a_burn_in, t_burn_in

            write_dataset(args,
                data_dir,
                unique_id,
                a_f,
                t_f,
                0
            )


            for j in range(1,int(T)):
                a_f, t_f = simulate(a_f, t_f, nt_sim, dt)
                a_f, t_f = simulate(a_f, t_f, 1, (j + 1) + args.burn_in_time - t_f)

                write_dataset(args,
                    data_dir,
                    unique_id,
                    a_f,
                    t_f,
                    j
                )