# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master.

# Authors: Anonymous

import os
import sys
from datetime import datetime
import math
from collections import OrderedDict, defaultdict
import torch
from absl import app, flags
from cleanfid import fid
from torchdiffeq import odeint
from torchdyn.core import NeuralODE
from torchvision import transforms
from tqdm import tqdm
import shutil
import traceback
import glob

# Move to repository root from this file's directory
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.join(current_dir, '../../../')
sys.path.insert(0, os.path.abspath(root_dir))

from torchcfm.models.unet.unet import UNetModelWrapper
from torchcfm.utils import exp_naming
from utils_cifar import compute_pr, visualize_pca_comparison, lt_cache_dir, ensure_lt_dataset_dir, get_real_dataset, find_model_directory




FLAGS = flags.FLAGS

flags.DEFINE_bool('measure_fid', True, help='measure fid')
flags.DEFINE_bool('measure_precall', True, help='measure precall')
flags.DEFINE_bool('measure_pca', False, help='perform PCA visualization')
flags.DEFINE_bool('measure_likelihood', False, help='compute per-sample log-likelihood via CNF')

flags.DEFINE_list('dataset_measure', ['cifar10','cifar10_lt'], help='list of datasets to evaluate (e.g., cifar10,cifar10_lt, cifar100, cifar100_lt)')

flags.DEFINE_string('gen_external_path', None, help='path to the generated images, if None, generate images in each measure function')
flags.DEFINE_list('data_external_paths', [None, None], help='list of real image dirs, [None] or [auto_dir] for auto directory, [auto_builtin] for auto builtin dataset')
flags.DEFINE_string('method_pr', 'fast', help='method to compute precision and recall, fast or slow') # slow is known code, fast is new codes


# PCA
flags.DEFINE_integer('pca_samples', 1024, help='number of samples for PCA visualization')
flags.DEFINE_string('pca_model', 'vgg16', help='feature extraction model for PCA')
flags.DEFINE_string('pca_mode', 'raw_pixel', help='PCA mode: feature (VGG16 features) or raw_pixel (image pixels)')
flags.DEFINE_integer('max_images_per_class', 7, help='maximum number of images to display per class in individual class visualization')


# UNet
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")

# Training
flags.DEFINE_integer("integration_steps", 100, help="number of inference steps")
flags.DEFINE_string("integration_method", "dopri5", help="integration method to use")
flags.DEFINE_integer("step", 400000, help="training steps, default: 400000")
flags.DEFINE_integer("num_gen", 50000, help="number of samples to generate, should be larger than batch_size_fid")
flags.DEFINE_float("tol", 1e-5, help="Integrator tolerance (absolute and relative, default: 1e-5)") # 1e-6: 2 iter per min, 1e-7: 10 iter per hour
flags.DEFINE_integer("batch_size_fid", 1024, help="Batch size to compute FID")
flags.DEFINE_string("device", "cuda:0", help="Device to use [cuda:0, cuda:1]")

# Path setting
flags.DEFINE_string("input_dir", "./results", help="output_directory")

# sub directory path setting option #0: manually set the sub directory name (directory="none", training_params=None)
flags.DEFINE_string("directory", "none", help="directly indicate the directory name")

## sub directory path setting option #1: manually set the model and hyperparameters (directory="none", training_params!=None)
flags.DEFINE_string("model", "sinkhorn_otwfm", help="flow matching model type")
flags.DEFINE_string("training_params", None, help="model hyperparameters")

## sub directory path setting option #2: automatically set the directory name (directory="auto")
flags.DEFINE_string("dataset_name", "cifar10", help="name of training dataset for model")
flags.DEFINE_string("weight_type", "none", help="weight type for flow matching [inv_tnu]")
flags.DEFINE_float("reg", 1.0, help="regularization parameter")
flags.DEFINE_float("tau_b", 1.0, help="regularization parameter b for Sinkhorn")
flags.DEFINE_float("beta", 1.0, help="beta for energy-weighted flow matching")
flags.DEFINE_bool("efm", False, help="energy-weighted flow matching")
flags.DEFINE_float("weight_power_factor", 2.0, help="weight power factor for Sinkhorn")
flags.DEFINE_bool("parallel", False, help="parallel training")
flags.DEFINE_bool("recoupling", True, help="sinkhorn change the coupling of x0 and x1")
flags.DEFINE_bool("fixed_source", False, help="sinkhorn fixed source")
flags.DEFINE_bool("fixed_target", False, help="sinkhorn fixed target")


flags.DEFINE_float("imb_factor", 0.01, help="imbalance factor for long-tail dataset")

# Likelihood (CNF) evaluation
flags.DEFINE_string('likelihood_split', 'train', help='dataset split for likelihood [train|test]')
flags.DEFINE_integer('batch_size_ll', 4, help='batch size for likelihood evaluation')
flags.DEFINE_string('trace_estimator_ll', 'hutch', help='trace estimator [hutch|exact]')
flags.DEFINE_integer('ll_max_samples', None, help='limit #samples for likelihood (None for all)')
flags.DEFINE_string('ll_output_csv', None, help='output csv path; default: results dir per dataset')
flags.DEFINE_string('data_norm', 'default', help='data normalization used in training [default|cifar10|cifar100|adaptive]')
flags.DEFINE_bool('ll_manual_euler', False, help='True: use memory-friendly manual Euler integrator for likelihood, False: use odeint (dopri5)')
flags.DEFINE_integer('ll_euler_steps', 20, help='number of Euler steps (only if ll_manual_euler=True)')
flags.DEFINE_string('ll_time_direction', 'backward', help='likelihood time direction: backward(1->0) or forward(0->1)')
flags.DEFINE_bool('ll_midpoint', False, help='use midpoint for divergence and state update (less bias), True: high memory cost, False: low memory cost')
flags.DEFINE_string('ll_trace_noise', 'rademacher', help='trace noise type: rademacher|gaussian')
flags.DEFINE_integer('ll_trace_mc', 1, help='Hutchinson MC samples per step (>=1)')
flags.DEFINE_bool('ll_classwise_stats', True, help='log classwise likelihood mean/var per dataset')


