import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
import argparse
import pickle
import os

from tensorflow.keras.datasets import mnist, fashion_mnist, cifar10
from test_minisom import MiniSom


def load_dataset(name, samples_per_class):
    if name == 'mnist':
        (x_train, y_train), _ = mnist.load_data()
        x_train = x_train.reshape(-1, 28 * 28) / 255.0
        input_len = 28 * 28
    elif name == 'fashion_mnist':
        (x_train, y_train), _ = fashion_mnist.load_data()
        x_train = x_train.reshape(-1, 28 * 28) / 255.0
        input_len = 28 * 28
    elif name == 'cifar10':
        (x_train, y_train), _ = cifar10.load_data()
        x_train = x_train.reshape(-1, 32 * 32 * 3) / 255.0
        y_train = y_train.flatten()
        input_len = 32 * 32 * 3
    else:
        raise ValueError("Unsupported dataset. Choose from 'mnist', 'fashion_mnist', 'cifar10'")

    selected_data, selected_labels = [], []
    for digit in range(10):
        indices = np.where(y_train == digit)[0][:samples_per_class]
        selected_data.append(x_train[indices])
        selected_labels.append(y_train[indices])

    return np.vstack(selected_data), np.hstack(selected_labels), input_len


def generate_synthetic_samples(som, bmu):
    mean = som._running_mean[bmu[0]][bmu[1]]
    std = np.sqrt(som._running_var[bmu[0]][bmu[1]])
    return mean + std * np.random.randn(1, mean.shape[0])


def plot_som_weights(snapshot, som_size, title, shape, digit, dataset, iterations):
    fig, axes = plt.subplots(som_size[0], som_size[1], figsize=(10, 10))
    fig.suptitle(title, fontsize=12)
    for i in range(som_size[0]):
        for j in range(som_size[1]):
            ax = axes[i, j]
            img = snapshot[i, j].reshape(*shape)
            ax.imshow(np.clip(img, 0, 1), cmap='gray' if len(shape) == 2 else None)
            ax.axis('off')
    plt.tight_layout()
    os.makedirs(f"synthSamplesSOM1/{dataset}_{som_size[0]}_{iterations}", exist_ok=True)
    plt.savefig(f"synthSamplesSOM1/{dataset}_{som_size[0]}_{iterations}/{digit}.png")

    plt.show()


def main(args):
    selected_data, selected_labels, input_len = load_dataset(args.dataset, args.samples_per_class)

    if args.dataset == 'cifar10':
        image_shape = (32, 32, 3)
    else:
        image_shape = (28, 28)

    som = MiniSom(args.som_size, args.som_size, input_len, sigma=0.95, learning_rate=0.5)
    first_class_data = selected_data[selected_labels == 0]
    som.random_weights_init(first_class_data)

    bmu_history = defaultdict(list)
    history_snapshots = {}
    history_digits = {}
    synthetic_data = []

    for digit in range(10):
        print(f"Training SOM with class {digit}...")
        class_data = selected_data[selected_labels == digit]

        if history_snapshots:
            bmus = history_snapshots[digit - 1]
            for bmu in bmus:
                for _ in range(args.num_synthetic):
                    synthetic_data.append(generate_synthetic_samples(som, bmu))

        mixed_data = np.vstack((class_data, np.vstack(synthetic_data))) if synthetic_data else class_data

        print(f"Samples: {len(mixed_data)} (Synthetic: {len(synthetic_data)}, Original: {len(class_data)})")

        som.train(mixed_data, num_iteration=args.iterations, random_order=True, verbose=True, use_epochs=True)

        bmu_to_label = defaultdict(list)
        for i in range(digit + 1):
            bmus = []
            for x in selected_data[selected_labels == i]:
                bmu = som.winner(x)
                bmus.append(bmu)
                bmu_to_label[bmu].append(i)
            bmu_history[i] = np.array(bmus)

        history_snapshots[digit] = bmu_history[digit].copy()
        history_digits[digit] = som.get_weights().copy()

    for digit in range(10):
        plot_som_weights(history_digits[digit], (args.som_size, args.som_size),
                         f"SOM Weights After Training Digit {digit}", image_shape, digit, args.dataset, args.iterations)

    bmu_majority_labels = {
        bmu: Counter(labels).most_common(1)[0][0] for bmu, labels in bmu_to_label.items()
    }

    model_data = {
        'weights': som.get_weights(),
        'running_mean': som._running_mean,
        'running_var': som._running_var,
        'bmu_history': bmu_history,
        'history_snapshots': history_snapshots,
        'history_digits': history_digits,
        'bmu_hits': som.activation_response(selected_data),
        'som_size': (args.som_size, args.som_size),
        'input_len': input_len,
        'sigma': som._sigma,
        'learning_rate': som._learning_rate,
        'som_bmu_labels': bmu_majority_labels
    }

    with open(args.output, 'wb') as f:
        pickle.dump(model_data, f)

    print(f"✅ SOM model saved to '{args.output}'")

    label_grid = [[-1 for _ in range(args.som_size)] for _ in range(args.som_size)]
    for (i, j), label in bmu_majority_labels.items():
        label_grid[i][j] = label

    print("\n🧭 BMU Majority Labels Grid:")
    for row in label_grid:
        print(" ".join(f"{val:2}" if val != -1 else " ." for val in row))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train SOM on image datasets with synthetic replay")
    parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'fashion_mnist', 'cifar10'],
                        help='Dataset to use')
    parser.add_argument('--som_size', type=int, default=10, help='Size of the SOM grid (NxN)')
    parser.add_argument('--samples_per_class', type=int, default=1000, help='Number of samples to select per class')
    parser.add_argument('--iterations', type=int, default=100, help='Training iterations per class')
    parser.add_argument('--num_synthetic', type=int, default=1, help='Number of synthetic samples per BMU')
    parser.add_argument('--output', type=str, default='trained_som_model.pkl',
                        help='Output filename for the trained SOM model')

    args = parser.parse_args()
    main(args)

# python synthetic_mean_var.py --dataset mnist --som_size 10 --samples_per_class 5000 --iterations 10 --num_synthetic 1 --output test_trained_som_model.pkl
# python synthetic_mean_var.py --dataset fashion_mnist --som_size 10 --samples_per_class 5000 --iterations 10 --num_synthetic 1 --output trained_fashion_mnist_model.pkl
# python synthetic_mean_var.py --dataset cifar10 --som_size 15 --samples_per_class 5000 --iteration 100 --num_synthetic 1 --output trained_som_cifar10.pkl