import math

import matplotlib.pyplot as plt
import matplotlib
import matplotlib.patches as patches
import numpy as np
import torch
import torchdyn
from torchdyn.datasets import generate_moons
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision import transforms

# Implement some helper functions


def eight_normal_sample(n, dim, scale=1, var=1):
    m = torch.distributions.multivariate_normal.MultivariateNormal(
        torch.zeros(dim), math.sqrt(var) * torch.eye(dim)
    )
    centers = [
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
        (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
        (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
        (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
    ]
    centers = torch.tensor(centers) * scale
    noise = m.sample((n,))
    multi = torch.multinomial(torch.ones(8), n, replacement=True)
    data = []
    for i in range(n):
        data.append(centers[multi[i]] + noise[i])
    data = torch.stack(data)
    return data

def k_normal_sample(n, dim, k, mu_list, sigma_list, scale_list):
    """
    Sample data from k normal distributions.
    Each normal is defined by mu_list and sigma_list. scale_list contains mixture weights (sum to 1).
    Larger scale yields more samples from that component.

    Args:
        n (int): total samples
        dim (int): data dimension
        k (int): number of components
        mu_list (list or np.ndarray or torch.Tensor): means (k, dim)
        sigma_list (list or np.ndarray or torch.Tensor): std devs (k, dim) or (k,) or float
        scale_list (list or np.ndarray or torch.Tensor): mixture weights (k,) (sum to 1)

    Returns:
        data (torch.Tensor): (n, dim)
    """
    # Convert inputs to torch tensors
    mu_list = torch.tensor(mu_list, dtype=torch.float32)
    sigma_list = torch.tensor(sigma_list, dtype=torch.float32)
    scale_list = torch.tensor(scale_list, dtype=torch.float32)
    scale_list = scale_list / scale_list.sum()  # normalize if not summing to 1

    # Decide component per sample
    counts = torch.multinomial(scale_list, n, replacement=True)  # (n,)
    data = []
    for i in range(n):
        idx = counts[i].item()
        mu = mu_list[idx]
        sigma = sigma_list[idx]
        # If sigma is scalar, expand to dim
        if sigma.ndim == 0:
            sigma = sigma * torch.ones(dim)
        elif sigma.ndim == 1 and sigma.shape[0] == k:
            sigma = sigma[idx] * torch.ones(dim)
        # Draw sample
        sample = torch.normal(mu, sigma)
        data.append(sample)
    data = torch.stack(data, dim=0)
    return data




def sample_moons(n):
    x0, _ = generate_moons(n, noise=0.2)
    return x0 * 3 - 1


def sample_8gaussians(n):
    return eight_normal_sample(n, 2, scale=5, var=0.1).float()

def sample_unbalanced_kgaussians(n, k, mu_list, sigma_list, scale_list):
    return k_normal_sample(n, 2, k, mu_list, sigma_list, scale_list).float()


class torch_wrapper(torch.nn.Module):
    """Wraps model to torchdyn compatible format."""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x, *args, **kwargs):
        return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))


def plot_trajectories(traj, vis_grid=False, colors=('#1f77b4', '#ff7f0e', '#2ca02c')):
    """Plot trajectories of some selected samples."""
    n = 2000
    plt.figure(figsize=(6, 6))
    plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c=colors[0])
    plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.2, alpha=0.2, c=colors[1])
    plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c=colors[2])
    plt.legend(["Prior sample z(S)", "Flow", "z(0)"])
    plt.xticks([])
    plt.yticks([])
    if vis_grid:
        # x, y 1    
        plt.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.7)
        plt.xticks(range(int(traj[:, :, 0].min()), int(traj[:, :, 0].max()) + 1), fontsize=10)
        plt.yticks(range(int(traj[:, :, 1].min()), int(traj[:, :, 1].max()) + 1), fontsize=10)
    plt.show()


