# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master.

# Authors: Anonymous

import copy
import os
from datetime import datetime
import csv
import shutil

import torch
torch.backends.cuda.matmul.allow_tf32 = True # For faster training
torch.backends.cudnn.allow_tf32 = True # For faster training

from absl import app, flags
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage
from tqdm import trange, tqdm
from utils_cifar import ema, generate_samples, infiniteloop, plot_fm_weights_histogram, get_real_dataset, classify_generated_images

from torchdiffeq import odeint

import sys
import os
# Move to repository root from this file's directory
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.join(current_dir, '../../../')
sys.path.insert(0, os.path.abspath(root_dir))

from torchcfm.utils import compute_dataset_mean_std
from torchcfm.conditional_flow_matching import (
    ConditionalFlowMatcher,
    ExactOptimalTransportConditionalFlowMatcher,
    SinkhornOptimalTransportConditionalFlowMatcher,
    TargetConditionalFlowMatcher,
    VariancePreservingConditionalFlowMatcher,
)
from torchcfm.models.unet.unet import UNetModelWrapper
from torchcfm.utils import energy_weight, CIFAR10LTDataset_regacy, ImbalanceCIFAR10, ImbalanceCIFAR100, exp_naming

FLAGS = flags.FLAGS

flags.DEFINE_string("model", "icfm", help="flow matching model type[otcfm, sinkhorn_otwfm, sinkhorn_otcfm, sinkhorn_otwfm_dv, icfm, itfm, si]")
flags.DEFINE_string("output_dir", "./results/", help="output_directory")
# UNet
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")

# Training
flags.DEFINE_integer("resume_step", 400000, help="resume from step")
flags.DEFINE_float("lr", 2e-4, help="target learning rate")  # TRY 2e-4
flags.DEFINE_float("grad_clip", 1.0, help="gradient norm clipping")
flags.DEFINE_integer("total_steps", 400001, help="total training steps") 
# Anonymous et al uses 400k 
flags.DEFINE_integer("warmup", 5000, help="learning rate warmup")
flags.DEFINE_integer("batch_size", 128, help="batch size")  # Anonymous et al uses 128
flags.DEFINE_integer("num_workers", 32, help="workers of Dataloader") # default 4
flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate")
flags.DEFINE_bool("parallel", False, help="multi gpu training")
flags.DEFINE_string("dataset_name", "cifar10", help="dataset name")
flags.DEFINE_float("reg", 0.05, help="regularization parameter")
flags.DEFINE_float("tau_b", 1.0, help="regularization parameter b for Sinkhorn")
flags.DEFINE_string("method", "unbalanced", help="method for Sinkhorn: ['unbalanced_knopp', 'unbalanced']")
flags.DEFINE_float("beta", 1.0, help="beta for energy-weighted flow matching")
flags.DEFINE_bool("efm", False, help="energy-weighted flow matching")
flags.DEFINE_string("weight_type", "none", help="weight type for flow matching [inv_tnu]")
flags.DEFINE_float("weight_power_factor", 1.0, help="power factor for weight")
flags.DEFINE_bool("save_weights_plot", False, help="save weights plot")
flags.DEFINE_bool("save_images", False, help="save images")
flags.DEFINE_bool("save_classwise_img_hist", True, help="save classwise images")
flags.DEFINE_bool("normalize_cost", True, help="normalize cost of OT")
flags.DEFINE_bool("recoupling", True, help="sinkhorn change the coupling of x0 and x1")
flags.DEFINE_bool("fixed_source", False, help="sinkhorn fixed source")
flags.DEFINE_bool("fixed_target", False, help="sinkhorn fixed target")
flags.DEFINE_string("data_root", None, help="data root")
flags.DEFINE_string("data_norm", "default", help="choose data normalization: ['adaptive', 'default']")
# Evaluation
flags.DEFINE_integer(
    "save_step",
    200000,
    help="frequency of saving checkpoints, 0 to disable during training",
)
flags.DEFINE_string("device", "cuda:0", help="Device to use [cuda:0, cuda:1]")

