# main experiments for the Reaction-Diffusion equation, comparing PCFM against other methods 

import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import json
from torch.utils.data import DataLoader, Subset
import torch.nn.functional as F
import logging, time

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)

from models import get_flow_model
from models.functional import make_grid
from datasets.rd1d import RD1DDataset
from models.PCFM_sampling import Residuals, fast_project_batched, FFM_sampler
from models.constraints import InitialCondition

def calculate_metrics(ref, gen):
    stats = {
        'mean': [ref.mean(0), gen.mean(0)],
        'std': [ref.std(0), gen.std(0)],
    }
    mmse = (stats['mean'][0] - stats['mean'][1]).pow(2).mean().item()
    smse = (stats['std'][0] - stats['std'][1]).pow(2).mean().item()
    return mmse, smse

def batch_ic_residual_mean(ut, residuals):
    return torch.stack([residuals.ic_residual(u.flatten()).norm() for u in ut]).mean().item()

def batch_mass_residual_mean(ut, residuals):
    return torch.stack([residuals.mass_residual_rd(u.flatten()).norm() for u in ut]).mean().item()

def rd_ic_pinn_loss(u_pred, u_true, mask, pinn_weight=1.0, rho=0.01, nu=0.005):
    loss_ic = ((u_pred - u_true) * mask).square().sum() / mask.sum() / u_pred.size(0)
    B, nx, nt = u_pred.shape
    dx = 1.0 / nx
    dt = 1.0 / nt
    dudx = (u_pred[:, 2:, :] - u_pred[:, :-2, :]) / (2 * dx)
    dudx = F.pad(dudx, (0, 0, 1, 1), mode='replicate')
    d2udx2 = (u_pred[:, :-2, :] - 2 * u_pred[:, 1:-1, :] + u_pred[:, 2:, :]) / dx**2
    d2udx2 = F.pad(d2udx2, (0, 0, 1, 1), mode='replicate')
    dudt = (u_pred[:, :, 2:] - u_pred[:, :, :-2]) / (2 * dt)
    dudt = F.pad(dudt, (1, 1, 0, 0), mode='replicate')
    f_react = rho * u_pred * (1 - u_pred)
    residual = dudt - nu * d2udx2 - f_react
    loss_pinn = residual.pow(2).mean()
    return loss_ic + pinn_weight * loss_pinn

