# m_npeff_small_dense_model001.py
R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


python3 -i local_scripts/m_npeff/m_npeff_small_dense_model001.py

CUDA_VISIBLE_DEVICES=3 python -i local_scripts/m_npeff/m_npeff_small_dense_model001.py

"""
import dataclasses
from importlib import reload
import random

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import tensorflow as tf

from em import datasets as em_datasets
from em.fishers import generate_dense_m_pefs
from em.tools.m_npeff import m_npeff1

from local_scripts.m_npeff import small_cifar_model_utils


###############################################################################
make_model = small_cifar_model_utils.make_model
load_datasets = small_cifar_model_utils.load_datasets
compile_model = small_cifar_model_utils.compile_model
###############################################################################
compute_m_pefs_for_batch = generate_dense_m_pefs.compute_m_pefs_for_batch
compute_flat_m_pefs_for_ds = generate_dense_m_pefs.compute_flat_m_pefs_for_ds
###############################################################################


model = make_model()
compile_model(model, learning_rate=1e-3)

train_ds, val_ds = load_datasets(batch_size=32)

model.fit(train_ds, steps_per_epoch=1_000, epochs=4, validation_data=val_ds)


# for x, _ in val_ds:
#     break

# x = x[:8]
# q = compute_m_pefs_for_batch(model, model.trainable_variables, x[:8])

N_EXAMPLES = 32
# N_EXAMPLES = 512
# N_EXAMPLES = 1024

A = compute_flat_m_pefs_for_ds(
    model,
    model.trainable_variables,
    val_ds.unbatch().map(lambda x, y: x),
    # n_examples=128,
    # n_examples=2048,
    n_examples=N_EXAMPLES,
    batch_size=1,
    #
    normalize_pefs=True,
)

# Try the decomposition on lab servers with n_examples=512, 1024, 2048.
# Have the rank vary between 32, 64, 128 depending on the number of examples.

reload(m_npeff1)
factorizer = m_npeff1.Factorizer(
    A,
    rank=16,
    # rank=32,
    # rank=64,
    lr_G=1e-3,
    eps=1e-7,
    #
    loss_frequency=10,
)

# losses = factorizer.fit(1000)
losses = factorizer.fit(10, update_W=False)

# losses = factorizer.fit(50, update_G=False)
losses = factorizer.fit(400)

# plt.imshow(factorizer.W.numpy().T[:, :400]); plt.show()

for images in val_ds.unbatch().map(lambda x, y: x).batch(N_EXAMPLES).as_numpy_iterator():
    break


def plot_top_comps(component_index: int, n_rows: int, n_cols: int):
    n_examples = n_rows * n_cols
    top_inds = np.argsort(-factorizer.W.numpy()[:, component_index])[:n_examples]
    #
    img = np.zeros([n_rows * 32, n_cols * 32, 3], dtype=np.float32)
    for i, ind in enumerate(top_inds):
        row = i // n_cols
        col = i % n_cols
        img[32 * row : 32 * (row + 1), 32 * col : 32 * (col + 1)] = images[ind]
    #
    plt.imshow(img)
    plt.show()


# plot_top_comps(0, 4, 8)
# plot_top_comps(1, 4, 8)
# plot_top_comps(2, 4, 8)
# plot_top_comps(3, 4, 8)

for i in range(16):
    plot_top_comps(i, 4, 8)