def plot_trajectories_with_ratios(traj, vis_grid=False, centers=None, names=None, colors=('#1f77b4', '#ff7f0e', '#2ca02c', '#d62728')):
    """
    Plot trajectories with an option to display ratios at target centers.
    traj: (T, N, D)
    """
    n = 2000
    fig, ax = plt.subplots(figsize=(6, 6)) # Create figure and axes
    ax.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c=colors[0], label="Prior sample z(S)")
    ax.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.2, alpha=0.2, c=colors[1], label="Flow")
    ax.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c=colors[2], label="z(0)")
    ax.legend()
    ax.set_xticks([])
    ax.set_yticks([])

    if vis_grid:
        ax.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.7)
        ax.set_xticks(range(int(traj[:, :, 0].min()), int(traj[:, :, 0].max()) + 1))
        ax.set_yticks(range(int(traj[:, :, 1].min()), int(traj[:, :, 1].max()) + 1))

    if centers is not None:
        # Ensure traj[-1] is a torch tensor for compute_majority_ratios
        counts, ratios = compute_majority_ratios(torch.tensor(traj[-1]), centers)
        if names is None:
            names = [str(tuple(c)) for c in centers]

        for i, center in enumerate(centers):
            text_x, text_y = center[0], center[1]
            offset_x = 0.5 # Simple offset for text
            offset_y = 0.5
            ax.text(text_x + offset_x, text_y + offset_y,
                    f"#{counts[i]}, {ratios[i]*100:.2f}%",
                    fontsize=8, color='black', ha='center', va=('top' if i % 2 == 1 else 'bottom'),
                    bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=1))
            ax.plot(center[0], center[1], 'o', color='red', markersize=8, markeredgecolor='black') # Plot centers

    plt.show()


####### custom functions #######








def visualize_pi(pi, title=None, xlabel='j', ylabel='i', cmap='viridis', fixed_value_scale=None, figsize=(6, 6), save_path=None):
    """
    Visualize 2D matrix pi as a heatmap.

    Args:
        pi (torch.Tensor or numpy.ndarray): shape (n, m)
        title (str, optional): plot title
        xlabel (str, optional): x-axis label
        ylabel (str, optional): y-axis label
        cmap (str, optional): colormap (default: 'viridis')
        fixed_value_scale (float, optional): if given, fix color scale to [0, value]
        figsize (tuple, optional): figure size (default: (6,6))
        save_path (str, optional): path to save (None to skip)
    """
    # If torch.Tensor, convert to numpy array
    if isinstance(pi, torch.Tensor):
        pi = pi.detach().cpu().numpy()

    plt.figure(figsize=figsize)
    if fixed_value_scale is not None and isinstance(fixed_value_scale, (float, int)):
        im = plt.imshow(pi, cmap=cmap, origin='lower', aspect='auto', vmin=0, vmax=fixed_value_scale)
    else:
        im = plt.imshow(pi, cmap=cmap, origin='lower', aspect='auto')
    plt.colorbar(im)
    if title:
        plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
    plt.show()


def compute_majority_ratios(points, centers):
    """
    Classify 2D points to the nearest of given centers, returning group counts and ratios.

    Args:
        points (torch.Tensor): shape (N, 2)
        centers (list of tuple or array): [(x1,y1), (x2,y2), ...]
    Returns:
        counts (torch.Tensor): shape (K,),   
        ratios (torch.Tensor): shape (K,),   
    """
    # Convert centers to tensor
    cent = torch.tensor(centers, dtype=points.dtype, device=points.device)  # (K,2)
    #         → (N,K)
    dists = ((points.unsqueeze(1) - cent.unsqueeze(0)) ** 2).sum(dim=2)
    # Index of nearest center
    idx = dists.argmin(dim=1)  # (N,)
    # Aggregate counts per center
    counts = torch.bincount(idx, minlength=cent.shape[0])
    # Compute ratios
    ratios = counts.float() / points.size(0)
    return counts, ratios

