# main experiments for the heat equation, comparing PCFM against other methods 

import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import json
from torch.utils.data import DataLoader, Subset
import torch.nn.functional as F
import logging, time 

import gc, random

def free_cuda_mem():
    gc.collect()
    torch.cuda.empty_cache()

# seed = 42
# torch.manual_seed(seed)
# np.random.seed(seed)
# torch.cuda.manual_seed_all(seed)

def seed_everything(seed=42):
  random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

seed_everything()


from models import get_flow_model
from models.functional import make_grid
from datasets import DiffusionDataset
from models.PCFM_sampling import Residuals, compute_jacobian, fast_project_batched, FFM_sampler
from models.constraints import InitialCondition, RegionConservationLaw, ChainConstraint

def calculate_metrics(ref, gen):
    stats = {
        'mean': [ref.mean(0), gen.mean(0)],
        'std': [ref.std(0), gen.std(0)],
    }
    mmse = (stats['mean'][0] - stats['mean'][1]).pow(2).mean().item()
    smse = (stats['std'][0] - stats['std'][1]).pow(2).mean().item()
    return mmse, smse

def batch_ic_residual_mean(ut, residuals):
    return torch.stack([residuals.ic_residual(u.flatten()).norm() for u in ut]).mean().item()

def batch_mass_residual_mean(ut, residuals):
    return torch.stack([residuals.mass_residual_heat(u.flatten()).norm() for u in ut]).mean().item()

def heat_ic_loss(u_pred, u_true, mask, pinn_weight=1e-2):
    loss_ic = ((u_pred - u_true) * mask).square().sum() / mask.sum() / u_pred.size(0)
    return loss_ic

def heat_ic_gpinn_loss(u_pred, u_true, mask, pinn_weight=1.0):
    loss_ic = ((u_pred - u_true) * mask).square().sum() / mask.sum() / u_pred.size(0)
    B, nx, nt = u_pred.shape
    dx = 1.0 / (nx - 1)
    dt = 1.0 / (nt - 1)

    mass_0 = torch.sum(u_pred[:, :, 0], dim=1, keepdim=True) * dx
    mass_t = torch.sum(u_pred, dim=1) * dx
    residual = mass_t[:, 1:] - mass_0
    loss_pinn = (residual ** 2).mean()

    return loss_ic + pinn_weight * loss_pinn

    
