# main experiments for Burgers IC problem, comparing PCFM against other methods 

import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import json
from torch.utils.data import DataLoader, Subset
import torch.nn.functional as F
import logging, time 

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)

from models import get_flow_model
from models.functional import make_grid
from datasets.burgers1d import Burgers1DDataset
from models.PCFM_sampling import Residuals, compute_jacobian, fast_project_batched, FFM_sampler
from models.constraints import InitialCondition

def calculate_metrics(ref, gen):
    stats = {
        'mean': [ref.mean(0), gen.mean(0)],
        'std': [ref.std(0), gen.std(0)],
    }
    mmse = (stats['mean'][0] - stats['mean'][1]).pow(2).mean().item()
    smse = (stats['std'][0] - stats['std'][1]).pow(2).mean().item()
    return mmse, smse

def batch_ic_residual_mean(ut, residuals):
    return torch.stack([residuals.ic_residual(u.flatten()).norm() for u in ut]).mean().item()

def batch_mass_residual_mean(ut, residuals):
    return torch.stack([residuals.mass_residual_burgers(u.flatten()).norm() for u in ut]).mean().item()

def burgers_ic_loss(u_pred, u_true, mask, pinn_weight=1e-2):
    loss_ic = ((u_pred - u_true) * mask).square().sum() / mask.sum() / u_pred.size(0)
    return loss_ic

def burgers_ic_pinn_loss(u_pred, u_true, mask, pinn_weight=1.0):
    loss_ic = ((u_pred - u_true) * mask).square().sum() / mask.sum() / u_pred.size(0)
    B, nx, nt = u_pred.shape
    dx = 1.0 / (nx - 1)
    dt = 1.0 / (nt - 1)
    dudx = (u_pred[:, 2:, :] - u_pred[:, :-2, :]) / (2 * dx)
    dudx = F.pad(dudx, (0, 0, 1, 1), mode='replicate')
    dudt = (u_pred[:, :, 2:] - u_pred[:, :, :-2]) / (2 * dt)
    dudt = F.pad(dudt, (1, 1, 0, 0), mode='replicate')
    residual = dudt + u_pred * dudx
    loss_pinn = (residual ** 2).mean()
    return loss_ic + pinn_weight * loss_pinn