def print_majority_ratios(points, centers, names=None):
    """
    Print the result of compute_majority_ratios.

    Args:
        points (torch.Tensor): shape (N,2)
        centers (list of tuple):  
        names (list of str, optional):    (: str(centers[i]))
    """
    counts, ratios = compute_majority_ratios(points, centers)
    # Default names
    if names is None:
        names = [str(tuple(c)) for c in centers]
    # Print
    for name, c, r in zip(names, counts.tolist(), ratios.tolist()):
        print(f"{name}: #{c}, {r*100:.2f}%")


def count_duplicate_rows(x):
    """
    Return the number of duplicate rows in the input tensor x.
    Args:
        x (torch.Tensor): (N, D)  
    Returns:
        num_duplicates (int):   
    """
    # Convert tensor to numpy
    x_np = x.detach().cpu().numpy()
    # Convert each row to a tuple
    x_tuples = [tuple(row) for row in x_np]
    # Remove duplicates using a set
    unique_x = set(x_tuples)
    num_duplicates = len(x_tuples) - len(unique_x)
    return num_duplicates, len(unique_x)

def plot_sample_points(
    x0, x1, x0_recoupled, x1_recoupled,
    labels=('x0', 'x1', 'x0_recoupled', 'x1_recoupled'),
    colors=('red', 'blue', 'green', 'orange'),
    markers=('o', 's', '^', 'x'),
    s=20, alpha=0.8, figsize=(6, 6)
):
    """
    Plot original samples (x0), target samples (x1), and OT samples (x0_recoupled, x1_recoupled)
    on a single 2D scatter plot.
    """

    samples = [x0, x1, x0_recoupled, x1_recoupled]
    plt.figure(figsize=figsize)
    for arr, label, color, marker in zip(samples, labels, colors, markers):
        data = arr.detach().cpu().numpy()
        plt.scatter(
            data[:, 0], data[:, 1],
            s=s, alpha=alpha, c=color,
            marker=marker, label=label
        )
    plt.legend()
    plt.xticks([])
    plt.yticks([])
    plt.show()


