# main experiments for the Navier-Stokes equation, comparing PCFM against other methods 

import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import json
from torch.utils.data import DataLoader, Subset
import torch.nn.functional as F
import logging, time 
import gc
import random

# seed = 42
# torch.manual_seed(seed)
# np.random.seed(seed)
# torch.cuda.manual_seed_all(seed)

def seed_everything(seed=42):
  random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

seed_everything()

def free_cuda_mem():
    gc.collect()
    torch.cuda.empty_cache()

from models import get_flow_model
from models.functional import make_grid
from datasets import NavierStokesDataset
from models.PCFM_sampling import Residuals2D, compute_jacobian, fast_project_batched, FFM_NS_sampler
from models.constraints import InitialCondition, RegionConservationLaw, ChainConstraint

def calculate_metrics(ref, gen):
    stats = {
        'mean': [ref.mean(0), gen.mean(0)],
        'std': [ref.std(0), gen.std(0)],
    }
    mmse = (stats['mean'][0] - stats['mean'][1]).pow(2).mean().item()
    smse = (stats['std'][0] - stats['std'][1]).pow(2).mean().item()
    return mmse, smse

def batch_ic_residual_mean(ut, residuals):
    return torch.stack([residuals.ic_residual_ns(u.flatten()).norm() for u in ut]).mean().item()

def batch_mass_residual_mean(ut, residuals):
    return torch.stack([residuals.mass_residual_ns(u.flatten()).norm() for u in ut]).mean().item()

def ns_ic_loss(u_pred, u_true, mask, pinn_weight=1e-2):
    loss_ic = ((u_pred - u_true) * mask).square().sum() / mask.sum() / u_pred.size(0)
    return loss_ic

def ns_ic_gpinn_loss(u_pred, u_true, mask, pinn_weight=1.0):
    loss_ic = ((u_pred - u_true) * mask).square().sum() / mask.sum() / u_pred.size(0)
    B, nx, ny, nt = u_pred.shape
    dx = 1.0 / (nx - 1)
    dy = 1.0 / (ny - 1)

    mass_0 = torch.sum(u_pred[:, :, :, 0], dim=(1,2), keepdim=True) * dx
    mass_t = torch.sum(u_pred, dim=(1,2)) * dx * dy
    residual = mass_t[:, 1:] - mass_0
    loss_pinn = (residual ** 2).mean()

    return loss_ic + pinn_weight * loss_pinn