FLAGS(sys.argv)



def main(argv):

    # Define the model
    use_cuda = torch.cuda.is_available()
    device = torch.device(FLAGS.device if use_cuda else "cpu")

    new_net = UNetModelWrapper(
        dim=(3, 32, 32),
        num_res_blocks=2,
        num_channels=FLAGS.num_channel,
        channel_mult=[1, 2, 2, 2],
        num_heads=4,
        num_head_channels=64,
        attention_resolutions="16",
        dropout=0.1,
    ).to(device)

    def gen_1_img(unused_latent):
        with torch.no_grad():
            x = torch.randn(FLAGS.batch_size_fid, 3, 32, 32, device=device)
            if FLAGS.integration_method == "euler":
                print("Use method: ", FLAGS.integration_method)
                t_span = torch.linspace(0, 1, FLAGS.integration_steps + 1, device=device)
                traj = node.trajectory(x, t_span=t_span)
            else:
                print("Use method: ", FLAGS.integration_method)
                t_span = torch.linspace(0, 1, 2, device=device)
                traj = odeint(
                    new_net, x, t_span, rtol=FLAGS.tol, atol=FLAGS.tol, method=FLAGS.integration_method
                )
        traj = traj[-1, :]  # .view([-1, 3, 32, 32]).clip(-1, 1)
        img = (traj * 127.5 + 128).clip(0, 255).to(torch.uint8)  # .permute(1, 2, 0)
        return img


    # Load the model
    if FLAGS.directory == "none":
        if FLAGS.training_params != None:
            target_dir = f"{FLAGS.model}_{FLAGS.training_params}"
            found_dir = find_model_directory(FLAGS.input_dir, target_dir)
            if found_dir is None:
                raise FileNotFoundError(f"Model directory not found: {target_dir}")
            PATH = f"{found_dir}/{FLAGS.model}_{FLAGS.dataset_name}_weights_step_{FLAGS.step}.pt"
        else:
            target_dir = FLAGS.model
            found_dir = find_model_directory(FLAGS.input_dir, target_dir)
            if found_dir is None:
                raise FileNotFoundError(f"Model directory not found: {target_dir}")
            PATH = f"{found_dir}/{FLAGS.model}_{FLAGS.dataset_name}_weights_step_{FLAGS.step}.pt"
    elif FLAGS.directory == "auto":
        exp_name = exp_naming(FLAGS)
        found_dir = find_model_directory(FLAGS.input_dir, exp_name)
        if found_dir is None:
            raise FileNotFoundError(f"Experiment directory not found: {exp_name}")
        PATH = f"{found_dir}/{FLAGS.model}_{FLAGS.dataset_name}_weights_step_{FLAGS.step}.pt"
    else:
        found_dir = find_model_directory(FLAGS.input_dir, FLAGS.directory)
        if found_dir is None:
            raise FileNotFoundError(f"Specified directory not found: {FLAGS.directory}")
        PATH = f"{found_dir}/{FLAGS.model}_{FLAGS.dataset_name}_weights_step_{FLAGS.step}.pt"
    
    print("path: ", PATH)
    
    # Check if .pt file exists
    if not os.path.exists(PATH):
        raise FileNotFoundError(f"Model file not found: {PATH}")
    
    checkpoint = torch.load(PATH, map_location=device)
    state_dict = checkpoint["ema_model"]
    try:
        new_net.load_state_dict(state_dict)
    except RuntimeError:
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            new_state_dict[k[7:]] = v
        new_net.load_state_dict(new_state_dict)
    new_net.eval()


    # Define the integration method if euler is used
    if FLAGS.integration_method == "euler":
        node = NeuralODE(new_net, solver=FLAGS.integration_method)


    if FLAGS.gen_external_path is not None:
        print(f"image generated in: {FLAGS.gen_external_path}")
        os.makedirs(FLAGS.gen_external_path, exist_ok=True)
        
        num_batches = (FLAGS.num_gen + FLAGS.batch_size_fid - 1) // FLAGS.batch_size_fid
        generated_count = 0
        
        for batch_idx in range(num_batches):
            # For the last batch, generate only the remaining number of images
            current_batch_size = min(FLAGS.batch_size_fid, FLAGS.num_gen - generated_count)
            
            with torch.no_grad():
                x = torch.randn(current_batch_size, 3, 32, 32, device=device)
                if FLAGS.integration_method == "euler":
                    t_span = torch.linspace(0, 1, FLAGS.integration_steps + 1, device=device)
                    traj = node.trajectory(x, t_span=t_span)
                else:
                    t_span = torch.linspace(0, 1, 2, device=device)
                    traj = odeint(
                        new_net, x, t_span, rtol=FLAGS.tol, atol=FLAGS.tol, method=FLAGS.integration_method
                    )
            
            traj = traj[-1, :]
            img = (traj * 127.5 + 128).clip(0, 255).to(torch.uint8)
            
            #     
            for i in range(current_batch_size):
                img_tensor = img[i]  # [3, 32, 32]
                img_path = os.path.join(FLAGS.gen_external_path, f"generated_{generated_count:06d}.png")
                # Convert tensor to PIL Image and save as PNG
                img_pil = transforms.ToPILImage()(img_tensor.cpu())
                img_pil.save(img_path)
                generated_count += 1
            if (batch_idx + 1) % 10 == 0:
                print(f"progress: {generated_count}/{FLAGS.num_gen} images generated")
        
        print(f"total {generated_count} images saved in {FLAGS.gen_external_path}")


    output_dir = os.path.dirname(PATH)
    log_file = os.path.join(output_dir, f"log.txt")

    # Build dataset list and prepare a single generation directory for multi-dataset evaluation
    datasets = FLAGS.dataset_measure  # Always a list
    # Normalize and validate data_external_paths
    raw_paths = FLAGS.data_external_paths
    if raw_paths is None:
        data_paths = [None] * len(datasets)
    else:
        if len(raw_paths) != len(datasets):
            raise ValueError("len(data_external_paths) must equal len(dataset_measure)")
        # Map 'None'/''/None strings to None; otherwise keep path strings
        def _norm_path(p):
            if p is None:
                return None
            if isinstance(p, str) and p.strip().lower() in ("none", "null", ""):
                return None
            return p
        data_paths = [_norm_path(p) for p in raw_paths]

    # Auto-set real data paths per dataset
    for i, ds in enumerate(datasets):
        if (data_paths[i] is None or data_paths[i] == "auto_dir") and isinstance(ds, str):
            if ds == "cifar10":
                data_paths[i] = "./data/cifar10"
            elif ds == "cifar100":
                data_paths[i] = "./data/cifar100"
            elif ds == "cifar10_lt":
                data_paths[i] = f"./data/cifar10_lt_imb{FLAGS.imb_factor}"
            elif ds == "cifar100_lt":
                data_paths[i] = f"./data/cifar100_lt_imb{FLAGS.imb_factor}"
            if data_paths[i] is not None and not os.path.exists(data_paths[i]):
                raise FileNotFoundError(f"data path does not exist: {data_paths[i]}")
        elif (data_paths[i] == "auto_builtin") and isinstance(ds, str):
            if ds == "cifar10":
                data_paths[i] = None
            elif ds == "cifar100":
                data_paths[i] = None
            elif ds == "cifar10_lt":
                data_paths[i] = f"./data/cifar10_lt_imb{FLAGS.imb_factor}"
            elif ds == "cifar100_lt":
                data_paths[i] = f"./data/cifar100_lt_imb{FLAGS.imb_factor}"
            if data_paths[i] is not None and not os.path.exists(data_paths[i]):
                # Instead of raising an error, download the dataset
                if ds == "cifar10":
                    from torchvision import datasets
                    _ = datasets.CIFAR10(root="./data", train=True, download=True)
                    data_paths[i] = "./data/cifar10"
                elif ds == "cifar100":
                    from torchvision import datasets
                    _ = datasets.CIFAR100(root="./data", train=True, download=True)
                    data_paths[i] = "./data/cifar100"
                else:
                    raise FileNotFoundError(f"data path does not exist: {data_paths[i]}")

    # Auto-fill FID real data path for LT datasets (create cache if missing)
    for i, ds in enumerate(datasets):
        if data_paths[i] is not None and isinstance(ds, str) and ds.endswith("_lt"):
            #      
            if not os.path.exists(data_paths[i]):
                ensure_lt_dataset_dir(ds, FLAGS.imb_factor, split="train", data_root="./data", per_class_subdir=True)
                data_paths[i] = lt_cache_dir(ds, FLAGS.imb_factor, data_root="./data")

    gen_dir_used = FLAGS.gen_external_path
    generated_temp_dir = None
    if len(datasets) > 1 and gen_dir_used is None:
        # Generate once and reuse for multi-dataset evaluation
        gen_dir_used = os.path.join(output_dir, f"gen_tmp_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
        generated_temp_dir = gen_dir_used
        print(f"image generated in: {gen_dir_used}")
        os.makedirs(gen_dir_used, exist_ok=True)

        num_batches = (FLAGS.num_gen + FLAGS.batch_size_fid - 1) // FLAGS.batch_size_fid
        generated_count = 0
        for batch_idx in range(num_batches):
            current_batch_size = min(FLAGS.batch_size_fid, FLAGS.num_gen - generated_count)
            with torch.no_grad():
                x = torch.randn(current_batch_size, 3, 32, 32, device=device)
                if FLAGS.integration_method == "euler":
                    t_span = torch.linspace(0, 1, FLAGS.integration_steps + 1, device=device)
                    traj = node.trajectory(x, t_span=t_span)
                else:
                    t_span = torch.linspace(0, 1, 2, device=device)
                    traj = odeint(
                        new_net, x, t_span, rtol=FLAGS.tol, atol=FLAGS.tol, method=FLAGS.integration_method
                    )
            traj = traj[-1, :]
            img = (traj * 127.5 + 128).clip(0, 255).to(torch.uint8)
            for i in range(current_batch_size):
                img_tensor = img[i]
                img_path = os.path.join(gen_dir_used, f"generated_{generated_count:06d}.png")
                img_pil = transforms.ToPILImage()(img_tensor.cpu())
                img_pil.save(img_path)
                generated_count += 1
            if (batch_idx + 1) % 10 == 0:
                print(f"progress: {generated_count}/{FLAGS.num_gen} images generated")
        print(f"total {generated_count} images saved in {gen_dir_used}")

    # === Likelihood helpers ===
    def _get_mean_std(name, mode):
        if mode == 'default':
            return (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
        if mode == 'cifar10':
            return (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
        if mode == 'cifar100':
            return (0.5071, 0.48651, 0.44091), (0.2673, 0.2564, 0.2762)
        return (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)

    def _draw_noise(shape_like, kind='rademacher'):
        if kind == 'rademacher':
            # ±1 with equal prob
            return (torch.randint_like(shape_like, low=0, high=2).float().mul_(2.0).sub_(1.0))
        # gaussian
        return torch.randn_like(shape_like)

    def _hutch_trace(y, x, noise_kind='rademacher', mc=1):
        # E[eps^T J^T eps] = tr(J), unbiased for both Rademacher and Gaussian
        mc = int(mc)
        if mc <= 1:
            noise = _draw_noise(x, noise_kind)
            jvp = torch.autograd.grad(y, x, noise, create_graph=False, retain_graph=False)[0]
            tr = torch.einsum('bi,bi->b', jvp, noise)
            return tr.detach()
        tr_sum = None  # tensor on device, no grad
        for i in range(mc):
            noise = _draw_noise(x, noise_kind)
            #      grad  →    retain_graph=True
            jvp = torch.autograd.grad(y, x, noise, create_graph=False, retain_graph=(i < mc - 1))[0]
            tr_i = torch.einsum('bi,bi->b', jvp, noise).detach()
            tr_sum = tr_i if tr_sum is None else (tr_sum + tr_i)
        return tr_sum / float(mc)

    def _autograd_trace(y, x):
        tr = 0.0
        for i in range(x.shape[1]):
            tr = tr + torch.autograd.grad(y[:, i].sum(), x, create_graph=True)[0][:, i]
        return tr

    class LLVectorField(torch.nn.Module):
        # state = [s (1 dim), x_flat (D)]
        def __init__(self, net, dim, trace='hutch'):
            super().__init__()
            self.net = net
            self.dim = dim
            self.D = int(torch.tensor(dim).prod().item())
            self.trace = trace

        def forward(self, t, state):
            s = state[:, :1]
            x_flat = state[:, 1:]
            x_flat = x_flat.requires_grad_(True)
            x = x_flat.view(-1, *self.dim)
            dx = self.net(t, x)
            dx_flat = dx.reshape(x_flat.shape)
            if self.trace == 'hutch':
                trJ = _hutch_trace(dx_flat, x_flat, noise_kind=FLAGS.ll_trace_noise, mc=max(1, int(FLAGS.ll_trace_mc)))
            else:
                trJ = _autograd_trace(dx_flat, x_flat)
            dsdt = -trJ[:, None]
            return torch.cat([dsdt, dx_flat], dim=1)



    if FLAGS.measure_fid:

        start_time = datetime.now()
        with open(log_file, "a") as f:
            f.write(f"Start computing FID at {start_time}\n")

        # https://github.com/GaParmar/clean-fid/blob/main/cleanfid/fid.py
        try:
            # Unified loop: branch per dataset for builtin/external path
            fid_scores = {}
            # Decide fdir1 source: pre-generated dir or external generation dir
            fdir1_source = gen_dir_used if gen_dir_used is not None else FLAGS.gen_external_path
            for idx, ds in enumerate(datasets):
                real_dir = data_paths[idx]
                if fdir1_source is not None:
                    # Use pre-generated images
                    if real_dir is None:
                        score = fid.compute_fid(
                            fdir1=fdir1_source,
                            dataset_name=ds,
                            batch_size=FLAGS.batch_size_fid,
                            dataset_res=32,
                            num_gen=FLAGS.num_gen,
                            dataset_split="train",
                            mode="legacy_tensorflow",
                            device=device,
                        )
                    else:
                        if not os.path.exists(real_dir):
                            raise ValueError(f"data_external_paths[{idx}] does not exist: {real_dir}")
                        score = fid.compute_fid(
                            fdir1=fdir1_source,
                            fdir2=real_dir,
                            batch_size=FLAGS.batch_size_fid,
                            dataset_res=32,
                            num_gen=FLAGS.num_gen,
                            dataset_split="train",
                            mode="legacy_tensorflow",
                            device=device,
                        )
                else:
                    # Use on-the-fly generator function
                    if real_dir is None:
                        score = fid.compute_fid(
                            gen=gen_1_img,
                            dataset_name=ds,
                            batch_size=FLAGS.batch_size_fid,
                            dataset_res=32,
                            num_gen=FLAGS.num_gen,
                            dataset_split="train",
                            mode="legacy_tensorflow",
                            device=device,
                        )
                    else:
                        if not os.path.exists(real_dir):
                            raise ValueError(f"data_external_paths[{idx}] does not exist: {real_dir}")
                        score = fid.compute_fid(
                            gen=gen_1_img,
                            fdir2=real_dir,
                            batch_size=FLAGS.batch_size_fid,
                            dataset_res=32,
                            num_gen=FLAGS.num_gen,
                            dataset_split="train",
                            mode="legacy_tensorflow",
                            device=device,
                        )
                print(f"FID ({ds}): {score}")
                fid_scores[ds] = score
        except Exception:
            with open(log_file, "a") as f:
                f.write("\nERROR during computing FID\n")
                f.write(traceback.format_exc())
            raise

        print()
        print("FID has been computed")
        # print()
        # print("Total NFE: ", new_net.nfe)
        print()
        with open(log_file, "a") as f:
            for ds, sc in fid_scores.items():
                f.write(f"step: {FLAGS.step}, training dataset: {FLAGS.dataset_name}, FID measured on: {ds} -> score: {sc}\n")
            f.write(f"computing FID finished at {datetime.now()}\n")
            f.write(f"Total FID measuring time: {datetime.now() - start_time}\n")

    if FLAGS.measure_precall:


        start_time = datetime.now()
        with open(log_file, "a") as f:
            f.write(f"Start computing precision and recall at {start_time}\n")

        try:
            # Unified loop: branch per dataset for builtin/external path
            # Decide fdir1 source: pre-generated dir or external generation dir
            fdir1_source = gen_dir_used if gen_dir_used is not None else FLAGS.gen_external_path
            last_precision, last_recall = None, None
            for idx, ds in enumerate(datasets):
                real_dir = data_paths[idx]
                if fdir1_source is not None:
                    if real_dir is None:
                        precision, recall = compute_pr(
                            fdir1=fdir1_source,
                            dataset_name=ds,
                            batch_size=FLAGS.batch_size_fid,
                            dataset_res=32,
                            num_gen=FLAGS.num_gen,
                            dataset_split="train",
                            device=device,
                            ext_model_name="vgg16",
                            real_data_path=None,
                            method=FLAGS.method_pr,
                        )
                    else:
                        if not os.path.exists(real_dir):
                            raise ValueError(f"data_external_paths[{idx}] does not exist: {real_dir}")
                        precision, recall = compute_pr(
                            fdir1=fdir1_source,
                            fdir2=real_dir,
                            dataset_name=ds,
                            batch_size=FLAGS.batch_size_fid,
                            dataset_res=32,
                            num_gen=FLAGS.num_gen,
                            dataset_split="train",
                            device=device,
                            ext_model_name="vgg16",
                            real_data_path=None,
                            method=FLAGS.method_pr,
                        )
                else:
                    if real_dir is None:
                        precision, recall = compute_pr(
                            gen=gen_1_img,
                            dataset_name=ds,
                            batch_size=FLAGS.batch_size_fid,
                            dataset_res=32,
                            num_gen=FLAGS.num_gen,
                            dataset_split="train",
                            device=device,
                            ext_model_name="vgg16",
                            real_data_path=None,
                            method=FLAGS.method_pr,
                        )
                    else:
                        if not os.path.exists(real_dir):
                            raise ValueError(f"data_external_paths[{idx}] does not exist: {real_dir}")
                        precision, recall = compute_pr(
                            gen=gen_1_img,
                            fdir2=real_dir,
                            dataset_name=ds,
                            batch_size=FLAGS.batch_size_fid,
                            dataset_res=32,
                            num_gen=FLAGS.num_gen,
                            dataset_split="train",
                            device=device,
                            ext_model_name="vgg16",
                            real_data_path=None,
                            method=FLAGS.method_pr,
                        )
                last_precision, last_recall = precision, recall
                print(f"Precision ({ds}): ", precision)
                print(f"Recall ({ds}): ", recall)
                with open(log_file, "a") as f:
                    f.write(f"step: {FLAGS.step} training dataset: {FLAGS.dataset_name}, measured on: {ds} -> precision: {precision} recall: {recall}\n")
            
        except Exception:
            with open(log_file, "a") as f:
                f.write("\nERROR during computing precision and recall\n")
                f.write(traceback.format_exc())
            raise

        #     
        if last_precision is not None and last_recall is not None:
            print("Precision: ", last_precision)
            print("Recall: ", last_recall)

        with open(log_file, "a") as f:
            #      
            f.write(f"computing precision and recall finished at {datetime.now()}\n")
            f.write(f"Total precision and recall measuring time: {datetime.now() - start_time}\n")

    if FLAGS.measure_pca:
        start_time = datetime.now()

        try:
            #  :  builtin/  
            # fdir1  :      
            fdir1_source = gen_dir_used if gen_dir_used is not None else FLAGS.gen_external_path
            
            for idx, ds in enumerate(datasets):
                real_dir = data_paths[idx]
                if fdir1_source is not None:
                    # Use pre-generated images
                    if real_dir is None:
                        result = visualize_pca_comparison(
                            gen_dir=fdir1_source,
                            real_dir=None,
                            gen_func=None,
                            dataset_name=ds,
                            num_samples=FLAGS.pca_samples,
                            device=device,
                            save_path=output_dir,
                            step=FLAGS.step,
                            ext_model_name=FLAGS.pca_model,
                            batch_size=FLAGS.batch_size_fid,
                            dataset_res=32,
                            dataset_split="train",
                            pca_mode=FLAGS.pca_mode,
                            max_images_per_class=FLAGS.max_images_per_class,
                            imb_factor=FLAGS.imb_factor if "lt" in ds else None,
                        )
                        #     tuple unpacking 
                        if len(result) == 3:
                            explained_variance, save_path, log_msg = result
                        else:
                            explained_variance, save_path = result
                    else:
                        if not os.path.exists(real_dir):
                            raise ValueError(f"data_external_paths[{idx}] does not exist: {real_dir}")
                        result = visualize_pca_comparison(
                            gen_dir=fdir1_source,
                            real_dir=real_dir,
                            gen_func=None,
                            dataset_name=ds,
                            num_samples=FLAGS.pca_samples,
                            device=device,
                            save_path=output_dir,
                            step=FLAGS.step,
                            ext_model_name=FLAGS.pca_model,
                            batch_size=FLAGS.batch_size_fid,
                            dataset_res=32,
                            dataset_split="train",
                            pca_mode=FLAGS.pca_mode,
                            max_images_per_class=FLAGS.max_images_per_class,
                            imb_factor=FLAGS.imb_factor if "lt" in ds else None,
                        )
                        #     tuple unpacking 
                        if len(result) == 3:
                            explained_variance, save_path, log_msg = result
                        else:
                            explained_variance, save_path = result
                else:
                    # Use on-the-fly generator function
                    if real_dir is None:
                        result = visualize_pca_comparison(
                            gen_dir=None,
                            real_dir=None,
                            gen_func=gen_1_img,
                            dataset_name=ds,
                            num_samples=FLAGS.pca_samples,
                            device=device,
                            save_path=output_dir,
                            step=FLAGS.step,
                            ext_model_name=FLAGS.pca_model,
                            batch_size=FLAGS.batch_size_fid,
                            dataset_res=32,
                            dataset_split="train",
                            pca_mode=FLAGS.pca_mode,
                            max_images_per_class=FLAGS.max_images_per_class,
                            imb_factor=FLAGS.imb_factor if "lt" in ds else None,
                        )
                        #     tuple unpacking 
                        if len(result) == 3:
                            explained_variance, save_path, log_msg = result
                        else:
                            explained_variance, save_path = result
                    else:
                        if not os.path.exists(real_dir):
                            raise ValueError(f"data_external_paths[{idx}] does not exist: {real_dir}")
                        result = visualize_pca_comparison(
                            gen_dir=None,
                            real_dir=real_dir,
                            gen_func=gen_1_img,
                            dataset_name=ds,
                            num_samples=FLAGS.pca_samples,
                            device=device,
                            save_path=output_dir,
                            step=FLAGS.step,
                            ext_model_name=FLAGS.pca_model,
                            batch_size=FLAGS.batch_size_fid,
                            dataset_res=32,
                            dataset_split="train",
                            pca_mode=FLAGS.pca_mode,
                            max_images_per_class=FLAGS.max_images_per_class,
                            imb_factor=FLAGS.imb_factor if "lt" in ds else None,
                        )
                        #     tuple unpacking 
                        if len(result) == 3:
                            explained_variance, save_path, log_msg = result
                        else:
                            explained_variance, save_path = result
                
                if explained_variance is not None:
                    print(f"PCA Visualization ({ds}): PC1={explained_variance[0]:.1%}, PC2={explained_variance[1]:.1%}, Total={explained_variance.sum():.1%}")
                    with open(log_file, "a") as f:
                        f.write(f"step: {FLAGS.step}, training dataset: {FLAGS.dataset_name}, PCA measured on: {ds} -> PC1: {explained_variance[0]:.1%} PC2: {explained_variance[1]:.1%} Total: {explained_variance.sum():.1%} File: {save_path}\n")
                else:
                    print(f"PCA Visualization ({ds}): Failed")
                    with open(log_file, "a") as f:
                        f.write(f"step: {FLAGS.step}, training dataset: {FLAGS.dataset_name}, PCA measured on: {ds} -> Failed\n")
                        
        except Exception:
            with open(log_file, "a") as f:
                f.write("\nERROR during PCA visualization\n")
                f.write(traceback.format_exc())
            raise

        print()
        print("PCA visualization has been completed")
        print()
        with open(log_file, "a") as f:
            f.write(f"Total PCA visualization time: {datetime.now() - start_time}. computing PCA visualization finished at {datetime.now()}\n")

    # === Likelihood (CNF) ===
    if FLAGS.measure_likelihood:
        start_time = datetime.now()

        try:
            mean, std = _get_mean_std(FLAGS.dataset_name, FLAGS.data_norm)
            vf = LLVectorField(new_net, dim=(3, 32, 32), trace=('hutch' if FLAGS.trace_estimator_ll == 'hutch' else 'exact')).to(device)
            t_span = torch.tensor([1.0, 0.0], device=device)

            for ds in datasets:
                #  
                tfm = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(mean, std),
                ])
                if ds in ['cifar10', 'cifar100', 'cifar10_lt', 'cifar100_lt']:
                    imb = FLAGS.imb_factor if 'lt' in ds else None
                    real_dataset = get_real_dataset(ds, split=FLAGS.likelihood_split, transform=tfm, data_root="./data", imb_factor=imb)
                else:
                    raise ValueError(f"likelihood currently supports cifar10/100(_lt); got {ds}")


                #   
                if FLAGS.ll_max_samples is not None:
                    real_dataset = torch.utils.data.Subset(real_dataset, list(range(min(int(FLAGS.ll_max_samples), len(real_dataset)))))

                loader = torch.utils.data.DataLoader(real_dataset, batch_size=FLAGS.batch_size_ll, shuffle=False, num_workers=4)

                # CSV  
                out_csv = FLAGS.ll_output_csv
                if out_csv is None or len(datasets) > 1:
                    out_csv = os.path.join(output_dir, f"ll_{ds}_step_{FLAGS.step}.csv")

                print(f"Computing likelihood on {ds} ({FLAGS.likelihood_split}), saving to {out_csv}")
                with open(out_csv, 'w') as fcsv:
                    fcsv.write("index,label,loglik,logp0,s_final,bpd,npd\n")
                    offset = 0
                    batch_iter = tqdm(loader, desc=f"LL ({ds})", ncols=80)
                    #  
                    tot_n = 0
                    sum_ll = 0.0
                    sumsq_ll = 0.0
                    #  BPD    (CSV     )
                    class_bpd_values = defaultdict(list)
                    #      
                    class_avg = defaultdict(lambda: {
                        'n': 0,
                        'sum_loglik': 0.0,
                        'sum_logp0': 0.0,
                        'sum_s_final': 0.0,
                        'sum_bpd': 0.0,
                        'sumsq_bpd': 0.0,
                        'sum_npd': 0.0,
                    })
                    classwise_error = None
                    for xb, yb in batch_iter:
                        xb = xb.to(device)
                        if FLAGS.ll_manual_euler:
                            #  Euler: t=1->0  ,   
                            steps = max(1, int(FLAGS.ll_euler_steps))
                            #  
                            if FLAGS.ll_time_direction.lower() == 'forward':
                                t = torch.zeros((), device=device)
                                dt = 1.0 / steps
                            else:
                                t = torch.ones((), device=device)
                                dt = -1.0 / steps
                            x = xb.clone()
                            s_acc = torch.zeros(x.size(0), dtype=torch.float64, device=device)
                            for _ in range(steps):
                                #    divergence 
                                x_flat = x.view(x.size(0), -1).requires_grad_(True)
                                state = torch.cat([torch.zeros(x.size(0), 1, device=device), x_flat], dim=1)
                                dxs_now = vf(t, state)
                                ds_now = dxs_now[:, 0]
                                dx_now = dxs_now[:, 1:]

                                if FLAGS.ll_midpoint:
                                    #  
                                    x_mid = (x_flat + dx_now * (0.5 * dt)).view_as(x)
                                    x_mid_flat = x_mid.view(x.size(0), -1).requires_grad_(True)
                                    state_mid = torch.cat([torch.zeros(x.size(0), 1, device=device), x_mid_flat], dim=1)
                                    dxs_mid = vf(t + 0.5 * dt, state_mid)
                                    ds = dxs_mid[:, 0]
                                    dx_use = dxs_mid[:, 1:]
                                else:
                                    ds = ds_now
                                    dx_use = dx_now

                                #     no_grad 
                                with torch.no_grad():
                                    s_acc.add_(ds.detach().double() * dt)
                                x = (x_flat + dx_use * dt).view_as(x)
                                x = x.detach()
                                t = t + dt
                            x0_flat = x.view(x.size(0), -1)
                            s_final = s_acc.float()
                        else:
                            # odeint  1->0 
                            s0 = torch.zeros(xb.size(0), 1, device=device)
                            x_flat = xb.view(xb.size(0), -1)
                            state0 = torch.cat([s0, x_flat], dim=1)
                            state_T = odeint(vf, state0, t_span, rtol=FLAGS.tol, atol=FLAGS.tol, method=FLAGS.integration_method)[-1]
                            s_final = state_T[:, 0]
                            x0_flat = state_T[:, 1:]

                        # log p0 ()
                        quad = (x0_flat ** 2).sum(dim=1)
                        D = x0_flat.shape[1]
                        log_p0 = -0.5 * (quad + D * torch.log(torch.tensor(2 * 3.141592653589793, device=device)))
                        #      
                        if FLAGS.ll_time_direction.lower() == 'forward':
                            log_p1 = log_p0 + s_final
                        else:
                            log_p1 = log_p0 - s_final

                        #   
                        D = x0_flat.shape[1]
                        bpd = (-log_p1 / (D * math.log(2.0)))
                        npd = (-log_p1 / D)

                        for i in range(xb.size(0)):
                            lbl = yb[i]
                            if isinstance(lbl, torch.Tensor):
                                lbl = int(lbl.item())
                            else:
                                lbl = int(lbl)
                            fcsv.write(f"{offset + i},{lbl},{float(log_p1[i].item())},{float(log_p0[i].item())},{float(s_final[i].item())},{float(bpd[i].item())},{float(npd[i].item())}\n")
                        offset += xb.size(0)
                        #  
                        vals = log_p1.detach()
                        tot_n += vals.numel()
                        sum_ll += vals.sum().item()
                        sumsq_ll += (vals * vals).sum().item()
                        #  BPD       (  )
                        if FLAGS.ll_classwise_stats and classwise_error is None:
                            try:
                                labels = yb
                                if not isinstance(labels, torch.Tensor):
                                    labels = torch.as_tensor(labels)
                                labels = labels.view(-1).cpu()
                                bpd_cpu = bpd.view(-1).detach().cpu()
                                loglik_cpu = log_p1.view(-1).detach().cpu()
                                logp0_cpu = log_p0.view(-1).detach().cpu()
                                s_final_cpu = s_final.view(-1).detach().cpu()
                                npd_cpu = npd.view(-1).detach().cpu()
                                unique_classes = torch.unique(labels)
                                for c in unique_classes.tolist():
                                    mask = (labels == c)
                                    n_c = int(mask.sum().item())
                                    if n_c == 0:
                                        continue
                                    # BPD   ( CSV)
                                    vals_c = bpd_cpu[mask].tolist()
                                    class_bpd_values[int(c)].extend(vals_c)
                                    #   
                                    st = class_avg[int(c)]
                                    st['n'] += n_c
                                    st['sum_loglik'] += float(loglik_cpu[mask].sum().item())
                                    st['sum_logp0'] += float(logp0_cpu[mask].sum().item())
                                    st['sum_s_final'] += float(s_final_cpu[mask].sum().item())
                                    sum_bpd_c = float(bpd_cpu[mask].sum().item())
                                    st['sum_bpd'] += sum_bpd_c
                                    st['sumsq_bpd'] += float((bpd_cpu[mask] * bpd_cpu[mask]).sum().item())
                                    st['sum_npd'] += float(npd_cpu[mask].sum().item())
                            except Exception as e:
                                classwise_error = str(e)
                        #    
                        torch.cuda.empty_cache()

                print(f"Saved: {out_csv}")
                #     :  bpd    
                if tot_n > 0:
                    mean_ll = sum_ll / tot_n
                    # mean bpd = -mean(log_p1)/(D*log(2))
                    #  bpd   .  mean_bpd  
                    # : D  (3*32*32)
                    D_total = int(3 * 32 * 32)
                    mean_bpd = - (mean_ll / (D_total * math.log(2.0)))
                    neg_mean_bpd = -mean_bpd
                    with open(log_file, "a") as f:
                        f.write(f"neg_mean_bpd({ds}): {neg_mean_bpd:.6f}\n")
                #  BPD  CSV    ()
                if FLAGS.ll_classwise_stats:
                    if classwise_error is not None or len(class_bpd_values) == 0:
                        reason = classwise_error if classwise_error is not None else "no class labels found"
                        print(f"LL classwise BPD stats skipped on {ds}: {reason}")
                    else:
                        out_csv_class = os.path.join(output_dir, f"ll_classwise_bpd_{ds}_step_{FLAGS.step}.csv")
                        with open(out_csv_class, 'w') as fcsvc:
                            fcsvc.write("class,count,mean,var,max,min,median\n")
                            for cls_id in sorted(class_bpd_values.keys()):
                                vals_list = class_bpd_values[cls_id]
                                if len(vals_list) == 0:
                                    continue
                                t = torch.tensor(vals_list, dtype=torch.float32)
                                n = int(t.numel())
                                mean_c = float(t.mean().item())
                                #    unbiased=False
                                var_c = float(t.var(unbiased=False).item()) if n > 0 else 0.0
                                max_c = float(t.max().item())
                                min_c = float(t.min().item())
                                med_c = float(t.median().item())
                                fcsvc.write(f"{cls_id},{n},{mean_c:.6f},{var_c:.6f},{max_c:.6f},{min_c:.6f},{med_c:.6f}\n")
                        print(f"Saved classwise BPD CSV: {out_csv_class}")

                #    CSV  (index ,  )
                if FLAGS.ll_classwise_stats and classwise_error is None and len(class_avg) > 0:
                    out_csv_labelwise = os.path.join(output_dir, f"ll_classwise_{ds}_stats_step_{FLAGS.step}.csv")
                    with open(out_csv_labelwise, 'w') as fcsvl:
                        #  , bpd  
                        fcsvl.write("label,loglik,logp0,s_final,bpd,npd,bpd_var\n")
                        for cls_id in sorted(class_avg.keys()):
                            st = class_avg[cls_id]
                            if st['n'] <= 0:
                                continue
                            n = float(st['n'])
                            mean_loglik = st['sum_loglik'] / n
                            mean_logp0 = st['sum_logp0'] / n
                            mean_s_final = st['sum_s_final'] / n
                            mean_bpd = st['sum_bpd'] / n
                            mean_npd = st['sum_npd'] / n
                            var_bpd = max(0.0, (st['sumsq_bpd'] / n) - (mean_bpd * mean_bpd))
                            fcsvl.write(
                                f"{cls_id},{mean_loglik:.6f},{mean_logp0:.6f},{mean_s_final:.6f},{mean_bpd:.6f},{mean_npd:.6f},{var_bpd:.6f}\n"
                            )
                    print(f"Saved classwise LL stats CSV: {out_csv_labelwise}")

        except Exception:
            with open(log_file, "a") as f:
                f.write("\nERROR during computing likelihood\n")
                f.write(traceback.format_exc())
            raise

        with open(log_file, "a") as f:
            f.write(f"Total likelihood measuring time: {datetime.now() - start_time}. computing likelihood finished at {datetime.now()}\n")




    # Delete all temporary generated images
    if 'generated_temp_dir' in locals() and generated_temp_dir is not None:
        print(f"Deleting generated images in {generated_temp_dir}")
        shutil.rmtree(generated_temp_dir)

    if FLAGS.gen_external_path is not None:
        print(f"Deleting generated images in {FLAGS.gen_external_path}")
        shutil.rmtree(FLAGS.gen_external_path)



if __name__ == "__main__":
    app.run(main)


"""
# Basic usage:
 python compute_eval.py \
  --dataset_measure=cifar10,cifar10_lt \
  --data_external_paths=None,./data/cifar10_lt \
  --measure_fid=True \
  --measure_precall=True \
  --measure_pca=True \
  --pca_samples=5000 \
  --pca_model=vgg16 \
  --pca_mode=feature \
  --max_images_per_class=3 \
  --device=cuda:0

# For CIFAR-100 long-tail dataset:
 python compute_eval.py \
  --dataset_measure=cifar100_lt \
  --measure_pca=True \
  --pca_mode=raw_pixel \
  --pca_samples=5000 \
  --max_images_per_class=5 \
  --device=cuda:0

# For raw pixel PCA with more images per class:
 python compute_eval.py \
  --dataset_measure=cifar10 \
  --measure_pca=True \
  --pca_mode=raw_pixel \
  --pca_samples=5000 \
  --max_images_per_class=7 \
  --device=cuda:0

# For likelihood evaluation:
CUDA_VISIBLE_DEVICES=1 python compute_eval.py \
  --input_dir ./results_cifar10_lt \
  --directory auto \
  --model icfm \
  --dataset_measure=cifar10_lt \
  --data_external_paths=None \
  --measure_fid=False \
  --measure_precall=False \
  --measure_pca=True \
  --measure_likelihood=True \
  --device=cuda:0

"""