def plot_sample_points_weightvis(
    x0, x1, x0_recoupled, x1_recoupled, weights,
    visual_count=20,
    labels=('x0', 'x1', 'x0_recoupled', 'x1_recoupled'),
    colors=('#1f77b4', '#9370db', '#2ca02c', '#ff7f0e'),
    markers=('o', 's', '^', 'x'),
    s=20, alpha=1.0, figsize=(6, 6),
    square_mode=False,
    source_square=None,
    target_squares=None,
    #  :      
    draw_square_outlines=False,
    savepath=None,
    save_bbox_inches='tight',
    save_pad_inches=0.1,
):
    """
    Plot sample points and draw lines between x0_recoupled and x1_recoupled with color intensity based on weights.
    """

    # Prepare weights for line coloring (unify to numpy array)
    if torch.is_tensor(weights):
        ws = weights.detach().cpu().numpy().flatten()
    else:
        ws = np.asarray([weights] * x0_recoupled.shape[0], dtype=float)
    wmin, wmax = ws.min(), ws.max()

    # Scatter original sample points
    n_vis = min(visual_count, x1_recoupled.shape[0])
    visw_idx = np.random.choice(x1_recoupled.shape[0], n_vis, replace=False)
    samples = [x0, x1, x0_recoupled[visw_idx], x1_recoupled[visw_idx]]
    plt.figure(figsize=figsize)
    for idx, (arr, label, color, marker) in enumerate(zip(samples, labels, colors, markers)):
        data = arr.detach().cpu().numpy()
        if label == 'x1_recoupled':
            if abs(wmax - wmin) < 1e-12:
                sizes = s
            else:
                norm_ws = (ws[visw_idx] - wmin) / (wmax - wmin + 1e-12)
                sizes = 200 * norm_ws
            plt.scatter(
                data[:, 0], data[:, 1],
                s=sizes,
                alpha=alpha, c=color,
                marker=marker, label=label,
                edgecolors='k', linewidths=1.5
            )
        else:
            plt.scatter(data[:, 0], data[:, 1], s=s, alpha=alpha, c=color, marker=marker, label=label)

    # Display weight range on plot
    plt.text(0.01, 0.99,
             f"weight min: {wmin:.4f}\nweight max: {wmax:.4f}",
             transform=plt.gca().transAxes,
             fontsize=8, va='top')

    # Draw connecting lines with color based on weight
    norm_ws = (ws[visw_idx] - wmin) / (wmax - wmin + 1e-12)
    cmap = matplotlib.cm.get_cmap('coolwarm')
    for i, w in enumerate(norm_ws):
        p0 = x0_recoupled[visw_idx[i]].detach().cpu().numpy()
        p1 = x1_recoupled[visw_idx[i]].detach().cpu().numpy()
        color = cmap(w)
        plt.plot([p0[0], p1[0]], [p0[1], p1[1]], color=color, linewidth=0.1, alpha=0.8)

    # Draw square outlines and sampling probabilities in squares mode (thin line + text)
    if square_mode:
        ax = plt.gca()
        # Source square outline and probability (always 100% since single source square)
        if source_square is not None and len(source_square) >= 4:
            scx, scy, sw, sh = source_square[0], source_square[1], source_square[2], source_square[3]
            src_rect = patches.Rectangle(
                (scx - sw / 2.0, scy - sh / 2.0), sw, sh,
                linewidth=0.8, edgecolor='#5B7FA8', facecolor='none', linestyle='--'
            )
            ax.add_patch(src_rect)
            prob_src = 1.0
            ax.text(
                scx, scy - sh / 2.0 - 0.2,
                f"{prob_src * 100:.1f}%",
                ha='center', va='top', fontsize=8,
                bbox=dict(facecolor='white', alpha=0.6, edgecolor='none', pad=1)
            )
        # Target squares outlines and probabilities
        if target_squares is not None and len(target_squares) > 0:
            counts = []
            for spec in target_squares:
                if len(spec) >= 5:
                    counts.append(float(spec[4]))
                else:
                    counts.append(1.0)
            total = sum(counts) if sum(counts) > 0 else len(counts)
            for idx, spec in enumerate(target_squares):
                cx, cy, w, h = spec[0], spec[1], spec[2], spec[3]
                rect = patches.Rectangle(
                    (cx - w / 2.0, cy - h / 2.0), w, h,
                    linewidth=0.8, edgecolor='#A66A6A', facecolor='none'
                )
                ax.add_patch(rect)
                p = counts[idx] / total if total > 0 else 0.0

                # Count within currently visible points
                left, right = cx - w / 2.0, cx + w / 2.0
                bottom, top = cy - h / 2.0, cy + h / 2.0

                # Count inside x1 (all x1 visualized as-is)
                if isinstance(x1, torch.Tensor) and x1.numel() > 0:
                    mask_x1 = (
                        (x1[:, 0] >= left) & (x1[:, 0] <= right) &
                        (x1[:, 1] >= bottom) & (x1[:, 1] <= top)
                    )
                    count_x1 = int(mask_x1.sum().item())
                else:
                    count_x1 = 0

                # Count inside x1_recoupled (visible subset via visw_idx) with unique rows
                if isinstance(x1_recoupled, torch.Tensor) and x1_recoupled.numel() > 0:
                    x1_rec_vis = x1_recoupled[visw_idx]
                    mask_rec = (
                        (x1_rec_vis[:, 0] >= left) & (x1_rec_vis[:, 0] <= right) &
                        (x1_rec_vis[:, 1] >= bottom) & (x1_rec_vis[:, 1] <= top)
                    )
                    if mask_rec.any():
                        subset = x1_rec_vis[mask_rec].detach().cpu().numpy()
                        unique_rec = np.unique(subset, axis=0)
                        count_x1_rec_unique = int(unique_rec.shape[0])
                    else:
                        count_x1_rec_unique = 0
                else:
                    count_x1_rec_unique = 0

                ax.text(
                    cx, cy - h / 2.0 - 0.2,
                    f"GT ratio: {p * 100:.1f}%\n#x1: {count_x1}\n#x1_rec: {count_x1_rec_unique}",
                    ha='center', va='top', fontsize=8,
                    bbox=dict(facecolor='white', alpha=0.6, edgecolor='none', pad=1)
                )

    plt.legend()
    plt.xticks([])
    plt.yticks([])

    # After fixing layout, optionally overlay bold outlines
    plt.tight_layout(pad=0.1)

    if draw_square_outlines and square_mode:
        ax = plt.gca()
        if source_square is not None and len(source_square) >= 4:
            scx, scy, sw, sh = source_square[0], source_square[1], source_square[2], source_square[3]
            src_rect = patches.Rectangle(
                (scx - sw/2.0, scy - sh/2.0), sw, sh,
                linewidth=1.5, edgecolor='blue', facecolor='none', linestyle='--'
            )
            ax.add_patch(src_rect)
        if target_squares is not None:
            for spec in target_squares:
                cx, cy, w, h = spec[0], spec[1], spec[2], spec[3]
                rect = patches.Rectangle(
                    (cx - w/2.0, cy - h/2.0), w, h,
                    linewidth=1.0, edgecolor='red', facecolor='none'
                )
                ax.add_patch(rect)

    # Save or show
    if savepath is not None:
        plt.savefig(savepath, bbox_inches=save_bbox_inches, pad_inches=save_pad_inches)
        plt.close()
    else:
        plt.show()