flags.DEFINE_string("imb_type", "exp", help="imbalance type for long-tail dataset [exp, step]")
flags.DEFINE_float("imb_factor", 0.01, help="imbalance factor for long-tail dataset")
flags.DEFINE_integer("rand_number", 0, help="random seed for dataset generation")


def warmup_lr(step):
    return min(step, FLAGS.warmup) / FLAGS.warmup


def test(argv):
    use_cuda = torch.cuda.is_available()
    device = torch.device(FLAGS.device if use_cuda else "cpu")

    print(
        "lr, total_steps, ema decay, save_step:",
        FLAGS.lr,
        FLAGS.total_steps,
        FLAGS.ema_decay,
        FLAGS.save_step,
    )

    # Set dataset root location #
    if FLAGS.data_root is None:
        if FLAGS.dataset_name == "cifar10":
            FLAGS.data_root = "./data"
        elif FLAGS.dataset_name == "cifar100":
            FLAGS.data_root = "./data"
        elif FLAGS.dataset_name == "cifar10_lt":
            FLAGS.data_root = "./data"
        elif FLAGS.dataset_name == "cifar100_lt":
            FLAGS.data_root = "./data"
        else:
            raise ValueError(f"Unknown dataset {FLAGS.dataset_name}, must be one of ['cifar10', 'cifar100', 'cifar10_lt', 'cifar100_lt'] or set data_root manually")

    # Configure data normalization #
    if FLAGS.data_norm == "adaptive": # Automatically compute distribution and normalize to zero-mean/unit-variance
        if FLAGS.dataset_name == "cifar10":
            dataset_class = datasets.CIFAR10
        elif FLAGS.dataset_name == "cifar100":
            dataset_class = datasets.CIFAR100
        elif FLAGS.dataset_name == "cifar10_lt":
            dataset_class = ImbalanceCIFAR10
        elif FLAGS.dataset_name == "cifar100_lt":
            dataset_class = ImbalanceCIFAR100
        else:
            raise ValueError(f"Unknown dataset {FLAGS.dataset_name}, must be one of ['cifar10', 'cifar100', 'cifar10_lt', 'cifar100_lt'] or set data_root manually")
        mean, std = compute_dataset_mean_std(dataset_class, FLAGS.data_root, train=True, batch_size=1024, num_workers=4) # IMBALANCED      .
        print(f"Adaptive normalization for {FLAGS.dataset_name} mean: {mean}, std: {std}")
    elif FLAGS.data_norm == "default": # Default setting used in prior experiments (torchcfm default)
        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)
        print(f"Fixed default normalization for {FLAGS.dataset_name} mean: {mean}, std: {std}")
    elif FLAGS.data_norm == "cifar10": # to zero mean, unit variance
        mean = (0.4914, 0.4822, 0.4465) # verified
        std = (0.2470, 0.2435, 0.2616) # verified
        print(f"Fixed CIFAR10 normalization for {FLAGS.dataset_name} mean: {mean}, std: {std}")
    elif FLAGS.data_norm == "cifar100": # to zero mean, unit variance
        mean = (0.5071, 0.48651, 0.44091) # verified
        std = (0.2673, 0.2564, 0.2762) # verified
        print(f"Fixed CIFAR100 normalization for {FLAGS.dataset_name} mean: {mean}, std: {std}")
    elif FLAGS.data_norm == "cifar10_lt": # to zero mean, unit variance
        mean = (0.4989, 0.5044, 0.4926) # verified
        std = (0.2513, 0.2485, 0.2734) # verified
        print(f"Fixed CIFAR10 normalization for {FLAGS.dataset_name} mean: {mean}, std: {std}")
    elif FLAGS.data_norm == "cifar100_lt": # to zero mean, unit variance
        mean = (0.5228, 0.4929, 0.4420) # verified
        std = (0.2677, 0.2617, 0.2780) # verified
        print(f"Fixed CIFAR100 normalization for {FLAGS.dataset_name} mean: {mean}, std: {std}")
    else:
        raise ValueError(f"Unknown data normalization {FLAGS.data_norm}, must be one of ['adaptive', 'default', 'cifar10', 'cifar100']")

    # DATASETS/DATALOADER
    if FLAGS.dataset_name == "cifar10":
        dataset = datasets.CIFAR10(
            root=FLAGS.data_root, # "./data"
            train=True,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(mean, std),
                ]
            ),
        )
    elif FLAGS.dataset_name == "cifar100":
        dataset = datasets.CIFAR100(
            root=FLAGS.data_root, # "./data"
            train=True,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(mean, std),
                ]
            ),
        )
    elif FLAGS.dataset_name == "cifar10_lt_regacy":
        dataset = CIFAR10LTDataset_regacy(
            data_dir=FLAGS.data_root, # example: ./data/cifar10-lt-npz/cifar10-lt-ratio50.npz
            split="train", 
            transform=transforms.Compose(
                [
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(mean, std),
                ]
            ),)
    elif FLAGS.dataset_name == "cifar10_lt":
        dataset = get_real_dataset(
            dataset_name="cifar10_lt",
            split="train",
            transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]),
            data_root=FLAGS.data_root,
            imb_factor=FLAGS.imb_factor,
        )
    elif FLAGS.dataset_name == "cifar100_lt":
        dataset = get_real_dataset(
            dataset_name="cifar100_lt",
            split="train",
            transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]),
            data_root=FLAGS.data_root,
            imb_factor=FLAGS.imb_factor,
        )
    else:
        raise ValueError(f"Unknown dataset {FLAGS.dataset_name}, must be one of ['cifar10', 'cifar100', 'cifar10_lt', 'cifar100_lt']")
    
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=FLAGS.batch_size,
        shuffle=True,
        num_workers=FLAGS.num_workers,
        drop_last=True,
    )

    datalooper = infiniteloop(dataloader)

    # MODELS
    net_model = UNetModelWrapper(
        dim=(3, 32, 32),
        num_res_blocks=2,
        num_channels=FLAGS.num_channel,
        channel_mult=[1, 2, 2, 2],
        num_heads=4,
        num_head_channels=64,
        attention_resolutions="16",
        dropout=0.1,
    ).to(
        device
    )  # new dropout + bs of 128

    ema_model = copy.deepcopy(net_model)
    if FLAGS.parallel:
        print(
            "Warning: parallel training is performing slightly worse than single GPU training due to statistics computation in dataparallel. We recommend to train over a single GPU, which requires around 8 Gb of GPU memory."
        )
        net_model = torch.nn.DataParallel(net_model)
        ema_model = torch.nn.DataParallel(ema_model)

    # show model size
    model_size = 0
    for param in net_model.parameters():
        model_size += param.data.nelement()
    print("Model params: %.2f M" % (model_size / 1024 / 1024))

    #################################
    #            OT-CFM
    #################################

    sigma = 0.0
    if FLAGS.model == "otcfm":
        FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma, normalize_cost=FLAGS.normalize_cost)
    elif FLAGS.model == "sinkhorn_otcfm" or FLAGS.model == "sinkhorn_otwfm":
        FM = SinkhornOptimalTransportConditionalFlowMatcher(sigma=sigma, method=FLAGS.method, reg=FLAGS.reg, reg_m=(float("inf"), FLAGS.tau_b), normalize_cost=FLAGS.normalize_cost, recoupling=FLAGS.recoupling, fixed_source=FLAGS.fixed_source, fixed_target=FLAGS.fixed_target)
    elif FLAGS.model in ["icfm", "fm_ot"]: # (independent) OT path flow matching (Anonymous et al.)
        FM = ConditionalFlowMatcher(sigma=sigma)
    elif FLAGS.model in ["itfm", "fm_dif"] or FLAGS.model == "fm": # (independent) diffusion path flow matching (Anonymous et al.)
        FM = TargetConditionalFlowMatcher(sigma=sigma)
    elif FLAGS.model == "si": # Stochastic Interpolants flow matching (Anonymous et al.)
        FM = VariancePreservingConditionalFlowMatcher(sigma=sigma)
    else:
        raise NotImplementedError(
            f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm', 'si']"
        )
    
    if (not "sinkhorn" in FLAGS.model) and FLAGS.weight_type != "none":
        raise ValueError(f"Weight type {FLAGS.weight_type} is not supported for {FLAGS.model}")

    savedir = os.path.join(FLAGS.output_dir, exp_naming(FLAGS))
    os.makedirs(savedir, exist_ok=True)

    if FLAGS.resume_step > 0:
        try:
            checkpoint_path = os.path.join(savedir, f"{FLAGS.model}_{FLAGS.dataset_name}_weights_step_{FLAGS.resume_step}.pt")
            checkpoint = torch.load(checkpoint_path, map_location=device)
            net_model.load_state_dict(checkpoint["net_model"])
            ema_model.load_state_dict(checkpoint["ema_model"])
            print(f"Resuming from step {checkpoint['step']}")
        except:
            raise ValueError(os.path.join(savedir, f"{FLAGS.model}_{FLAGS.dataset_name}_weights_step_{FLAGS.resume_step}.pt"))
    else:
        raise ValueError("resume_step must be > 0 for testing")



    # JIT compile #   
    # net_model = torch.compile(net_model)

    fm_weight_list = [] #   fm_weight  
    fm_weight_min_list = []
    fm_weight_max_list = []
    fm_weight_mean_list = []
    fm_weight_std_list = []
    fm_weight_inv_std_list = []
    fm_weight_inv_mean_list = []

    # Make log file
    start_time = datetime.now()
    with open(os.path.join(savedir, 'test_log.txt'), 'w') as f:
        f.write("===== Hyperparameters(FLAGS) =====\n")
        for name in FLAGS:
            f.write(f"{name}: {getattr(FLAGS, name)}\n")
        f.write("===== Additional Information =====\n")
        f.write(f"data is normalized to (0,I) from mean: {mean}, std: {std}\n")
        f.write("==================================\n")
        f.write(f"Start Testing at {start_time}\n")
        f.write('\n')

    for batch_idx, batch in enumerate(dataloader):
        x1 = batch[0].to(device)
        x0 = torch.randn_like(x1)
        if FLAGS.model == "icfm" or FLAGS.model == "itfm" or FLAGS.model == "si" or FLAGS.model == "otcfm":
            t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1) # for rest of the models
        elif "sinkhorn" in FLAGS.model:
            t, xt, ut, pi, w_u, w_v, i, j = FM.sample_location_and_conditional_flow(x0, x1) # for sinkhorn OT-CFM
        else:
            raise ValueError(f"Unknown model {FLAGS.model}, must be one of ['icfm', 'itfm', 'si', 'otcfm', 'sinkhorn_otcfm', 'sinkhorn_otwfm']")
        vt = net_model(t, xt)
        
        # loss = torch.mean((vt - ut) ** 2) # naive FM loss
        if FLAGS.weight_type == "none":
            fm_weight = 1.0
        else:
            if FLAGS.weight_type == "inv_tnu":
                tnu = pi.sum(dim=0)
                tnu = tnu.reshape(tnu.size(0), 1, 1, 1).to(device) 
                tnu = tnu / (1/x1.size(0)) # normalizaed by batch size
                fm_weight = 1 / tnu # inverse weight (minority)
                fm_weight = fm_weight[j] # select weight for re-coupling target samples
                fm_weight = fm_weight ** FLAGS.weight_power_factor
            else:
                raise ValueError(f"Unknown weight type {FLAGS.weight_type}, must be one of ['none', 'inv_tnu']")
            if FLAGS.efm:
                fm_weight = energy_weight(fm_weight, beta=FLAGS.beta)
            if FLAGS.save_weights_plot:
                if batch_idx < 10:
                    fm_weight_list.append(fm_weight)
                    fm_weight_min_list.append(fm_weight.min())
                    fm_weight_max_list.append(fm_weight.max())
                    fm_weight_mean_list.append(fm_weight.mean())
                    fm_weight_std_list.append(fm_weight.std())
                    fm_weight_inv_std_list.append((tnu).std())
                    fm_weight_inv_mean_list.append((tnu).mean())
                if len(fm_weight_list) > 100:  #  100 
                    fm_weight_list = fm_weight_list[-100:]
                    fm_weight_min_list = fm_weight_min_list[-100:]

        loss = torch.mean(((vt - ut) ** 2) * fm_weight) # weighted FM loss

        # sample and save images at intervals; default settings disable this behavior
        if FLAGS.save_images and FLAGS.save_step > 0 and (batch_idx % FLAGS.save_step == 0):
            generate_samples(net_model, FLAGS.parallel, savedir, step=batch_idx, net_="normal", device=device)
            generate_samples(ema_model, FLAGS.parallel, savedir, step=batch_idx, net_="ema", device=device)

        if FLAGS.save_weights_plot and batch_idx < 10 and False: # For inspecting within-batch distribution (currently disabled)
            plot_fm_weights_histogram(fm_weight.reshape(-1), FLAGS.weight_type, savedir, data_type=batch_idx)

    if FLAGS.save_weights_plot and len(fm_weight_list) > 0:
        print(f"Saving weights plot for {FLAGS.weight_type}")
        fm_weight_tensor = torch.cat(fm_weight_list, dim=0)
        plot_fm_weights_histogram(fm_weight_tensor.reshape(-1), FLAGS.weight_type, savedir, extra_info=f"reg{FLAGS.reg}_tau{FLAGS.tau_b}")
        plot_fm_weights_histogram(torch.tensor(fm_weight_min_list), FLAGS.weight_type, savedir, data_type="min of batch", extra_info=f"reg{FLAGS.reg}_tau{FLAGS.tau_b}")
        plot_fm_weights_histogram(torch.tensor(fm_weight_max_list), FLAGS.weight_type, savedir, data_type="max of batch", extra_info=f"reg{FLAGS.reg}_tau{FLAGS.tau_b}")
        plot_fm_weights_histogram(torch.tensor(fm_weight_mean_list), FLAGS.weight_type, savedir, data_type="mean of batch", extra_info=f"reg{FLAGS.reg}_tau{FLAGS.tau_b}")
        
        # Save txt containing average of weight_std in filename
        avg_weight_std = torch.tensor(fm_weight_std_list).mean().item()
        avg_weight_mean = torch.tensor(fm_weight_mean_list).mean().item()
        avg_weight_inv_std = torch.tensor(fm_weight_inv_std_list).mean().item()
        avg_weight_inv_mean = torch.tensor(fm_weight_inv_mean_list).mean().item()   

        # Append one line to log.txt
        with open(os.path.join(savedir, 'test_log.txt'), 'a') as f:
            f.write(f"avg_weight_std: {avg_weight_std:.6f} avg_weight_mean: {avg_weight_mean:.6f} avg_weight_inv_std: {avg_weight_inv_std:.6f} avg_weight_inv_mean: {avg_weight_inv_mean:.6f}\n")

    with open(os.path.join(savedir, 'test_log.txt'), 'a') as f:
        f.write(f"\nEnd Testing at {datetime.now()}\n")
        f.write(f"Total testing time: {datetime.now() - start_time}\n")

    # Save classwise image histogram if requested
    if FLAGS.save_classwise_img_hist:
        print("Starting classwise image histogram generation...")
        classwise_start_time = datetime.now()
        
        # Create directory structure
        classwise_dir = os.path.join(savedir, "classwise_images")
        os.makedirs(classwise_dir, exist_ok=True)
        
        # Determine number of classes based on dataset
        if FLAGS.dataset_name in ["cifar10", "cifar10_lt"]:
            num_classes = 10
            class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
        elif FLAGS.dataset_name in ["cifar100", "cifar100_lt"]:
            num_classes = 100
            class_names = [f'class_{i}' for i in range(100)]
        else:
            raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
        
        # Create class directories
        for i in range(num_classes):
            class_dir = os.path.join(classwise_dir, f"class_{i:02d}_{class_names[i] if i < len(class_names) else f'class_{i}'}")
            os.makedirs(class_dir, exist_ok=True)
        
        # Generate 50000 images
        num_gen_images = 50000
        temp_gen_dir = os.path.join(savedir, "temp_generated_images")
        os.makedirs(temp_gen_dir, exist_ok=True)
        
        print(f"Generating {num_gen_images} images...")
        num_batches = (num_gen_images + FLAGS.batch_size - 1) // FLAGS.batch_size
        generated_count = 0
        
        with torch.no_grad():
            for batch_idx in tqdm(range(num_batches)):
                current_batch_size = min(FLAGS.batch_size, num_gen_images - generated_count)
                
                # Generate images using the trained model
                x = torch.randn(current_batch_size, 3, 32, 32, device=device)
                
                # Use NeuralODE for integration
                t_span = torch.linspace(0, 1, 2, device=device)
                traj = odeint(ema_model, x, t_span, rtol=1e-5, atol=1e-5, method="dopri5")
                
                # Get final images
                generated_images = traj[-1, :]
                # Convert to [0, 255] range and save as PNG
                generated_images = (generated_images * 127.5 + 128).clip(0, 255).to(torch.uint8)
                
                # Save images
                for i in range(current_batch_size):
                    img_tensor = generated_images[i]  # [3, 32, 32]
                    img_path = os.path.join(temp_gen_dir, f"generated_{generated_count:06d}.png")
                    # Convert tensor to PIL Image and save as PNG

                    img_pil = ToPILImage()(img_tensor.cpu())
                    # Ensure directory exists before saving
                    os.makedirs(os.path.dirname(img_path), exist_ok=True)
                    img_pil.save(img_path)
                    generated_count += 1
                
                if (batch_idx + 1) % 50 == 0:
                    print(f"Generated {generated_count}/{num_gen_images} images")
        
        print(f"Generated {generated_count} images. Starting classification...")
        
        # Classify generated images and get confidences
        predictions, confidences = classify_generated_images(temp_gen_dir, FLAGS.dataset_name, device, num_gen_images, return_confidence=True)
        
        print("Classification completed. Organizing images by class...")
        
        # Count images per class
        class_counts = [0] * num_classes
        for pred in predictions:
            class_counts[pred] += 1
        
        # Move images to class-specific directories with confidence in filename
        for i, (pred, conf) in enumerate(zip(predictions, confidences)):
            src_path = os.path.join(temp_gen_dir, f"generated_{i:06d}.png")
            class_dir = os.path.join(classwise_dir, f"class_{pred:02d}_{class_names[pred] if pred < len(class_names) else f'class_{pred}'}")
            # Format confidence to [0,1] with 4 decimals and include in filename
            conf_str = f"conf{conf:.4f}"
            dst_path = os.path.join(class_dir, f"generated_{i:06d}_{conf_str}.png")
            
            if os.path.exists(src_path):
                shutil.move(src_path, dst_path)
        
        # Save class counts to CSV
        csv_path = os.path.join(savedir, "classwise_counts.csv")
        with open(csv_path, 'w', newline='') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(['class_id', 'class_name', 'count'])
            for i in range(num_classes):
                class_name = class_names[i] if i < len(class_names) else f'class_{i}'
                writer.writerow([i, class_name, class_counts[i]])
        
        # Clean up temporary directory
        shutil.rmtree(temp_gen_dir)
        
        # Log results
        with open(os.path.join(savedir, 'test_log.txt'), 'a') as f:
            f.write(f"\nClasswise image histogram generation completed at {datetime.now()}\n")
            f.write(f"Total classwise generation time: {datetime.now() - classwise_start_time}\n")
            f.write(f"Generated {num_gen_images} images, classified and organized by class\n")
            f.write(f"Class distribution: {dict(zip([f'class_{i}' for i in range(num_classes)], class_counts))}\n")
        
        print(f"Classwise image histogram saved to: {classwise_dir}")
        print(f"Class counts saved to: {csv_path}")
        print(f"Class distribution: {dict(zip([f'class_{i}' for i in range(num_classes)], class_counts))}")


if __name__ == "__main__":
    app.run(test)


"""
CUDA_VISIBLE_DEVICES=1 python test_cifar10.py --model icfm --dataset_name cifar10 --output_dir results_cifar10 --save_classwise_img_hist True
"""