R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=1 python -i local_scripts/m_npeff/stiefel/make_small_dense_cifar_lrm_pefs001.py

"""
import os
import dataclasses
from importlib import reload
import random

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import tensorflow as tf

from em import datasets as em_datasets
from em.fishers import generate_dense_m_pefs
from em.tools.m_npeff import m_npeff2
from em.fishers import lrm_pefs

from local_scripts.m_npeff import small_cifar_model_utils


###############################################################################
make_model = small_cifar_model_utils.make_model
load_datasets = small_cifar_model_utils.load_datasets
compile_model = small_cifar_model_utils.compile_model
###############################################################################
compute_m_pefs_for_batch = generate_dense_m_pefs.compute_m_pefs_for_batch
compute_flat_m_pefs_for_ds = generate_dense_m_pefs.compute_flat_m_pefs_for_ds
###############################################################################

model = make_model()
compile_model(model, learning_rate=1e-3)

train_ds, val_ds = load_datasets(batch_size=32)

model.fit(train_ds, steps_per_epoch=1_000, epochs=4, validation_data=val_ds)

###############################################################################
N_EXAMPLES = 512

# PEF_DIR = "/playpen/users/m/project_data/m_npeff1/per_example_fishers"
PEF_DIR = "/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers"
PEF_NAME = "small_cifar_lrm_pefs.512ex.h5"
PEF_PATH = os.path.join(PEF_DIR, PEF_NAME)

###############################################################################

pefs_ds = em_datasets.load('cifar10/binarized', split='test', tokenizer=None, sequence_length=None)
pefs_ds = pefs_ds.take(N_EXAMPLES).batch(1)

variables = model.trainable_variables
nvpe = tf.reduce_sum([tf.size(v) for v in variables]).numpy()


fisher_computer = lrm_pefs.SparseLrmPefComputer(
    model=model,
    variables=model.trainable_variables,
    n_values_per_example=nvpe,
)

saver = lrm_pefs.StreamingLrmPefSaver(
    fisher_computer=fisher_computer,
    n_examples=N_EXAMPLES,
    use_tqdm=True,
)

output_path = os.path.expanduser(PEF_PATH)
saver.compute_and_save_pefs(output_path, pefs_ds)