def energy_weight(energy, beta=1.0):
    """
    based onenergy-weighted flow matching
    inputs
        energy: (batch_size, 1)
        beta: float
    outputs
        weight: (batch_size, 1)
    """
    return torch.exp(-energy * beta) / torch.exp(-energy * beta).sum()


def exp_naming(FLAGS):
    name = FLAGS.model + "_" + FLAGS.dataset_name
    if 'sinkhorn' in FLAGS.model:
        name += ("_reg" + str(FLAGS.reg))
        name += ("+tau" + "inf" + str(FLAGS.tau_b))
    if 'otwfm' in FLAGS.model:
        name += ("_" + FLAGS.weight_type)
        name += ("^" + str(FLAGS.weight_power_factor) if FLAGS.weight_power_factor != 0 else "")
        name += ("+efm" if FLAGS.efm else "")
        name += ("+beta" + str(FLAGS.beta) if FLAGS.efm else "")
        name += ("+norecoupling" if (not FLAGS.recoupling) else "")
    name += ("_multigpu" if FLAGS.parallel else "")
    name += ("_fixsrc" if FLAGS.fixed_source else "")
    name += ("_fixtgt" if FLAGS.fixed_target else "")
    name += (f"_imb{FLAGS.imb_factor}" if FLAGS.imb_factor != 0.01 else "")
    return name