def save_outputs(data, ut, method, n_step, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    fname = os.path.join(save_dir, f"{method}_{n_step}steps.pt")
    torch.save({"ut": ut.cpu()}, fname)

def run_sampling_in_batches(method_name, sampler, data, u0_full, batch_size, nx, nt, n_step, residuals, config):
    B = data.shape[0]
    ut_all = []

    for i in range(0, B, batch_size):
        u0 = u0_full[i:i + batch_size]
        d = data[i:i + batch_size]

        # Default: each method uses its own data slice
        u1_true = d
        mask = torch.zeros_like(u1_true)
        mask[:, :, 0] = 1.0

        if method_name == "PCFM":
            pcfm_cfg = config.get('PCFM', {})
            k_steps = pcfm_cfg.get('k_steps', 15)
            guided = pcfm_cfg.get('guided_interpolation', False)
            interp_params = pcfm_cfg.get('interpolation_params', {})
            use_vmap = pcfm_cfg.get('pcfm_use_vmap', False)

            out = sampler.pcfm_sample(
                u0, n_step,
                hfunc=lambda u: residuals.full_residual_burgers(u, k=k_steps),
                newtonsteps=1,
                guided_interpolation=guided,
                interpolation_params=interp_params,
                use_vmap = use_vmap
            )

            out = fast_project_batched(
                out.view(batch_size, -1),
                lambda u: residuals.full_residual_burgers(u, k=k_steps),
                max_iter=1
            ).view(batch_size, nx, nt)

        elif method_name == "ECI":
            ic_constraint = InitialCondition(d[:1, :, 0])  # still use d for IC
            eci_cfg = config.get('ECI', {})
            n_mix = eci_cfg.get('n_mix', 5)
            resample_step = eci_cfg.get('resample_step', 5)

            out = sampler.eci_sample(
                u0, n_step=n_step, n_mix=n_mix, resample_step=resample_step, constraint=ic_constraint
            )

        elif method_name == "vanilla":
            out = sampler.vanilla_sample(u0, n_step)

        elif method_name == "diffusionPDE":
            diff_cfg = config.get('diffusionPDE', {})
            eta = diff_cfg.get('eta', 1e-2)
            dPDE_loss_fn = lambda u_pred, u_true, m: burgers_ic_pinn_loss(
                u_pred, u_true, m, pinn_weight=config['diffusionPDE'].get('pinn_weight', 1e-2)
            )
            out = sampler.guided_sample(
                u0=u0, u1_true=u1_true, mask=mask,
                n_step=n_step,
                loss_fn=dPDE_loss_fn,
                eta=eta
            )

        elif method_name == "DFlow":
            u1_single = d[0].unsqueeze(0)
            u1_true = u1_single.expand(batch_size, -1, -1).contiguous()
            mask = torch.zeros_like(u1_true)
            mask[:, :, 0] = 1.0

            dflow_cfg = config.get('DFlow', {})
            n_iter = dflow_cfg.get('n_iter', 20)
            lr = dflow_cfg.get('lr', 1.0)
            pinn_weight = dflow_cfg.get('pinn_weight', 1e-2)

            dflow_loss_fn = lambda u_pred, u_true, m: burgers_ic_pinn_loss(
                u_pred, u_true, m, pinn_weight=pinn_weight
            )

            out = sampler.dflow_sample(
                u1_true=u1_true,
                mask=mask,
                n_sample=batch_size,
                n_step=n_step,
                n_iter=n_iter,
                lr=lr,
                loss_fn=dflow_loss_fn
            )

        else:
            raise NotImplementedError(f"{method_name} not recognized.")

        ut_all.append(out.cpu())

    return torch.cat(ut_all, dim=0)

def evaluate_all_methods(config):
    device = 'cuda'
    ckpt = torch.load(config['ckpt_path'], map_location=device)
    model = get_flow_model(ckpt['config'].model, ckpt['config'].encoder).to(device)
    model.load_state_dict(ckpt['model'])
    model.eval()

    dataset = Burgers1DDataset(
        root=config['dataset_root'],
        split="train",
        data_file=config['data_file']
    )

    ic_idx = config['ic_idx']
    N_bc = dataset.N_bc
    start = ic_idx * N_bc
    indices = list(range(start, start + N_bc))
    loader = DataLoader(Subset(dataset, indices), batch_size=N_bc)
    data = next(iter(loader)).to(device)
    x = torch.tensor(dataset.file['x'][:]).to(device)
    t_grid = torch.tensor(dataset.file['t'][:]).to(device)
    nx, nt = data.shape[1], data.shape[2]
    grid = make_grid((nx, nt), device)
    u0_full = model.gp.sample(grid, (nx, nt), n_samples=N_bc).to(device)

    torch.save(u0_full.cpu(), os.path.join(config["save_dir"], "u0_full.pt"))
    logging.info(f"ic_idx, N_bc, nx, nt: {ic_idx, N_bc, nx, nt}")
    logging.info(f"u0 shape: {u0_full.shape}")
    logging.info(f"data shape: {data.shape}")

    sampler = FFM_sampler(model, model.gp)
    residuals = Residuals(data=data, x=x, t_grid=t_grid, dx=x[1] - x[0], dt=t_grid[1] - t_grid[0], nx=nx, nt=nt)

    methods = config['methods']
    step_list = config['n_steps']
    batch_sizes = config['batch_sizes']

    results = []

    for method in methods:
        for n_step in step_list:
            logging.info(f"Running {method} with n_step={n_step}")
            start_time = time.time()
            batch_size = batch_sizes.get(method, 64)

            ut = run_sampling_in_batches(method, sampler, data, u0_full, batch_size, nx, nt, n_step, residuals, config)
            elapsed = time.time() - start_time
            logging.info(f"Completed {method} with n_step={n_step} in {elapsed:.2f} seconds")

            mmse, smse = calculate_metrics(data.cpu(), ut.cpu())
            ic_res = batch_ic_residual_mean(ut, residuals)
            mass_res = batch_mass_residual_mean(ut, residuals)
            results.append({
                "method": method,
                "n_step": n_step,
                "mmse": mmse,
                "smse": smse,
                "ic_res": ic_res,
                "mass_res": mass_res
            })

            save_outputs(data, ut, method, n_step, save_dir=config['save_dir'])
            df = pd.DataFrame(results)
            df.to_csv(config['csv_path'], index=False)
            logging.info(f"\nTemp results saved to {config['csv_path']}")

    df = pd.DataFrame(results)
    df.to_csv(config['csv_path'], index=False)
    logging.info("\nExperiment completed successfully.")

########################

config = {
    "ckpt_path": "logs/burgers0504/20000.pt",
    "dataset_root": "datasets/data",
    "data_file": "burgers_sampling_diffICs2_nIC20_nBC512.h5",
    "ic_idx": 1, 

    "methods": ["PCFM", "DFlow", "vanilla", "ECI", "diffusionPDE"],

    "n_steps": [200, 10, 20, 50, 100],
    
    "batch_sizes": {
        "vanilla": 128,
        "ECI": 128,
        "PCFM": 64,
        "diffusionPDE": 128,
        "DFlow": 32,  
    },

    "PCFM": {
        "k_steps": 5,
        "guided_interpolation": False,
        "interpolation_params": {
            "custom_lam": 1.0,
            "step_size": 1e-2,
            "num_steps": 20
        },
        "pcfm_use_vmap": False,  # in pcfm_batched()
    },
    "ECI": {
        "n_mix": 5,
        "resample_step": 5
    },
    "diffusionPDE": {
        "eta": 1, 
        "pinn_weight": 1e-2
        # eta=10 and weight=1 got NaNs
    },
    "DFlow": {
        "n_iter": 20,
        "lr": 1.0, 
        "pinn_weight": 1e-2,
    },
    "save_dir": "final/burgers0511",
    "csv_path": "results.csv"
}
os.makedirs(config["save_dir"], exist_ok=True)
config["csv_path"] = os.path.join(config["save_dir"], "results.csv")

config_path = os.path.join(config["save_dir"], "config.json")
with open(config_path, "w") as f:
    json.dump(config, f, indent=4)
print(f"Config saved to {config_path}")

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler(os.path.join(config["save_dir"], "run.log")),
        logging.StreamHandler()
    ]
)

evaluate_all_methods(config)