def save_outputs(data, ut, method, n_step, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    torch.save({"ut": ut.cpu()}, os.path.join(save_dir, f"{method}_{n_step}steps.pt"))

def run_sampling_in_batches(method_name, sampler, data, u0_full, batch_size, nx, nt, n_step, residuals, config):
    B = data.shape[0]
    ut_all = []

    for i in range(0, B, batch_size):
        u0 = u0_full[i:i + batch_size]
        d = data[i:i + batch_size]
        u1_true = d
        mask = torch.zeros_like(u1_true)
        mask[:, :, 0] = 1.0

        if method_name == "PCFM":
            cfg = config['PCFM']
            out = sampler.pcfm_sample(
                u0, n_step,
                hfunc=lambda u: residuals.full_residual_rd(u),
                newtonsteps=1,
                guided_interpolation=cfg["guided_interpolation"],
                interpolation_params=cfg["interpolation_params"],
                use_vmap = cfg["pcfm_use_vmap"]
            )
            out = fast_project_batched(
                out.view(batch_size, -1),
                lambda u: residuals.full_residual_rd(u),
                max_iter=1
            ).view(batch_size, nx, nt)

        elif method_name == "ECI":
            ic_constraint = InitialCondition(d[:1, :, 0])
            out = sampler.eci_sample(
                u0, n_step=n_step,
                n_mix=config['ECI']['n_mix'],
                resample_step=config['ECI']['resample_step'],
                constraint=ic_constraint
            )

        elif method_name == "vanilla":
            out = sampler.vanilla_sample(u0, n_step)

        elif method_name == "diffusionPDE":
            dPDE_loss_fn = lambda u_pred, u_true, m: rd_ic_pinn_loss(
                u_pred, u_true, m,
                pinn_weight=config['diffusionPDE'].get('pinn_weight', 1e-2),
                rho=residuals.rho, nu=residuals.nu
            )
            eta = config['diffusionPDE']['eta']
            out = sampler.guided_sample(
                u0=u0, u1_true=u1_true, mask=mask, n_step=n_step,
                loss_fn = dPDE_loss_fn, eta=eta
            )

        elif method_name == "DFlow":
            u1_true = d[0].unsqueeze(0).expand(batch_size, -1, -1)
            mask = torch.zeros_like(u1_true)
            mask[:, :, 0] = 1.0

            dflow_cfg = config.get('DFlow', {})
            n_iter = dflow_cfg.get('n_iter', 20)
            lr = dflow_cfg.get('lr', 1.0)
            pinn_weight = dflow_cfg.get('pinn_weight', 1e-2)

            dflow_loss_fn = lambda u_pred, u_true, m: rd_ic_pinn_loss(
                u_pred, u_true, m,
                pinn_weight=pinn_weight,
                rho=residuals.rho, nu=residuals.nu
            )

            out = sampler.dflow_sample(
                u1_true=u1_true,
                mask=mask,
                n_sample=batch_size,
                n_step=n_step,
                n_iter=n_iter,
                lr=lr,
                loss_fn=dflow_loss_fn
            )

        else:
            raise NotImplementedError(f"{method_name} not recognized.")

        ut_all.append(out.cpu())

    return torch.cat(ut_all, dim=0)

def evaluate_all_methods(config):
    device = 'cuda'
    ckpt = torch.load(config['ckpt_path'], map_location=device)
    model = get_flow_model(ckpt['config'].model, ckpt['config'].encoder).to(device)
    model.load_state_dict(ckpt['model'])
    model.eval()

    dataset = RD1DDataset(
        root=config['dataset_root'],
        split="train",
        data_file=config['data_file']
    )

    ic_idx = config['ic_idx']
    N_bc = dataset.N_bc
    indices = list(range(ic_idx * N_bc, (ic_idx + 1) * N_bc))
    data = next(iter(DataLoader(Subset(dataset, indices), batch_size=N_bc))).to(device)

    x = torch.tensor(dataset.file['x'][:]).to(device)
    t_grid = torch.tensor(dataset.file['t'][:]).to(device)
    nx, nt = data.shape[1], data.shape[2]
    grid = make_grid((nx, nt), device)

    u0_full = model.gp.sample(grid, (nx, nt), n_samples=N_bc).to(device)
    torch.save(u0_full.cpu(), os.path.join(config["save_dir"], "u0_full.pt"))
    
    logging.info(f"ic_idx, N_bc, nx, nt: {ic_idx, N_bc, nx, nt}")
    logging.info(f"u0 shape: {u0_full.shape}")
    logging.info(f"data shape: {data.shape}")

    sampler = FFM_sampler(model, model.gp)
    residuals = Residuals(
        data=data, x=x, t_grid=t_grid,
        dx=x[1] - x[0], dt=t_grid[1] - t_grid[0],
        nx=nx, nt=nt,
        rho=dataset.file.attrs['rho'],
        nu=dataset.file.attrs['nu']
    )

    results = []
    for method in config['methods']:
        for n_step in config['n_steps']:
            logging.info(f"Running {method} with n_step={n_step}")
            start_time = time.time()
            ut = run_sampling_in_batches(method, sampler, data, u0_full, config['batch_sizes'][method], nx, nt, n_step, residuals, config)
            logging.info(f"{method} done in {time.time() - start_time:.2f}s")

            mmse, smse = calculate_metrics(data.cpu(), ut.cpu())
            ic_res = batch_ic_residual_mean(ut, residuals)
            mass_res = batch_mass_residual_mean(ut, residuals)

            results.append({
                "method": method, "n_step": n_step,
                "mmse": mmse, "smse": smse,
                "ic_res": ic_res, "mass_res": mass_res
            })

            save_outputs(data, ut, method, n_step, config["save_dir"])
            pd.DataFrame(results).to_csv(config["csv_path"], index=False)
            logging.info(f"\nTemp results saved to {config['csv_path']}")

    pd.DataFrame(results).to_csv(config["csv_path"], index=False)
    logging.info("All experiments done.")

config = {
    "ckpt_path": "logs/rd0423/20000.pt",
    "dataset_root": "datasets/data",
    "data_file": "RD_sampling_diffICs_nIC20_nBC512.h5",
    "ic_idx": 5, 

    "methods": ["PCFM", "DFlow", "vanilla", "ECI", "diffusionPDE"],
    "n_steps": [200, 10, 20, 50, 100],
    "batch_sizes": {
        "vanilla": 128, 
        "ECI": 128, 
        "PCFM": 64,
        "diffusionPDE": 128, 
        "DFlow": 32
    },
    "PCFM": {
        "guided_interpolation": False,
        "interpolation_params": {
            "custom_lam": 1.0, "step_size": 1e-2, "num_steps": 20
        },
        "pcfm_use_vmap": False,  # in pcfm_batched()
    },
    "ECI": {
        "n_mix": 5, 
        "resample_step": 5
    },
    "diffusionPDE": {
        "eta": 1,
        "pinn_weight": 1e-2,
    },
    "DFlow": {
        "n_iter": 20, 
        "lr": 1.0,
        "pinn_weight": 1e-2, 
    },
    "save_dir": "final/rd_0511",
    "csv_path": "final/rd_0511/results.csv"
}

os.makedirs(config["save_dir"], exist_ok=True)
with open(os.path.join(config["save_dir"], "config.json"), "w") as f:
    json.dump(config, f, indent=4)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler(os.path.join(config["save_dir"], "run.log")),
        logging.StreamHandler()
    ]
)

evaluate_all_methods(config)
