import argparse
from pathlib import Path
import numpy as np
import h5py
from tqdm import tqdm
from fipy import CellVariable, Grid1D, TransientTerm, ConvectionTerm, ImplicitSourceTerm, DiffusionTerm, LinearLUSolver
from fipy.solvers.scipy.linearGMRESSolver import LinearGMRESSolver
import torch
from torchvision.utils import save_image, make_grid
import sys
import os


project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(project_root)
from vis_utils import draw


def check_reaction_diffusion(path, atol_ic=1e-6, atol_mass=1e-5):

    with h5py.File(path, 'r') as f:
        sol = f['solution'][:]               # shape (nsim, nx, nt)
        ic_ref = f['initial_condition'][:]   # shape (nsim, nx)
        g_L = f['bc_flux_left'][:]           # shape (nsim,)
        g_R = f['bc_flux_right'][:]          # shape (nsim,)
        rho = f['rho'][:]                    # shape (nsim,)
        nsim, nx, nt = sol.shape

    dx = 1.0 / nx
    dt = 1.0 / (nt - 1)  


    ic_ok = np.allclose(sol[:, :, 0], ic_ref, atol=atol_ic)
    if ic_ok:
        print("OK: initial condition check passed")
    else:
        bad = np.where(~np.allclose(sol[:, :, 0], ic_ref, atol=atol_ic))[0]



    mass = sol.sum(axis=1) * dx               # shape (nsim, nt)


    reaction = rho[:, None, None] * sol * (1 - sol)
    R = reaction.sum(axis=1) * dx              # shape (nsim, nt)


    flux = (g_L - g_R)[:, None]                # shape (nsim, 1)

    
    integrand = R[:, :-1] + flux               # shape (nsim, nt-1)
    C = np.cumsum(integrand * dt, axis=1)      # shape (nsim, nt-1)


    m0 = mass[:, 0]                            # shape (nsim,)
    m_pred = m0[:, None] + C                   # shape (nsim, nt-1)

    
    m_true = mass[:, 1:]                       # shape (nsim, nt-1)
    ok = np.allclose(m_true, m_pred, atol=atol_mass, rtol=0)

    if ok:
        print("OK: mass conservation check passed")
        
    else:

        diff = m_true - m_pred
        idx = np.unravel_index(np.argmax(np.abs(diff)), diff.shape)
        i_bad, k_bad = idx
        real_t = k_bad + 1  
        print("FAIL: mass conservation check failed")



def generate_random_ic(mesh):

    x_coords = mesh.cellCenters.value[0]
    sine_amp = np.random.uniform(0.1, 0.4)
    sine_freq = np.random.randint(1, 5)
    sine_wave = sine_amp * np.sin(2 * np.pi * sine_freq * x_coords) + sine_amp

    bump_loc = np.random.uniform(0.2, 0.8)
    bump_amp = np.random.uniform(0.2, 0.5)
    bump_width = np.random.uniform(0.05, 0.15)
    bump = bump_amp * np.exp(-((x_coords - bump_loc) ** 2) / (2 * bump_width ** 2))

    ic = sine_wave + bump
    ic[ic < 0] = 0
    return ic

def generate_random_bcs():
    
    g_L = np.random.uniform(-0.01, 0.01)
    g_R = np.random.uniform(-0.01, 0.01)
    return g_L, g_R


def visualize_test_samples(h5_path, output_path='test_samples_visualization.png', n_samples=16):

    
    with h5py.File(h5_path, 'r') as f:
        solution = f['solution'][:n_samples]  # shape: (n_samples, nx, nt)
        ic = f['initial_condition'][:n_samples]  # shape: (n_samples, nx)
        g_L = f['bc_flux_left'][:n_samples]  # shape: (n_samples,)
        g_R = f['bc_flux_right'][:n_samples]  # shape: (n_samples,)
        rho = f['rho'][:n_samples]  # shape: (n_samples,)
        nu = f['nu'][:n_samples]  # shape: (n_samples,)
        
    n_samples_actual, nx, nt = solution.shape

    

    solution_tensor = torch.from_numpy(solution).float()
    

    img_tensors = []
    
    for i in range(n_samples_actual):

        img_tensor = draw(solution_tensor[i], vmin=0.0, vmax=1.0, cmap='bwr')
        img_tensors.append(img_tensor)
    

    nrow = 4
    

    grid_img = make_grid(img_tensors, nrow=nrow, padding=2, normalize=False)
    

    save_image(grid_img, output_path)

    
    return grid_img