def save_outputs(data, ut, method, n_step, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    fname = os.path.join(save_dir, f"{method}_{n_step}steps.pt")
    torch.save({"ut": ut.cpu()}, fname)

def run_sampling_in_batches(method_name, sampler, data, u0_full, batch_size, nx, ny, nt, n_step, residuals, config):
    B = data.shape[0]
    # print(B)
    # print(batch_size)
    ut_all = []

    for i in range(0, B, batch_size):
        # free_cuda_mem()
        gc.collect()
        torch.cuda.empty_cache()
        _batch_size = min(B,i + batch_size) - i
        u0 = u0_full[i:min(B,i + batch_size)]
        d = data[i:min(B,i + batch_size)]

        # Default: each method uses its own data slice
        u1_true = d
        mask = torch.zeros_like(u1_true)
        mask[:, :, 0] = 1.0

        if method_name == "PCFM":
            pcfm_cfg = config.get('PCFM', {})
            k_steps = pcfm_cfg.get('k_steps', 15)
            guided = pcfm_cfg.get('guided_interpolation', False)
            interp_params = pcfm_cfg.get('interpolation_params', {})

            out = sampler.pcfm_sample(
                u0, n_step,
                hfunc=lambda u: residuals.full_residual_ns(u),
                newtonsteps=1,
                guided_interpolation=guided,
                interpolation_params=interp_params
            )

            gc.collect()
            torch.cuda.empty_cache()

            out = fast_project_batched(
                out.view(_batch_size, -1),
                lambda u: residuals.full_residual_ns(u),
                max_iter=1
            ).view(_batch_size, nx, ny, nt)

        elif method_name == "ECI":
            ic_constraint = InitialCondition(d[:1, :, :, 0])  # still use d for IC

            _,nx, ny, nt = u0.shape

            dx = 1 / (nx - 1)
            dy = 1 / (ny - 1)
            cell_area = dx * dy
            domain_area = nx * ny * cell_area

            conserv_constraint = RegionConservationLaw(
                value=torch.zeros(_batch_size, nt),  # integral == 0
                dims=(1,2),              # enforce along the spatial dimension
                area= domain_area,        
            )

            full_constraint = ChainConstraint(conserv_constraint, ic_constraint).to(data.device)
            eci_cfg = config.get('ECI', {})
            n_mix = eci_cfg.get('n_mix', 5)
            resample_step = eci_cfg.get('resample_step', 5)

            out = sampler.eci_sample(
                u0, n_step=n_step, n_mix=n_mix, resample_step=resample_step, constraint=full_constraint
            )

        elif method_name == "vanilla":
            out = sampler.vanilla_sample(u0, n_step)

        elif method_name == "diffusionPDE":
            diff_cfg = config.get('diffusionPDE', {})
            eta = diff_cfg.get('eta', 1e-2)
            dPDE_loss_fn = lambda u_pred, u_true, m: ns_ic_gpinn_loss(
                u_pred, u_true, m, pinn_weight=config['diffusionPDE'].get('pinn_weight', 1e-2)
            )
            out = sampler.guided_sample(
                u0=u0, u1_true=u1_true, mask=mask,
                n_step=n_step,
                loss_fn=dPDE_loss_fn,
                eta=eta
            )

        elif method_name == "DFlow":
            # DFlow uses repeated single example as ground truth
            u1_single = d[0].unsqueeze(0)
            u1_true = u1_single.expand(_batch_size, -1, -1).contiguous()
            mask = torch.zeros_like(u1_true)
            mask[:, :, 0] = 1.0

            dflow_cfg = config.get('DFlow', {})
            n_iter = dflow_cfg.get('n_iter', 20)
            lr = dflow_cfg.get('lr', 1.0)

            out = sampler.dflow_sample(
                u1_true=u1_true,
                mask=mask,
                n_sample=_batch_size,
                n_step=n_step,
                n_iter=n_iter,
                lr=lr
            )

        else:
            raise NotImplementedError(f"{method_name} not recognized.")

        ut_all.append(out.cpu())

    return torch.cat(ut_all, dim=0)

def evaluate_all_methods(config):
    device = 'cuda:1'
    ckpt = torch.load(config['ckpt_path'], map_location=device)
    model = get_flow_model(ckpt['config'].model, ckpt['config'].encoder).to(device)
    model.load_state_dict(ckpt['model'])
    model.eval()

    ic_idx = config['ic_idx']    
    
    nf = 100
    full_size = 100

    dataset = NavierStokesDataset(root=config['dataset_root'], split='test', data_file=config['data_file'])
    indices = list(range(ic_idx * nf, (ic_idx + 1) * nf))
    ic_subset = Subset(dataset, indices)
    loader = DataLoader(ic_subset, batch_size=full_size, shuffle=False, num_workers=4)

    data = next(iter(loader)).to(device)

    data = data[...,::2]

    print(data.shape)


    # x = torch.tensor(dataset.file['x'][:]).to(device)
    # y = torch.tensor(dataset.file['y'][:]).to(device)
    # t_grid = torch.tensor(dataset.file['t'][:]).to(device)

    nx = ny = 64; nt = 25
    Lx = 1.0; 
    x = torch.linspace(0, Lx, nx, device=device).unsqueeze(1)
    y = torch.linspace(0, Lx, nx, device=device).unsqueeze(1)
    dx = dy = x[1] - x[0]
    t_grid = torch.linspace(0, 49., nt, device=device).unsqueeze(1)

    nx, ny, nt = data.shape[1], data.shape[2], data.shape[3] 
    grid = make_grid((nx, ny, nt), device)
    u0_full = torch.randn(full_size, nx, ny, nt, device=device)

    torch.save(u0_full.cpu(), os.path.join(config["save_dir"], "u0_full.pt"))
    
    logging.info(f"ic_idx, nx, ny, nt, full_batch: {ic_idx, nx, ny, nt, full_size}")
    logging.info(f"u0 shape: {u0_full.shape}")
    logging.info(f"data shape: {data.shape}")

    sampler = FFM_NS_sampler(model)
    residuals = Residuals2D(data=data, x=x, y=y, t_grid=t_grid, dx=x[1] - x[0], dt=t_grid[1] - t_grid[0], dy=x[1] - x[0], nx=nx, ny=ny, nt=nt)

    methods = config['methods']
    step_list = config['n_steps']
    batch_sizes = config['batch_sizes']

    results = []

    for method in methods:
        for n_step in step_list:
            gc.collect()
            torch.cuda.empty_cache()
            logging.info(f"Running {method} with n_step={n_step}")
            start_time = time.time()
            batch_size = batch_sizes.get(method, 64)

            ut = run_sampling_in_batches(method, sampler, data, u0_full, batch_size, nx, ny, nt, n_step, residuals, config)
            elapsed = time.time() - start_time
            logging.info(f"Completed {method} with n_step={n_step} in {elapsed:.2f} seconds")

            mmse, smse = calculate_metrics(data.cpu(), ut.cpu())
            ic_res = batch_ic_residual_mean(ut, residuals)
            mass_res = batch_mass_residual_mean(ut, residuals)
            results.append({
                "method": method,
                "n_step": n_step,
                "mmse": mmse,
                "smse": smse,
                "ic_res": ic_res,
                "mass_res": mass_res
            })

            save_outputs(data, ut, method, n_step, save_dir=config['save_dir'])
            df = pd.DataFrame(results)
            df.to_csv(config['csv_path'], index=False)
            logging.info(f"\nTemp results saved to {config['csv_path']}")

    df = pd.DataFrame(results)
    df.to_csv(config['csv_path'], index=False)
    logging.info("\nExperiment completed successfully.")

########################

config = {
    "ckpt_path": "logs/navier_stokes/480000.pt",
    "dataset_root": "datasets/data",
    "ic_idx": 0,
    "data_file": "ns_nw10_nf100_s64_t50_mu0.001.h5",
    "methods": ["PCFM", "vanilla", "ECI", "diffusionPDE", "DFlow"],
    "n_steps": [10, 20, 50, 100, 200],
    "batch_sizes": {
        "vanilla": 50,
        "ECI": 50,
        "PCFM": 8,
        "diffusionPDE": 10,
        "DFlow": 10, 
    },

    "PCFM": {
        "guided_interpolation": False,
        "interpolation_params": {
            "custom_lam": 1.0,
            "step_size": 1e-2,
            "num_steps": 20
        }
    },
    "ECI": {
        "n_mix": 10,
        "resample_step": 5
    },
    "diffusionPDE": {
        "eta": 1, 
        "pinn_weight": 1e-2
    },
    "DFlow": {
        "n_iter": 10, # time consuming 
        "lr": 1.0 
    },
    "save_dir": "ns0507",
    "csv_path": "results.csv"
}
os.makedirs(config["save_dir"], exist_ok=True)
config["csv_path"] = os.path.join(config["save_dir"], "results.csv")

config_path = os.path.join(config["save_dir"], "config.json")
with open(config_path, "w") as f:
    json.dump(config, f, indent=4)
print(f"Config saved to {config_path}")

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler(os.path.join(config["save_dir"], "run.log")),
        logging.StreamHandler()
    ]
)

evaluate_all_methods(config)