import os
from torch.utils.data import Dataset
class CIFAR10LTDataset_regacy(Dataset):
    """
    Legacy dataset code. Not used currently.
    PyTorch Dataset for CIFAR-10-LT (Long-Tailed) dataset stored in a .npz file.
    - In __init__, load the .npz file from data_dir.
    - In __getitem__, return image and label as tensors.
    """
    def __init__(self, data_dir, split="train", transform=None):
        """
        Args:
            data_dir (str): directory containing the .npz file
            split (str): 'train' or 'test'
            transform (callable, optional): image transform
        """
        # By default, pick the first .npz file
        npz_files = [f for f in os.listdir(data_dir) if f.endswith(".npz")]
        if len(npz_files) == 0:
            raise FileNotFoundError(f"No .npz file found in {data_dir}")
        
        npz_path = os.path.join(data_dir, npz_files[0])
        data = np.load(npz_path)

        # Load according to split
        if split == "train":
            self.images = data["train_data"]
            self.labels = data["train_labels"]
        elif split == "test":
            self.images = data["test_data"]
            self.labels = data["test_labels"]
        else:
            raise ValueError("split must be 'train' or 'test'")
        
        self.describe()

        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]  # (H, W, C), dtype=uint8
        label = self.labels[idx]

        # Image (H, W, C) → (C, H, W)
        img_tensor = torch.from_numpy(img).permute(2, 0, 1).float()
        label_tensor = torch.tensor(label, dtype=torch.long)

        # Apply transform (restore order if it expects (H, W, C))
        if self.transform:
            img = img_tensor.permute(1, 2, 0).numpy().astype(np.uint8)  # (H, W, C), uint8
            img = self.transform(img)
            # If transform returns a tensor, use directly
            if isinstance(img, np.ndarray):
                img_tensor = torch.from_numpy(img).permute(2, 0, 1).float()
            else:
                img_tensor = img

        return img_tensor, label_tensor
    
    def describe(self):
        print(f"CIFAR10LTDataset: {len(self)} images")
        print(f"CIFAR10LTDataset: {self.images.shape}")
        print(f"CIFAR10LTDataset: {self.labels.shape}")
        print(f"CIFAR10LTDataset: {self.transform}")
        print(f"CIFAR10LTDataset: {self.transform.__class__.__name__}")
        print(f"CIFAR10LTDataset: {self.transform.__class__.__name__}")
        # Per-class counts
        unique_labels, counts = np.unique(self.labels, return_counts=True)
        print("\nclass-wise data count:")
        for label, count in zip(unique_labels, counts):
            print(f"class {label}: {count} images")





# code from https://github.com/MediaBrain-SJTU/OC_LT/blob/main/dataset.py
# Long-Tailed Diffusion Models With Oriented Calibration" ICLR2024
class ImbalanceCIFAR10(datasets.CIFAR10):
    base_folder = "cifar-10-batches-py"
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    filename = "cifar-10-python.tar.gz"
    tgz_md5 = "c58f30108f718f92721af3b95e74349a"
    train_list = [
        ["data_batch_1", "c99cafc152244af753f735de768cd75f"],
        ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"],
        ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"],
        ["data_batch_4", "634d18415352ddfa80567beed471001a"],
        ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"],
    ]

    test_list = [
        ["test_batch", "40351d587109b95175f43aff81a1287e"],
    ]
    meta = {
        "filename": "batches.meta",
        "key": "label_names",
        "md5": "5ff9c542aee3614f3951f8cda6e48888",
    }
    cls_num = 10

    def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True,
                 transform=None, target_transform=None, download=False):
        super(ImbalanceCIFAR10, self).__init__(root, train, transform, target_transform, download)
        np.random.seed(rand_number)
        img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor)
        self.num_per_cls_dict = dict()
        self.gen_imbalanced_data(img_num_list)

    def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
        img_max = len(self.data) / cls_num
        img_num_per_cls = []
        if imb_type == 'exp':
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor ** (cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
        elif imb_type == 'step':
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max * imb_factor))
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)
        return img_num_per_cls

    def gen_imbalanced_data(self, img_num_per_cls):
        new_data = []
        new_targets = []
        targets_np = np.array(self.targets, dtype=np.int64)
        
        #      
        existing_classes = np.unique(targets_np)
        all_classes = set(range(self.cls_num))
        missing_classes = sorted(all_classes - set(existing_classes))
        
        #    
        if missing_classes:
            print(f"Warning: Classes {missing_classes} are missing from the dataset. These will be skipped.")
        
        #    ( )
        for class_idx in sorted(existing_classes):
            the_img_num = img_num_per_cls[class_idx]
            self.num_per_cls_dict[class_idx] = the_img_num
            idx = np.where(targets_np == class_idx)[0]
            np.random.shuffle(idx)
            selec_idx = idx[:the_img_num]
            new_data.append(self.data[selec_idx, ...])
            new_targets.extend([class_idx, ] * the_img_num)
        
        if new_data:  #   
            new_data = np.vstack(new_data)
            self.data = new_data
            self.targets = new_targets

    def get_cls_num_list(self):
        cls_num_list = []
        for i in range(self.cls_num):
            cls_num_list.append(self.num_per_cls_dict[i])
        return cls_num_list