def main(args):


    dx = 1.0 / args.nx
    dt = args.total_time / (args.nt-1)
    total_simulations = args.num_ics * args.num_bcs

    

    mesh_for_gen = Grid1D(dx=dx, nx=args.nx) 
    list_of_ics = [generate_random_ic(mesh_for_gen) for _ in tqdm(range(args.num_ics), desc="Generate ICs")]
    list_of_bcs = [generate_random_bcs() for _ in tqdm(range(args.num_bcs), desc="Generate BCs")]
    print("-" * 30)


    with h5py.File(args.out, 'w') as f:
        
        f.create_dataset('solution',
                         shape=(total_simulations, args.nx, args.nt),
                         dtype=np.float32)
        f.create_dataset('initial_condition',
                         shape=(total_simulations, args.nx),
                         dtype=np.float32)
        f.create_dataset('bc_flux_left',
                         shape=(total_simulations,),
                         dtype=np.float32)
        f.create_dataset('bc_flux_right',
                         shape=(total_simulations,),
                         dtype=np.float32)
        f.create_dataset('rho',
                         shape=(total_simulations,),
                         dtype=np.float32)
        f.create_dataset('nu',
                         shape=(total_simulations,),
                         dtype=np.float32)

        sim_index = 0
        with tqdm(total=total_simulations, ) as pbar:
            for ic in list_of_ics:
                for g_L, g_R in list_of_bcs:
                    mesh = Grid1D(dx=dx, nx=args.nx)
                    u = CellVariable(name="u", mesh=mesh, hasOld=True)

 
                    u.setValue(ic)
   
                    u.faceGrad.constrain(- g_L/args.nu, where=mesh.facesLeft)

                    u.faceGrad.constrain(- g_R/args.nu, where=mesh.facesRight)


                    reaction = ImplicitSourceTerm(coeff=args.rho) - args.rho * u * u

                    diffusion = DiffusionTerm(coeff=args.nu)
                    eq = TransientTerm() == diffusion + reaction


                    sol_hist = np.zeros((args.nx, args.nt), dtype=np.float32)

     
                    sol_hist[:, 0] = ic


                    for t in range(1, args.nt):

                        u.updateOld()
                        eq.solve(var=u, dt=dt, solver=LinearLUSolver(tolerance=1e-12))

                        sol_hist[:, t] = u.value

              
                    f['solution'][sim_index, ...] = sol_hist
                    f['initial_condition'][sim_index, :] = ic
                    f['bc_flux_left'][sim_index] = g_L
                    f['bc_flux_right'][sim_index] = g_R
                    f['rho'][sim_index] = args.rho
                    f['nu'][sim_index] = args.nu
                    
                    sim_index += 1
                    pbar.update(1)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="1D diffusion-reaction dataset generation with FiPy, output to HDF5."
    )
    parser.add_argument('--nx',    type=int,   default=128 )
    parser.add_argument('--nt',    type=int,   default=100)
    parser.add_argument('--total_time', type=float, default=1.0)
    parser.add_argument('--num_ics', type=int, default=100, )
    parser.add_argument('--num_bcs', type=int, default=100)
    parser.add_argument('--rho',   type=float, default=0.01)
    parser.add_argument('--nu',    type=float, default=0.005)
    parser.add_argument('--out',   type=str,   default='datasets/data/rd/reaction_diffusion_test.h5')
    
    parser.add_argument('--visualize', action='store_true')
    parser.add_argument('--test_data', type=str, default=)
    parser.add_argument('--vis_output', type=str, default=)
    parser.add_argument('--n_samples', type=int, default=16)
    
    args = parser.parse_args()
    
    if args.visualize:

        visualize_test_samples(args.test_data, args.vis_output, args.n_samples)
    else:

        main(args)  
        check_reaction_diffusion(args.out, atol_mass=1e-6, atol_ic=1e-10)