# Quick test script for dataset loading
# Tests the same dataset loading logic as compute_eval.py

import os
import sys
from collections import Counter
import torch
from absl import app, flags
from torchvision import transforms

#    conditional-flow-matching  
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.join(current_dir, '../../../')
sys.path.insert(0, os.path.abspath(root_dir))

from utils_cifar import get_real_dataset

# compute_eval.py  flags 
FLAGS = flags.FLAGS

# Measures
flags.DEFINE_bool('measure_fid', True, help='measure fid')
flags.DEFINE_bool('measure_precall', True, help='measure precall')
flags.DEFINE_bool('measure_pca', False, help='perform PCA visualization')
flags.DEFINE_bool('measure_likelihood', False, help='compute per-sample log-likelihood via CNF')

flags.DEFINE_list('dataset_measure', ['cifar10','cifar10_lt'], help='list of datasets to evaluate (e.g., cifar10,cifar10_lt, cifar100, cifar100_lt)')

# External paths and PR method
flags.DEFINE_string('gen_external_path', None, help='path to the generated images, if None, generate images in each measure function')
flags.DEFINE_list('data_external_paths', [None, None], help='list of real image dirs, [None] or [auto_dir] for auto directory, [auto_builtin] for auto builtin dataset')
flags.DEFINE_string('method_pr', 'fast', help='method to compute precision and recall, fast or slow')

# PCA
flags.DEFINE_integer('pca_samples', 1024, help='number of samples for PCA visualization')
flags.DEFINE_string('pca_model', 'vgg16', help='feature extraction model for PCA')
flags.DEFINE_string('pca_mode', 'raw_pixel', help='PCA mode: feature (VGG16 features) or raw_pixel (image pixels)')
flags.DEFINE_integer('max_images_per_class', 7, help='maximum number of images to display per class in individual class visualization')

# UNet (   )
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")

# Training (   )
flags.DEFINE_integer("integration_steps", 100, help="number of inference steps")
flags.DEFINE_string("integration_method", "dopri5", help="integration method to use")
flags.DEFINE_integer("step", 400000, help="training steps, default: 400000")
flags.DEFINE_integer("num_gen", 50000, help="number of samples to generate, should be larger than batch_size_fid")
flags.DEFINE_float("tol", 1e-5, help="Integrator tolerance (absolute and relative, default: 1e-5)")
flags.DEFINE_integer("batch_size_fid", 1024, help="Batch size to compute FID")
flags.DEFINE_string("device", "cuda:0", help="Device to use [cuda:0, cuda:1]")

# Path setting
flags.DEFINE_string("input_dir", "./results", help="output_directory")

# sub directory path setting option #0: manually set the sub directory name (directory="none", training_params=None)
flags.DEFINE_string("directory", "none", help="directly indicate the directory name")

## sub directory path setting option #1: manually set the model and hyperparameters (directory="none", training_params!=None)
flags.DEFINE_string("model", "sinkhorn_otwfm", help="flow matching model type")
flags.DEFINE_string("training_params", None, help="model hyperparameters")

## sub directory path setting option #2: automatically set the directory name (directory="auto")
flags.DEFINE_string("dataset_name", "cifar10", help="name of training dataset for model")
flags.DEFINE_string("weight_type", "none", help="weight type for flow matching [inv_tnu]")
flags.DEFINE_float("reg", 1.0, help="regularization parameter")
flags.DEFINE_float("tau_b", 1.0, help="regularization parameter b for Sinkhorn")
flags.DEFINE_float("beta", 1.0, help="beta for energy-weighted flow matching")
flags.DEFINE_bool("efm", False, help="energy-weighted flow matching")
flags.DEFINE_float("weight_power_factor", 2.0, help="weight power factor for Sinkhorn")
flags.DEFINE_bool("parallel", False, help="parallel training")
flags.DEFINE_bool("recoupling", True, help="sinkhorn change the coupling of x0 and x1")
flags.DEFINE_bool("fixed_source", False, help="sinkhorn fixed source")
flags.DEFINE_bool("fixed_target", False, help="sinkhorn fixed target")

flags.DEFINE_float("imb_factor", 0.01, help="imbalance factor for long-tail dataset")

# Likelihood (   )
flags.DEFINE_string('likelihood_split', 'train', help='dataset split for likelihood [train|test]')
flags.DEFINE_integer('batch_size_ll', 4, help='batch size for likelihood evaluation')
flags.DEFINE_string('trace_estimator_ll', 'hutch', help='trace estimator [hutch|exact]')
flags.DEFINE_integer('ll_max_samples', None, help='limit #samples for likelihood (None for all)')
flags.DEFINE_string('ll_output_csv', None, help='output csv path; default: results dir per dataset')
flags.DEFINE_string('data_norm', 'default', help='data normalization used in training [default|cifar10|cifar100|adaptive]')
flags.DEFINE_bool('ll_manual_euler', False, help='True: use memory-friendly manual Euler integrator for likelihood, False: use odeint (dopri5)')
flags.DEFINE_integer('ll_euler_steps', 20, help='number of Euler steps (only if ll_manual_euler=True)')
flags.DEFINE_string('ll_time_direction', 'backward', help='likelihood time direction: backward(1->0) or forward(0->1)')
flags.DEFINE_bool('ll_midpoint', False, help='use midpoint for divergence and state update (less bias), True: high memory cost, False: low memory cost')
flags.DEFINE_string('ll_trace_noise', 'rademacher', help='trace noise type: rademacher|gaussian')
flags.DEFINE_integer('ll_trace_mc', 1, help='Hutchinson MC samples per step (>=1)')
flags.DEFINE_bool('ll_classwise_stats', True, help='log classwise likelihood mean/var per dataset')