class ImbalanceCIFAR100(datasets.CIFAR100):
    base_folder = 'cifar-100-python'
    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
    train_list = [
        ['train', '16019d7e3df5f24257cddd939b257f8d'],
    ]

    test_list = [
        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
    ]
    meta = {
        'filename': 'meta',
        'key': 'fine_label_names',
        'md5': '7973b15100ade9c7d40fb424638fde48',
    }
    cls_num = 100

    def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True,
                 transform=None, target_transform=None, download=False):
        super(ImbalanceCIFAR100, self).__init__(root, train, transform, target_transform, download)
        np.random.seed(rand_number)
        img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor)
        self.num_per_cls_dict = dict()
        self.gen_imbalanced_data(img_num_list)

    def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
        img_max = len(self.data) / cls_num
        img_num_per_cls = []
        if imb_type == 'exp':
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor ** (cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(max(1, int(num)))  #  1 
        elif imb_type == 'step':
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(max(1, int(img_max * imb_factor)))  #  1 
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)
        return img_num_per_cls

    def gen_imbalanced_data(self, img_num_per_cls):
        new_data = []
        new_targets = []
        targets_np = np.array(self.targets, dtype=np.int64)
        
        #      
        existing_classes = np.unique(targets_np)
        all_classes = set(range(self.cls_num))
        missing_classes = sorted(all_classes - set(existing_classes))
        
        #    
        if missing_classes:
            print(f"Warning: Classes {missing_classes} are missing from the dataset. These will be skipped.")
        
        #    ( )
        for class_idx in sorted(existing_classes):
            the_img_num = img_num_per_cls[class_idx]
            self.num_per_cls_dict[class_idx] = the_img_num
            idx = np.where(targets_np == class_idx)[0]
            np.random.shuffle(idx)
            selec_idx = idx[:the_img_num]
            new_data.append(self.data[selec_idx, ...])
            new_targets.extend([class_idx, ] * the_img_num)
        
        if new_data:  #   
            new_data = np.vstack(new_data)
            self.data = new_data
            self.targets = new_targets

    def get_cls_num_list(self):
        cls_num_list = []
        for i in range(self.cls_num):
            cls_num_list.append(self.num_per_cls_dict[i])
        return cls_num_list


def compute_dataset_mean_std(dataset_class, root, train=True, batch_size=1024, num_workers=4, **kwargs):
    """
    Compute per-channel mean and std over the whole dataset using the given dataset class.

    Args:
        dataset_class: dataset class such as torchvision.datasets.CIFAR10
        root: dataset root directory
        train: whether to use training split
        batch_size: batch size
        num_workers: number of DataLoader workers
        **kwargs: extra args passed to dataset class

    Returns:
        mean: per-channel mean (list of float)
        std: per-channel std (list of float)
    """
    dataset = dataset_class(root=root, train=train, download=True, transform=transforms.ToTensor(), **kwargs)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    # Infer channel count automatically
    for x, _ in loader:
        c = x.shape[1]
        break

    channel_sum = torch.zeros(c, dtype=torch.float64)
    channel_sqsum = torch.zeros(c, dtype=torch.float64)
    pixel_count = 0

    for x, _ in loader:
        x = x.to(dtype=torch.float64)          # [B, C, H, W], [0,1]
        b, c, h, w = x.shape
        pixel_count += b * h * w
        x = x.view(b, c, -1)
        channel_sum += x.sum(dim=[0, 2])
        channel_sqsum += (x * x).sum(dim=[0, 2])

    mean = channel_sum / pixel_count
    var = channel_sqsum / pixel_count - mean * mean
    std = var.clamp_min(0).sqrt()

    return [float(v) for v in mean.to(torch.float32)], [float(v) for v in std.to(torch.float32)]