def save_outputs(data, ut, method, n_step, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    fname = os.path.join(save_dir, f"{method}_{n_step}steps.pt")
    torch.save({"ut": ut.cpu()}, fname)

def run_sampling_in_batches(method_name, sampler, data, u0_full, batch_size, nx, nt, n_step, residuals, config):
    B = data.shape[0]
    # print(B)
    # print(batch_size)
    ut_all = []

    for i in tqdm(range(0, B, batch_size),desc=str(method_name)):
        free_cuda_mem()
        _batch_size = min(B,i + batch_size) - i
        u0 = u0_full[i:min(B,i + batch_size)]
        d = data[i: min(B,i + batch_size)]

        u1_true = d
        mask = torch.zeros_like(u1_true)
        mask[:, :, 0] = 1.0

        if method_name == "PCFM":
            pcfm_cfg = config.get('PCFM', {})
            k_steps = pcfm_cfg.get('k_steps', 15)
            guided = pcfm_cfg.get('guided_interpolation', False)
            interp_params = pcfm_cfg.get('interpolation_params', {})

            out = sampler.pcfm_sample(
                u0, n_step,
                hfunc=lambda u: residuals.full_residual_heat(u),
                newtonsteps=1,
                guided_interpolation=guided,
                interpolation_params=interp_params
            )
            out = fast_project_batched(
                out.view(_batch_size, -1),
                lambda u: residuals.full_residual_heat(u),
                max_iter=1
            ).view(_batch_size, nx, nt)

        elif method_name == "ECI":
            ic_constraint = InitialCondition(d[:1, :, 0]) 

            _,nx,nt = u0.shape

            conserv_constraint = RegionConservationLaw(
                value=torch.zeros(_batch_size, nt),  # integral == 0
                dims=(1,),              # enforce along the spatial dimension
                area=2*torch.pi            
            )
            full_constraint = ChainConstraint(ic_constraint, conserv_constraint).to(data.device)

            eci_cfg = config.get('ECI', {})
            n_mix = eci_cfg.get('n_mix', 5)
            resample_step = eci_cfg.get('resample_step', 5)

            out = sampler.eci_sample(
                u0, n_step=n_step, n_mix=n_mix, resample_step=resample_step, constraint=full_constraint
            )

        elif method_name == "vanilla":
            out = sampler.vanilla_sample(u0, n_step)

        elif method_name == "diffusionPDE":
            diff_cfg = config.get('diffusionPDE', {})
            eta = diff_cfg.get('eta', 1e-2)
            dPDE_loss_fn = lambda u_pred, u_true, m: heat_ic_gpinn_loss(
                u_pred, u_true, m, pinn_weight=config['diffusionPDE'].get('pinn_weight', 1e-2)
            )
            out = sampler.guided_sample(
                u0=u0, u1_true=u1_true, mask=mask,
                n_step=n_step,
                loss_fn=dPDE_loss_fn,
                eta=eta
            )

        elif method_name == "DFlow":
            u1_single = d[0].unsqueeze(0)
            u1_true = u1_single.expand(_batch_size, -1, -1).contiguous()
            mask = torch.zeros_like(u1_true)
            mask[:, :, 0] = 1.0

            dflow_cfg = config.get('DFlow', {})
            n_iter = dflow_cfg.get('n_iter', 20)
            lr = dflow_cfg.get('lr', 1.0)

            out = sampler.dflow_sample(
                u1_true=u1_true,
                mask=mask,
                n_sample=_batch_size,
                n_step=n_step,
                n_iter=n_iter,
                lr=lr
            )

        else:
            raise NotImplementedError(f"{method_name} not recognized.")

        ut_all.append(out.cpu())

    free_cuda_mem()
    return torch.cat(ut_all, dim=0)

def evaluate_all_methods(config):
    device = 'cuda:2'
    ckpt = torch.load(config['ckpt_path'], map_location=device)
    model = get_flow_model(ckpt['config'].model, ckpt['config'].encoder).to(device)
    model.load_state_dict(ckpt['model'])
    model.eval()

    nx = 100
    nt = 100

    Lx = 2 * np.pi
    x = torch.linspace(0, Lx, nx, device=device).unsqueeze(1)  # shape (nx,1)
    dx = x[1] - x[0]  


    methods = config['methods']
    step_list = config['n_steps']
    batch_sizes = config['batch_sizes']

    dataset = DiffusionDataset(
        split='test',
        nx=nx,
        nt=nt,
        visc_range=(1., 5.),    # diffusivity in [1,5]
        phi_range=(np.pi/4, np.pi/4),  # fix phi = pi/4
        t_range=(0., 1.),
    )
    # test_loader = DataLoader(dataset, batch_size=batch_sizes, shuffle=False)
    # data = next(iter(test_loader))

    ic_idx = config['ic_idx']
    #N_bc = dataset.N_bc
    #start = ic_idx * N_bc
    #indices = list(range(start, start + N_bc))
    #loader = DataLoader(Subset(dataset, indices), batch_size=N_bc)

    full_size = 512
    loader = DataLoader(dataset, batch_size=full_size, shuffle=False)
    data = next(iter(loader)).to(device)
    #x = torch.tensor(dataset.file['x'][:]).to(device)
    #t_grid = torch.tensor(dataset.file['t'][:]).to(device)

    t_grid = torch.linspace(0, 1., nt, device=device).unsqueeze(1)
    nx, nt = data.shape[1], data.shape[2]
    grid = make_grid((nx, nt), device)
    u0_full = model.gp.sample(grid, (nx, nt), n_samples=full_size).to(device)

    torch.save(u0_full.cpu(), os.path.join(config["save_dir"], "u0_full.pt"))
    
    logging.info(f"ic_idx, nx, nt, full_batch: {ic_idx, nx, nt, full_size}")
    logging.info(f"u0 shape: {u0_full.shape}")
    logging.info(f"data shape: {data.shape}")

    sampler = FFM_sampler(model, model.gp)
    residuals = Residuals(data=data, x=x, t_grid=t_grid, dx=x[1] - x[0], dt=t_grid[1] - t_grid[0], nx=nx, nt=nt)

    results = []

    for method in methods:
        for n_step in step_list:
            free_cuda_mem()
            logging.info(f"Running {method} with n_step={n_step}")
            start_time = time.time()
            batch_size = batch_sizes.get(method, 64)

            ut = run_sampling_in_batches(method, sampler, data, u0_full, batch_size, nx, nt, n_step, residuals, config)
            elapsed = time.time() - start_time
            logging.info(f"Completed {method} with n_step={n_step} in {elapsed:.2f} seconds")

            mmse, smse = calculate_metrics(data.cpu(), ut.cpu())
            ic_res = batch_ic_residual_mean(ut, residuals)
            mass_res = batch_mass_residual_mean(ut, residuals)
            results.append({
                "method": method,
                "n_step": n_step,
                "mmse": mmse,
                "smse": smse,
                "ic_res": ic_res,
                "mass_res": mass_res
            })

            save_outputs(data, ut, method, n_step, save_dir=config['save_dir'])
            df = pd.DataFrame(results)
            df.to_csv(config['csv_path'], index=False)
            logging.info(f"\nTemp results saved to {config['csv_path']}")

    df = pd.DataFrame(results)
    df.to_csv(config['csv_path'], index=False)
    logging.info("\nExperiment completed successfully.")

########################

config = {
    "ckpt_path": "logs/heat/20000.pt",
    "dataset_root": "datasets/data",
    "ic_idx": 0,

    "methods": ["PCFM", "vanilla", "ECI", "diffusionPDE", "DFlow"],
    "n_steps": [10, 20, 50, 100, 200],
    "batch_sizes": {
        "vanilla": 512,
        "ECI": 512,
        "PCFM": 128,
        "diffusionPDE": 512,
        "DFlow": 256, 
    },

    "PCFM": {
        "guided_interpolation": True,
        "interpolation_params": {
            "custom_lam": 1.0,
            "step_size": 1e-2,
            "num_steps": 3
        }
    },
    "ECI": {
        "n_mix": 1,
        "resample_step": 0
    },
    "diffusionPDE": {
        "eta": 1, 
        "pinn_weight": 1e-2
    },
    "DFlow": {
        "n_iter": 10, 
        "lr": 1.0 
    },
    "save_dir": "heat0506",
    "csv_path": "results.csv"
}
os.makedirs(config["save_dir"], exist_ok=True)
config["csv_path"] = os.path.join(config["save_dir"], "results.csv")

config_path = os.path.join(config["save_dir"], "config.json")
with open(config_path, "w") as f:
    json.dump(config, f, indent=4)
print(f"Config saved to {config_path}")

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler(os.path.join(config["save_dir"], "run.log")),
        logging.StreamHandler()
    ]
)

evaluate_all_methods(config)