FLAGS(sys.argv)


def main(argv):
    print("=" * 60)
    print("Dataset Loading Test")
    print("=" * 60)
    
    # datasets  
    datasets = FLAGS.dataset_measure
    print(f"Testing datasets: {datasets}")
    
    # data_external_paths    (compute_eval.py  )
    raw_paths = FLAGS.data_external_paths
    if raw_paths is None:
        data_paths = [None] * len(datasets)
    else:
        if len(raw_paths) != len(datasets):
            raise ValueError("len(data_external_paths) must equal len(dataset_measure)")
        #  'None'/''/None → None,     
        def _norm_path(p):
            if p is None:
                return None
            if isinstance(p, str) and p.strip().lower() in ("none", "null", ""):
                return None
            return p
        data_paths = [_norm_path(p) for p in raw_paths]

    #      (compute_eval.py  )
    for i, ds in enumerate(datasets):
        if (data_paths[i] is None or data_paths[i] == "auto_dir") and isinstance(ds, str):
            if ds == "cifar10":
                data_paths[i] = "./data/cifar10"
            elif ds == "cifar100":
                data_paths[i] = "./data/cifar100"
            elif ds == "cifar10_lt":
                data_paths[i] = f"./data/cifar10_lt_imb{FLAGS.imb_factor}"
            elif ds == "cifar100_lt":
                data_paths[i] = f"./data/cifar100_lt_imb{FLAGS.imb_factor}"
            if data_paths[i] is not None and not os.path.exists(data_paths[i]):
                print(f"Warning: data path does not exist: {data_paths[i]}")
        elif (data_paths[i] == "auto_builtin") and isinstance(ds, str):
            if ds == "cifar10":
                data_paths[i] = None
            elif ds == "cifar100":
                data_paths[i] = None
            elif ds == "cifar10_lt":
                data_paths[i] = f"./data/cifar10_lt_imb{FLAGS.imb_factor}"
            elif ds == "cifar100_lt":
                data_paths[i] = f"./data/cifar100_lt_imb{FLAGS.imb_factor}"
            if data_paths[i] is not None and not os.path.exists(data_paths[i]):
                print(f"Warning: data path does not exist: {data_paths[i]}")

    print(f"Data paths: {data_paths}")
    print(f"Imbalance factor: {FLAGS.imb_factor}")
    print()

    #   
    for idx, ds in enumerate(datasets):
        print(f"Testing dataset: {ds}")
        print("-" * 40)
        
        try:
            # compute_eval.py  transform 
            tfm = transforms.Compose([
                transforms.ToTensor(),
            ])
            
            #  
            if ds in ['cifar10', 'cifar100', 'cifar10_lt', 'cifar100_lt']:
                imb = FLAGS.imb_factor if 'lt' in ds else None
                dataset = get_real_dataset(ds, split="train", transform=tfm, data_root="./data", imb_factor=imb)
            else:
                print(f"Unknown dataset: {ds}")
                continue
            
            print(f"Dataset type: {type(dataset)}")
            print(f"Total samples: {len(dataset)}")
            
            #    
            class_counts = Counter()
            for i in range(len(dataset)):
                _, label = dataset[i]
                class_counts[label] += 1
            
            print(f"Number of classes: {len(class_counts)}")
            print("Class distribution:")
            
            #    ( )
            for class_id in sorted(class_counts.keys()):
                count = class_counts[class_id]
                print(f"  Class {class_id:2d}: {count:5d} samples")
            
            #  
            counts = list(class_counts.values())
            if counts:
                print(f"Min samples per class: {min(counts)}")
                print(f"Max samples per class: {max(counts)}")
                print(f"Mean samples per class: {sum(counts)/len(counts):.1f}")
            
            #    (CIFAR-10/100 )
            if ds in ['cifar10', 'cifar10_lt']:
                expected_classes = set(range(10))
            elif ds in ['cifar100', 'cifar100_lt']:
                expected_classes = set(range(100))
            else:
                expected_classes = set()
            
            if expected_classes:
                actual_classes = set(class_counts.keys())
                missing_classes = sorted(expected_classes - actual_classes)
                if missing_classes:
                    print(f"Missing classes: {missing_classes}")
                else:
                    print("All expected classes are present")
            
        except Exception as e:
            print(f"Error loading dataset {ds}: {e}")
            import traceback
            traceback.print_exc()
        
        print()

    print("=" * 60)
    print("Test completed")
    print("=" * 60)


if __name__ == "__main__":
    app.run(main)
