# main experiments for Burgers BC problem, comparing PCFM against other methods 

import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import json
from torch.utils.data import DataLoader, Subset
import torch.nn.functional as F

import logging, time 

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)

from models import get_flow_model
from models.functional import make_grid
from datasets.burgers1d import Burgers1DDataset
from models.PCFM_sampling import Residuals, compute_jacobian, fast_project_batched, FFM_sampler
from models.constraints import InitialCondition, DirichletCondition, ChainConstraint, Constraint

START_STEP = 1

def calculate_metrics(ref, gen):
    stats = {
        'mean': [ref.mean(0), gen.mean(0)],
        'std': [ref.std(0), gen.std(0)],
    }
    mmse = (stats['mean'][0] - stats['mean'][1]).pow(2).mean().item()
    smse = (stats['std'][0] - stats['std'][1]).pow(2).mean().item()
    return mmse, smse


def batch_bc_residual_mean(ut, residuals, start_step=START_STEP):
    return torch.stack([residuals.bc_residual_burgers(u.flatten(), start_step=start_step).norm() for u in ut]).mean().item()

def batch_mass_residual_mean(ut, residuals):
    return torch.stack([residuals.mass_residual_burgers(u.flatten()).norm() for u in ut]).mean().item()


def burgers_pde_loss(u_pred):
    B, nx, nt = u_pred.shape
    dx, dt = 1.0 / (nx - 1), 1.0 / (nt - 1)
    dudx = (u_pred[:, 2:, :] - u_pred[:, :-2, :]) / (2 * dx)
    dudx = F.pad(dudx, (0, 0, 1, 1), mode='replicate')
    dudt = (u_pred[:, :, 2:] - u_pred[:, :, :-2]) / (2 * dt)
    dudt = F.pad(dudt, (1, 1, 0, 0), mode='replicate')
    res = dudt + u_pred * dudx
    return res.square().mean()


def burgers_bc_pinn_loss(u_pred, left_bc, pinn_weight=1e-2, start_step=0):
    pde   = burgers_pde_loss(u_pred) * pinn_weight
    uL    = u_pred[:, 0,  start_step:]     
    refL  = left_bc[:, start_step:]
    lossD = (uL - refL).square().mean()
    uR    = u_pred[:, -1, start_step:]
    uRm1  = u_pred[:, -2, start_step:]
    lossN = (uR - uRm1).square().mean()
    return pde + lossD + lossN

