import random

import numpy as np
import torch
import torch.optim as optim
from matplotlib import pyplot as plt
from sklearn.manifold import TSNE
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as Datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.utils as vutils

import os
import shutil
from tqdm import trange, tqdm
from collections import defaultdict
import argparse

import Helpers as hf
from test_minisom import MiniSom
from vgg19 import VGG19
from RES_VAE import VAE
from torchvision.datasets import CIFAR10

from collections import Counter
import pickle
from torch.utils.data import Subset

from sklearn.preprocessing import StandardScaler

def make_positive_definite(cov, eps=1e-6):
    eigvals, eigvecs = torch.linalg.eigh(cov)
    eigvals = torch.clamp(eigvals, min=eps)
    return (eigvecs @ torch.diag(eigvals) @ eigvecs.T)

def generate_synthetic_samples(som, bmu):
    """
    Generate synthetic samples for a given BMU based on running mean and variance.
    :param som: The trained MiniSom object
    :param bmu: BMU index (tuple)
    :return:  a synthetic sample
    """
    mean = som._running_mean[bmu[0]][bmu[1]]
    std = np.sqrt(som._running_var[bmu[0]][bmu[1]])
    sample = mean  + std * np.random.randn(1, mean.shape[0])
    return sample

def generate_synthetic_samples_torch(som, bmu, num_samples=1, regularization=1e-4, device='cuda'):
    mean_np = som._running_mean[bmu[0]][bmu[1]]
    cov_np = som._running_cov[bmu[0]][bmu[1]]

    # Add small regularization in NumPy
    cov_np += regularization * np.eye(mean_np.shape[0])

    # Convert to torch tensor for eigendecomposition
    cov_tensor = torch.tensor(cov_np, dtype=torch.float32, device=device)
    cov_tensor = make_positive_definite(cov_tensor)

    # Convert mean to tensor
    mean = torch.tensor(mean_np, dtype=torch.float32, device=device)

    try:
        dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov_tensor)
        samples = dist.sample((num_samples,))
    except Exception as e:
        print(f"[WARN] Sampling failed at BMU {bmu}: {e}")
        samples = torch.zeros((num_samples, mean.shape[0]), device=device)

    return samples.cpu().numpy()


def get_class_subset(dataset, class_id, max_samples=None):
    """
    Returns a subset of the dataset for the given class_id, optionally limited to max_samples.

    :param dataset: The full dataset
    :param class_id: The class label to filter by
    :param max_samples: Optional maximum number of samples to return
    :return: Subset of the dataset
    """
    indices = [i for i, (_, label) in enumerate(dataset) if label == class_id]

    if max_samples is not None and len(indices) > max_samples:
        random.seed(42)  # for reproducibility
        indices = random.sample(indices, max_samples)

    return Subset(dataset, indices)


class TensorLabelDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        img, label = self.base_dataset[idx]
        if not isinstance(label, torch.Tensor):
            label = torch.tensor(label, dtype=torch.long)
        return img, label

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Training Params")
    # string args
    parser.add_argument("--model_name", "-mn", help="Experiment save name", type=str, required=True)
    parser.add_argument("--dataset_root", "-dr", help="Dataset root dir", type=str, required=True)

    parser.add_argument("--save_dir", "-sd", help="Root dir for saving model and data", type=str, default=".")
    parser.add_argument("--norm_type", "-nt",
                        help="Type of normalisation layer used, BatchNorm (bn) or GroupNorm (gn)", type=str, default="bn")

    # int args
    parser.add_argument("--nepoch", help="Number of training epochs", type=int, default=2000)
    parser.add_argument("--som", help="Som size", type=int, default=10)
    parser.add_argument("--som_iter", help="Som iterations", type=int, default=100)
    parser.add_argument("--batch_size", "-bs", help="Training batch size", type=int, default=128)
    parser.add_argument("--image_size", '-ims', help="Input image size", type=int, default=64)
    parser.add_argument("--ch_multi", '-w', help="Channel width multiplier", type=int, default=64)

    parser.add_argument("--num_res_blocks", '-nrb',
                        help="Number of simple res blocks at the bottle-neck of the model", type=int, default=1)

    parser.add_argument("--device_index", help="GPU device index", type=int, default=0)
    parser.add_argument("--latent_channels", "-lc", help="Number of channels of the latent space", type=int, default=256)
    parser.add_argument("--save_interval", '-si', help="Number of iteration per save", type=int, default=256)
    parser.add_argument("--block_widths", '-bw', help="Channel multiplier for the input of each block",
                        type=int, nargs='+', default=(1, 2, 4, 8))
    # float args
    parser.add_argument("--lr", help="Learning rate", type=float, default=1e-4)
    parser.add_argument("--feature_scale", "-fs", help="Feature loss scale", type=float, default=1)
    parser.add_argument("--kl_scale", "-ks", help="KL penalty scale", type=float, default=1)

    # bool args
    parser.add_argument("--load_checkpoint", '-cp', action='store_true', help="Load from checkpoint")
    parser.add_argument("--deep_model", '-dm', action='store_true',
                        help="Deep Model adds an additional res-identity block to each down/up sampling stage")

    args = parser.parse_args()

    use_cuda = torch.cuda.is_available()
    device = torch.device(args.device_index if use_cuda else "cpu")
    print("")

    # Create dataloaders
    # This code assumes there is no pre-defined test/train split and will create one for you
    print("-Target Image Size %d" % args.image_size)
    transform = transforms.Compose([transforms.Resize(args.image_size),
                                    transforms.CenterCrop(args.image_size),
                                    transforms.RandomHorizontalFlip(0.5),
                                    transforms.ToTensor(),
                                    transforms.Normalize(0.5, 0.5)])

    # Download and prepare dataset using torchvision
    full_dataset = CIFAR10(root=args.dataset_root, train=True, transform=transform, download=True)

    # Create train-test split
    test_split = 0.9
    full_n_train_examples = int(len(full_dataset) * test_split)
    full_n_test_examples = len(full_dataset) - full_n_train_examples
    full_train_set, full_test_set = torch.utils.data.random_split(full_dataset, [full_n_train_examples, full_n_test_examples],
                                                        generator=torch.Generator().manual_seed(42))

    full_train_loader = DataLoader(full_train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
    full_test_loader = DataLoader(full_test_set, batch_size=args.batch_size, shuffle=False)

    # Get a test image batch from the test_loader to visualise the reconstruction quality etc
    full_dataiter = iter(full_test_loader)
    full_test_images, _ = next(full_dataiter)

    # Create AE network.
    vae_net = VAE(channel_in=full_test_images.shape[1],
                  ch=args.ch_multi,
                  blocks=args.block_widths,
                  latent_channels=args.latent_channels,
                  num_res_blocks=args.num_res_blocks,
                  norm_type=args.norm_type,
                  deep_model=args.deep_model).to(device)

    # Setup optimizer
    optimizer = optim.Adam(vae_net.parameters(), lr=args.lr)

    # AMP Scaler
    scaler = torch.cuda.amp.GradScaler()

    if args.norm_type == "bn":
        print("-Using BatchNorm")
    elif args.norm_type == "gn":
        print("-Using GroupNorm")
    else:
        ValueError("norm_type must be bn or gn")

    # Create the feature loss module if required
    if args.feature_scale > 0:
        feature_extractor = VGG19().to(device)
        print("-VGG19 Feature Loss ON")
    else:
        feature_extractor = None
        print("-VGG19 Feature Loss OFF")

    # Let's see how many Parameters our Model has!
    num_model_params = 0
    for param in vae_net.parameters():
        num_model_params += param.flatten().shape[0]

    print("-This Model Has %d (Approximately %d Million) Parameters!" % (num_model_params, num_model_params//1e6))
    fm_size = args.image_size//(2 ** len(args.block_widths))
    print("-The Latent Space Size Is %dx%dx%d!" % (args.latent_channels, fm_size, fm_size))

    # Create the save directory if it does not exist
    if not os.path.isdir(args.save_dir + "/Models"):
        os.makedirs(args.save_dir + "/Models")
    if not os.path.isdir(args.save_dir + "/Results"):
        os.makedirs(args.save_dir + "/Results")

    # Checks if a checkpoint has been specified to load, if it has, it loads the checkpoint
    # If no checkpoint is specified, it checks if a checkpoint already exists and raises an error if
    # it does to prevent accidental overwriting. If no checkpoint exists, it starts from scratch.
    save_file_name = args.model_name + "_" + str(args.image_size)
    if args.load_checkpoint:
        if os.path.isfile(args.save_dir + "/Models/" + save_file_name + ".pt"):
            checkpoint = torch.load(args.save_dir + "/Models/" + save_file_name + ".pt",
                                    map_location="cpu")
            print("-Checkpoint loaded!")
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            vae_net.load_state_dict(checkpoint['model_state_dict'])

            if not optimizer.param_groups[0]["lr"] == args.lr:
                print("Updating lr!")
                optimizer.param_groups[0]["lr"] = args.lr

            start_epoch = checkpoint["epoch"]
            data_logger = defaultdict(lambda: [], checkpoint["data_logger"])
        else:
            raise ValueError("Warning Checkpoint does NOT exist -> check model name or save directory")
    else:
        # If checkpoint does exist raise an error to prevent accidental overwriting
        if os.path.isfile(args.save_dir + "/Models/" + save_file_name + ".pt"):
            raise ValueError("Warning Checkpoint exists -> add -cp flag to use this checkpoint")
        else:
            print("Starting from scratch")
            start_epoch = 0
            # Loss and metrics logger
            data_logger = defaultdict(lambda: [])
    print("")

    latent_dim = args.latent_channels * fm_size * fm_size
    som_size = args.som  # grid size of SOM (e.g., 10x10)

    som = MiniSom(som_size, som_size, latent_dim, sigma=0.95, learning_rate=0.5)
    history_snapshots = {}
    buffer_synth_images = []
    buffer_synth_labels = []
    # Start training loop
    standard_scaler = StandardScaler()
    for class_id in range(10):
        start_epoch = 0
        print(f"\n--- Training on Class {class_id} ---")

        class_subset = get_class_subset(full_dataset, class_id)

        test_split = 0.9
        n_train_examples = int(len(class_subset) * test_split)
        n_test_examples = len(class_subset) - n_train_examples
        train_set, test_set = torch.utils.data.random_split(class_subset,
                                                                      [n_train_examples, n_test_examples],
                                                                      generator=torch.Generator().manual_seed(42))

        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False)

        if class_id != 0:
            images_tensor = torch.stack(buffer_synth_images)
            buffer_synth_labels = [int(l) for l in buffer_synth_labels]
            labels_tensor = torch.tensor(buffer_synth_labels, dtype=torch.long)

            replay_dataset = torch.utils.data.TensorDataset(images_tensor, labels_tensor)
            train_set = TensorLabelDataset(train_set)

            test_split = 0.9
            # Split train_set
            n_train_main = int(len(train_set) * test_split)
            n_test_main = len(train_set) - n_train_main
            train_main, test_main = torch.utils.data.random_split(train_set, [n_train_main, n_test_main],
                                                                  generator=torch.Generator().manual_seed(42))

            # Split replay_dataset
            n_train_replay = int(len(replay_dataset) * test_split)
            n_test_replay = len(replay_dataset) - n_train_replay
            _, test_replay = torch.utils.data.random_split(replay_dataset, [n_train_replay, n_test_replay],
                                                                      generator=torch.Generator().manual_seed(42))

            # Combine splits
            combined_train = torch.utils.data.ConcatDataset([train_main, replay_dataset])
            combined_test = torch.utils.data.ConcatDataset([test_main, test_replay])

            new_train_loader = DataLoader(combined_train, batch_size=args.batch_size, shuffle=True, num_workers=4)
            new_test_loader = DataLoader(combined_test, batch_size=args.batch_size, shuffle=False)
            from collections import Counter


            def get_label_dist(loader):
                all_labels = []
                for _, lbl in loader:
                    all_labels.extend(lbl.tolist())
                return Counter(all_labels)


            print("Train label distribution:", get_label_dist(new_train_loader))
            print("Test label distribution:", get_label_dist(new_test_loader))
        # Get a test image batch from the test_loader to visualise the reconstruction quality etc
        if class_id != 0:
            test = new_test_loader
        else:
            test = test_loader

        dataiter = iter(test)

        # Get first batch
        first_images, _ = next(dataiter)

        # Loop through to get the last batch
        last_images = first_images
        for batch in dataiter:
            last_images, _ = batch  # Overwrites until last

        for epoch in trange(start_epoch, args.nepoch, leave=False):
            vae_net.train()
            if class_id != 0:
                data = new_train_loader
            else:
                data = train_loader
            for i, (images, _) in enumerate(tqdm(data, leave=False)):
                current_iter = i + epoch * len(data)
                images = images.to(device)
                bs, c, h, w = images.shape

                # We will train with mixed precision!
                with torch.cuda.amp.autocast():
                    recon_img, mu, log_var = vae_net(images)

                    kl_loss = hf.kl_loss(mu, log_var)
                    mse_loss = F.mse_loss(recon_img, images)
                    loss = args.kl_scale * kl_loss + mse_loss

                    # Perception loss
                    if feature_extractor is not None:
                        feat_in = torch.cat((recon_img, images), 0)
                        feature_loss = feature_extractor(feat_in)
                        loss += args.feature_scale * feature_loss
                        data_logger["feature_loss"].append(feature_loss.item())

                optimizer.zero_grad()
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(vae_net.parameters(), 10)
                scaler.step(optimizer)
                scaler.update()

                # Log losses and other metrics for evaluation!
                data_logger["mu"].append(mu.mean().item())
                data_logger["mu_var"].append(mu.var().item())
                data_logger["log_var"].append(log_var.mean().item())
                data_logger["log_var_var"].append(log_var.var().item())

                data_logger["kl_loss"].append(kl_loss.item())
                data_logger["img_mse"].append(mse_loss.item())

                # Save results and a checkpoint at regular intervals
                if epoch == args.nepoch - 1:
                    # In eval mode the model will use mu as the encoding instead of sampling from the distribution
                    vae_net.eval()
                    with torch.no_grad():
                        with torch.cuda.amp.autocast():
                            # Save an example from testing and log a test loss
                            recon_img, mu, log_var = vae_net(first_images.to(device))
                            data_logger['test_mse_loss'].append(F.mse_loss(recon_img,
                                                                           first_images.to(device)).item())

                            img_cat = torch.cat((recon_img.cpu(), first_images), 2).float()
                            vutils.save_image(img_cat,
                                              "%s/%s/%s_%d_%d_%d_test.png" % (args.save_dir,
                                                                        "Results",
                                                                        args.model_name,
                                                                        args.image_size,
                                                                           class_id,
                                                                              epoch),
                                              normalize=True)

                            lrecon_img, lmu, llog_var = vae_net(last_images.to(device))
                            data_logger['test_mse_loss'].append(F.mse_loss(lrecon_img,
                                                                           last_images.to(device)).item())

                            limg_cat = torch.cat((lrecon_img.cpu(), last_images), 2).float()
                            vutils.save_image(limg_cat,
                                              "%s/%s/%s_%d_%d_%d_test_l.png" % (args.save_dir,
                                                                              "Results",
                                                                              args.model_name,
                                                                              args.image_size,
                                                                              class_id,
                                                                              epoch),
                                              normalize=True)

                        # Keep a copy of the previous save in case we accidentally save a model that has exploded...
                        if os.path.isfile(args.save_dir + "/Models/" + save_file_name + ".pt"):
                            shutil.copyfile(src=args.save_dir + "/Models/" + save_file_name + ".pt",
                                            dst=args.save_dir + "/Models/" + save_file_name + "_copy.pt")

                        # Save a checkpoint
                        torch.save({
                                    'epoch': epoch + 1,
                                    'data_logger': dict(data_logger),
                                    'model_state_dict': vae_net.state_dict(),
                                    'optimizer_state_dict': optimizer.state_dict(),
                                     }, args.save_dir + "/Models/" + save_file_name + ".pt")

                        vae_net.train()

        vae_net.eval()

        class_latent_vectors = []
        with torch.no_grad():
            for images, _ in train_loader:
                images = images.to(device)
                recon_img, mu, log_var = vae_net(images.to(device))
                latent = mu.view(mu.size(0), -1).cpu().numpy()
                class_latent_vectors.append(latent)

        class_som_latent_vectors = np.concatenate(class_latent_vectors, axis=0)

        combined_latent_vectors = []
        combined_labels = []
        with torch.no_grad():
            for images, labels in data:
                images = images.to(device)
                recon_img, mu, log_var = vae_net(images.to(device))
                latent = mu.view(mu.size(0), -1).cpu().numpy()
                combined_latent_vectors.append(latent)
                combined_labels.extend(labels.cpu().numpy())
                # Print count for the current batch (optional)
                print("Current batch label count:", Counter(labels.cpu().numpy()))

                # Print count of all labels seen so far
                print("Combined label count:", Counter(combined_labels))

        som_latent_vectors = np.concatenate(combined_latent_vectors, axis=0)
        # if class_id == 0:
        #     standard_scaler.fit(som_latent_vectors)
        # else:
        #     standard_scaler.partial_fit(class_som_latent_vectors)
        # som_latent_vectors = standard_scaler.transform(som_latent_vectors)
        if class_id == 0:
            som.random_weights_init(som_latent_vectors)
        print("Som training started")
        from collections import Counter

        label_counts = Counter(combined_labels)
        print(f"\n[class_id={class_id}] Combined label distribution:")
        for label, count in sorted(label_counts.items()):
            print(f"  Class {label}: {count} samples")
        som.train(som_latent_vectors, num_iteration=args.som_iter, random_order=True, verbose=True, use_epochs=True)
        print("Som training finished")

        # for next iteration class
        bmus = []
        for x, label in zip(som_latent_vectors, combined_labels):
            bmu = som.winner(x)
            bmus.append((bmu, label))  # Save (BMU, label)
        history_snapshots[class_id] = bmus

        if class_id == 9:
            bmu_to_label = defaultdict(list)
            for latent, label in zip(som_latent_vectors, combined_labels):
                bmu = som.winner(latent)
                bmu_to_label[bmu].append(label)


        normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        som_size = som.get_weights().shape[0]  # assuming square SOM (N x N)
        print("Creating som bmu plots: 1) Bmu decode visualization 2) Synth generation of 4 images;2x2 grid per bmu")

        fig, axes = plt.subplots(som_size, som_size, figsize=(som_size, som_size))

        for i in range(som_size):
            for j in range(som_size):
                weight = som.get_weights()[i][j]  # shape: (latent_dim,)
                # weight = standard_scaler.inverse_transform(weight.reshape(1, -1)).squeeze()
                # Convert to tensor and reshape to latent image shape
                z_tensor = torch.tensor(weight, dtype=torch.float32).to(device)
                z_tensor = z_tensor.view(1, args.latent_channels, fm_size, fm_size)

                # Decode
                with torch.no_grad():
                    img = vae_net.decoder(z_tensor).detach().cpu()
                    img = (img + 1) / 2  # bring to [0, 1]
                    # img = normalize(img)

                axes[i, j].imshow(img[0].permute(1, 2, 0).numpy().clip(0, 1))
                axes[i, j].axis("off")

        plt.tight_layout()
        os.makedirs(f"{args.save_dir}/samples_SOM", exist_ok=True)
        plt.savefig(f"{args.save_dir}/samples_SOM/som_grid_decoded_{class_id}.png")
        plt.close()

        # row_blocks = []
        # for i in range(som_size):  # SOM rows
        #     col_blocks = []
        #     for j in range(som_size):  # SOM cols
        #         bmu = (i, j)
        #         try:
        #             latent_samples = generate_synthetic_samples_torch(som, bmu, num_samples=4, device = device)
        #         except Exception as e:
        #             print(f"[WARN] Sampling failed at BMU {bmu}: {e}")
        #             # Fill with blank if needed
        #             blank = np.zeros((32 * 2, 32 * 2, 3))
        #             col_blocks.append(blank)
        #             continue
        #
        #         decoded_imgs = []
        #         for z in latent_samples:
        #             z = standard_scaler.inverse_transform(z.reshape(1, -1)).squeeze()
        #             z_tensor = torch.tensor(z, dtype=torch.float32).to(device)
        #             z_tensor = z_tensor.view(1, args.latent_channels, fm_size, fm_size)
        #             with torch.no_grad():
        #                 img = vae_net.decoder(z_tensor).detach().cpu()
        #             img = (img + 1) / 2
        #             decoded_imgs.append(img[0].permute(1, 2, 0).numpy())  # HWC
        #
        #         # Arrange 4 images as 2×2 block for this BMU
        #         top_row = np.hstack(decoded_imgs[:2])
        #         bottom_row = np.hstack(decoded_imgs[2:])
        #         bmu_block = np.vstack([top_row, bottom_row])  # shape: 2H x 2W x 3
        #
        #         col_blocks.append(bmu_block)
        #     row_blocks.append(np.hstack(col_blocks))  # combine 15 BMUs horizontally
        #
        # final_image = np.vstack(row_blocks)  # combine 15 rows vertically

        row_blocks = []
        for i in range(som_size):  # SOM rows
            col_blocks = []
            for j in range(som_size):  # SOM cols
                bmu = (i, j)
                try:
                    latent_samples = generate_synthetic_samples_torch(som, bmu, num_samples=1, device=device)
                except Exception as e:
                    print(f"[WARN] Sampling failed at BMU {bmu}: {e}")
                    blank = np.zeros((32, 32, 3))  # One image only
                    col_blocks.append(blank)
                    continue

                z = latent_samples[0]
                # z = standard_scaler.inverse_transform(z.reshape(1, -1)).squeeze()
                z_tensor = torch.tensor(z, dtype=torch.float32).to(device)
                z_tensor = z_tensor.view(1, args.latent_channels, fm_size, fm_size)

                with torch.no_grad():
                    img = vae_net.decoder(z_tensor).detach().cpu()
                img = (img + 1) / 2
                img_np = img[0].permute(1, 2, 0).numpy()  # HWC

                col_blocks.append(img_np)
            row_blocks.append(np.hstack(col_blocks))  # horizontally stack images for the row

        final_image = np.vstack(row_blocks)  # vertically stack rows
        # Plot
        plt.figure(figsize=(som_size, som_size))
        plt.imshow(final_image.clip(0, 1))
        plt.axis("off")
        plt.tight_layout()
        plt.savefig(f"{args.save_dir}/samples_SOM/class_{class_id:02}_decoded.png")
        plt.show()

        print("Generating synthetic images and storing for next iteration")
        buffer_synth_labels = []
        buffer_synth_images = []
        new_latent_vectors = []
        if history_snapshots:
            print(f"\n[Class {class_id}] Generating synthetic data from previous BMUs...")

            bmus = history_snapshots[class_id]  # List of (i, j) BMU coordinates for previous class
            for count, (bmu, label) in enumerate(history_snapshots[class_id], 1):
                print(f"Generating sample {count}/{len(bmus)} from BMU {bmu} with label {label}")
                try:
                    samples = generate_synthetic_samples_torch(som, bmu, num_samples=1, device=device)
                    new_latent_vectors.append((samples, label))
                except Exception as e:
                    print(f"[WARN] Sampling failed at BMU {bmu}: {e}")

            if new_latent_vectors:
                print(f"[Class {class_id}] Decoding {len(new_latent_vectors)} synthetic samples...")
                cur_decoded_imgs = []

                for z, label in new_latent_vectors:
                    # z = standard_scaler.inverse_transform(z.reshape(1, -1)).squeeze()
                    z_tensor = torch.tensor(z, dtype=torch.float32).to(device)
                    z_tensor = z_tensor.view(-1, args.latent_channels, fm_size, fm_size)
                    with torch.no_grad():
                        img = vae_net.decoder(z_tensor).detach().cpu()
                    img = (img + 1) / 2
                    for i in range(img.shape[0]):
                        single_img = img[i]
                        if not isinstance(single_img, torch.Tensor):
                            single_img = torch.tensor(single_img).permute(2, 0, 1).float()
                        single_img = normalize(single_img)

                        buffer_synth_images.append(single_img.detach())
                        buffer_synth_labels.append(label)
                label_counts = Counter(buffer_synth_labels)
                print("Label distribution in buffer_synth_labels:", label_counts)

        else:
            print(f"[Class {class_id}] No BMUs available in history for synthetic generation.")

        print("Complete! Onto next iteration.")

        if class_id == 9:
            print("Saving trained SOM and metadata...")

            # Compute majority class per BMU
            bmu_majority_labels = {
                bmu: Counter(labels).most_common(1)[0][0]
                for bmu, labels in bmu_to_label.items()
            }

            # Create activation map (BMU hit count)
            all_data_latents = som_latent_vectors  # already built
            activation_map = som.activation_response(all_data_latents)

            # Prepare model components to save
            som_model_data = {
                'weights': som.get_weights(),
                'running_mean': som._running_mean,
                'running_var': som._running_var,
                'running_cov': som._running_cov,
                'bmu_history': history_snapshots,  # historical bmus per class
                'bmu_hits': activation_map,
                'som_size': som_size,
                'input_len': latent_dim,
                'sigma': som._sigma,
                'learning_rate': som._learning_rate,
                'som_bmu_labels': bmu_majority_labels,
                # 'scaler': standard_scaler
            }

            save_path = os.path.join(f"{args.save_dir}/samples_SOM", "trained_som_model.pkl")
            with open(save_path, 'wb') as f:
                pickle.dump(som_model_data, f)

            print(f"✅ SOM model and associated data saved to: {save_path}")