def save_outputs(data, ut, method, n_step, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    fname = os.path.join(save_dir, f"{method}_{n_step}steps.pt")
    torch.save({"ut": ut.cpu()}, fname)


def run_sampling_in_batches(method_name, sampler, data, u0_full, batch_size, nx, nt, n_step, residuals, config):
    B = data.shape[0]
    ut_all = []

    for i in range(0, B, batch_size):
        u0 = u0_full[i:i + batch_size]
        d = data[i:i + batch_size]

        u1_true      = d
        mask_ic      = torch.zeros_like(d);     mask_ic[:,:,0]=1.0
        left_bc_batch= d[:,0,:]  # shape [B,nt]

        # for ECI's constraint 
        mask_left    = torch.zeros_like(d, dtype=torch.bool)
        mask_left[:,0,START_STEP:] = True  
        left_val     = d[:,0:1,:]
        eci_constraint     = DirichletCondition(value=left_val, mask=mask_left)            

        if method_name == "PCFM":
            pcfm_cfg = config.get('PCFM', {})
            k_steps = pcfm_cfg.get('k_steps', 5)
            guided = pcfm_cfg.get('guided_interpolation', False)
            interp_params = pcfm_cfg.get('interpolation_params', {})
            use_vmap = pcfm_cfg.get('pcfm_use_vmap', False)

            out = sampler.pcfm_sample(
                u0, n_step,
                hfunc=lambda u: residuals.full_residual_burgers2(u, start_step=START_STEP),
                newtonsteps=1,
                guided_interpolation=guided,
                interpolation_params=interp_params,
                use_vmap = use_vmap
            )

            out = fast_project_batched(
                out.view(batch_size, -1),
                lambda u: residuals.full_residual_burgers2(u, start_step=START_STEP),
                max_iter=1,
            ).view(batch_size, nx, nt)

        elif method_name == "ECI":
            eci_cfg = config.get('ECI', {})
            n_mix = eci_cfg.get('n_mix', 5)
            resample_step = eci_cfg.get('resample_step', 5)

            out = sampler.eci_sample(
                u0, n_step=n_step, n_mix=n_mix, resample_step=resample_step, 
                constraint = eci_constraint
            )

        elif method_name == "vanilla":
            out = sampler.vanilla_sample(u0, n_step)

        elif method_name == "diffusionPDE":
            diff_cfg = config.get('diffusionPDE', {})
            eta = diff_cfg.get('eta', 1e-2)
            pinn_w  = diff_cfg['pinn_weight']

            def loss_dPDE(u_pred, u_true=None, m=None):
                return burgers_bc_pinn_loss(
                    u_pred, left_bc_batch,
                    pinn_weight=pinn_w,
                    start_step=START_STEP
                )
            
            dummy_mask = torch.ones_like(u1_true, dtype=torch.bool)

            out = sampler.guided_sample(
                u0=u0, u1_true=u1_true, mask=dummy_mask,
                n_step=n_step,
                loss_fn=loss_dPDE,
                eta=eta
            )

        elif method_name == "DFlow":
            def loss_pinn(u_pred, u_true=None, m=None):
                return burgers_bc_pinn_loss(
                    u_pred, left_bc_batch,
                    pinn_weight=config['DFlow']['pinn_weight'],
                    start_step=START_STEP
                )
            dummy_mask = torch.ones_like(u1_true, dtype=torch.bool)

            out = sampler.dflow_sample(
                u1_true  = u1_true,
                mask     = dummy_mask,
                n_sample = batch_size,
                n_step   = n_step,
                n_iter   = config['DFlow']['n_iter'],
                lr       = config['DFlow']['lr'],
                loss_fn  = loss_pinn
            )
        else:
            raise NotImplementedError(f"{method_name} not recognized.")

        ut_all.append(out.cpu())

    return torch.cat(ut_all, dim=0)

def evaluate_all_methods(config):
    device = 'cuda'
    ckpt = torch.load(config['ckpt_path'], map_location=device)
    model = get_flow_model(ckpt['config'].model, ckpt['config'].encoder).to(device)
    model.load_state_dict(ckpt['model'])
    model.eval()

    dataset = Burgers1DDataset(
        root=config['dataset_root'],
        split="train",
        data_file=config['data_file']
    )

    bc_idx = config['bc_idx']
    data = torch.from_numpy(dataset.file['u'][bc_idx]).to(device)  # [512, 101, 101]
    N_ic = data.shape[0]
    x = torch.tensor(dataset.file['x'][:]).to(device)
    t_grid = torch.tensor(dataset.file['t'][:]).to(device)
    nx, nt = data.shape[1], data.shape[2]
    grid = make_grid((nx, nt), device)

    u0_full = model.gp.sample(grid, (nx, nt), n_samples=N_ic).to(device)

    torch.save(u0_full.cpu(), os.path.join(config["save_dir"], "u0_full.pt"))
    logging.info(f"bc_idx, N_ic, nx, nt: {bc_idx, N_ic, nx, nt}")
    logging.info(f"u0 shape: {u0_full.shape}")
    logging.info(f"data shape: {data.shape}")

    sampler = FFM_sampler(model, model.gp)

    left_bc_shared = data[0, 0, START_STEP:].to(device)
    logging.info(f"left_bc_shared.shape: {left_bc_shared.shape}")

    applied_bc_values = data[:, 0, START_STEP:]  # Exclude t=0    
    bc_value = torch.tensor(dataset.file['bc'][bc_idx]).to(data.device)  # scalar
    logging.info(f"boundary value for bc_idx={bc_idx}: {bc_value.item()}")

    residuals = Residuals(
        data=data, x=x, t_grid=t_grid,
        dx=x[1] - x[0], dt=t_grid[1] - t_grid[0],
        nx=nx, nt=nt,
        left_bc=left_bc_shared
    )

    methods = config['methods']
    step_list = config['n_steps']
    batch_sizes = config['batch_sizes']

    results = []

    for n_step in step_list:
        for method in methods:
            logging.info(f"Running {method} with n_step={n_step}")
            start_time = time.time()
            batch_size = batch_sizes.get(method, 64)

            ut = run_sampling_in_batches(method, sampler, data, u0_full, batch_size, nx, nt, n_step, residuals, config)
            elapsed = time.time() - start_time
            logging.info(f"Completed {method} with n_step={n_step} in {elapsed:.2f} seconds")

            mmse, smse = calculate_metrics(data.cpu(), ut.cpu())
            bc_res = batch_bc_residual_mean(ut, residuals)
            mass_res = batch_mass_residual_mean(ut, residuals)
            results.append({
                "method": method,
                "n_step": n_step,
                "mmse": mmse,
                "smse": smse,
                "bc_res": bc_res,
                "mass_res": mass_res
            })

            save_outputs(data, ut, method, n_step, save_dir=config['save_dir'])
            df = pd.DataFrame(results)
            df.to_csv(config['csv_path'], index=False)
            logging.info(f"\nTemp results saved to {config['csv_path']}")

    df = pd.DataFrame(results)
    df.to_csv(config['csv_path'], index=False)
    logging.info("\nExperiment completed successfully.")

########################

config = {
    "ckpt_path": "logs/burgers0504/20000.pt",
    "dataset_root": "datasets/data",
    "data_file": "burgers_sampling_diffBCs_nBC20_nIC512.h5",
    "bc_idx": 1, 

    "methods": [ "PCFM", "DFlow", "ECI", "vanilla", "diffusionPDE"],

    "n_steps": [200, 10, 20, 50, 100],
    
    "batch_sizes": {
        "vanilla": 128,
        "ECI": 128,
        "PCFM": 64,
        "diffusionPDE": 128,
        "DFlow": 32,  
    },

    "PCFM": {
        "guided_interpolation": False,
        "interpolation_params": {
            "custom_lam": 1.0,
            "step_size": 1e-2,
            "num_steps": 20
        },
        "pcfm_use_vmap": False,  # in pcfm_batched()
    },
    "ECI": {
        "n_mix": 5,
        "resample_step": 5
    },
    "diffusionPDE": {
        "eta": 1, 
        "pinn_weight": 1e-2
        # eta=10 and weight=1 got NaNs
    },
    "DFlow": {
        "n_iter": 20,
        "lr": 1.0, 
        "pinn_weight": 1e-2, 
    },

    "save_dir": "final/burgersbc0514",
    "csv_path": "results.csv"
}
os.makedirs(config["save_dir"], exist_ok=True)
config["csv_path"] = os.path.join(config["save_dir"], config["csv_path"])

config_path = os.path.join(config["save_dir"], "config.json")
with open(config_path, "w") as f:
    json.dump(config, f, indent=4)
print(f"Config saved to {config_path}")

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler(os.path.join(config["save_dir"], "run.log")),
        logging.StreamHandler()
    ]
)

evaluate_all_methods(config)