import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from tqdm import tqdm
import numpy as np
from train_mlp_pcax_bp import main as train_model_bp
from utils_pcax.data import get_datax_mlp
from utils_pcax.models import fp_up

import jax.numpy as jnp


class Config:
    def __init__(self):
        self.batch_size = 256
        self.is_supervised = True
        self.nm_epochs = 25
        self.activity_decay = 0.0
        self.gamma = 0
        self.h_var = 0.0
        self.activity_init = "ff"
        self.activity_init_kwargs = {"layer_var": 0.0}
        self.input_var = 1.0
        self.latent_dim = 10
        self.hidden_dim = 256
        self.data_dim = 784
        self.nm_layers = 4
        self.T = 8
        self.T_eval = 100
        self.dataset = "mnist"
        self.train_size = 60000
        self.val_size = 5000
        self.test_size = 5000
        self.is_wandb = False
        self.verbose = True
        self.epochs_per_val = 5
        self.make_mean_image = False
        self.is_post_activation = False
        self.is_shared_weights = False
        self.is_hybrid = False
        self.is_cnn = False
        self.is_free_latents = False
        self.free_latent_dim = None
        self.is_arbitrary_graph = False

        self.alpha_up = None
        self.alpha_down = None
        self.lr_x = 0.01
        self.momentum = 0.0
        self.lr_p = None
        self.weight_decay = None
        self.activation_fn = None
        self.is_up_initialisation_default = None

        self.seed = 0
        self.load_path = None
        self.save_path = None

        self.is_acc_init_up = True
        self.acc_init = "ff"
        self.is_rmse_init_up = False
        self.rmse_init = "ff"

        self.out_activation_fn = None


def make_mask(p, batch_size, patch_size):
    # infer but allow only the wrongly initialised neurons to be updated
    # can be done by updating the call function of the input vode
    mask_fixed = np.ones((batch_size, 28, 28))  # one if fixed zero if not

    # sample patches, patches are indexed from top left to bottom right
    n_pathches = 784 // patch_size**2
    n_patches_per_row = 28 // patch_size

    # p gives the proportion of input neurons that should be left uninitialised
    # for each image in the batch samples int(p * 784) pixels to be updatable
    # randomly select int(p * 784) in each image
    for i in range(batch_size):
        idxs = np.random.choice(n_pathches, int(p * n_pathches), replace=False)

        # fill in the selected patches with zeros
        for idx in idxs:
            row = idx // n_patches_per_row
            col = idx % n_patches_per_row
            mask_fixed[
                i,
                row * patch_size : (row + 1) * patch_size,
                col * patch_size : (col + 1) * patch_size,
            ] = 0

    return mask_fixed.reshape(batch_size, 784)


def main():
    # get up model
    up_config = Config()
    up_config.T = 0
    up_config.alpha_up = 1.0
    up_config.alpha_down = 0.0
    up_config.lr_p = 0.0005120529713738246
    up_config.weight_decay = 0.004344597680057968
    up_config.activation_fn = "gelu"
    up_config.is_up_initialisation_default = True

    seeds = [0, 1, 2, 3, 4]  #
    results_across_seeds = []
    for seed in seeds:
        up_config.seed = seed
        model_up, optim_h_up, _, is_up_initialisation = train_model_bp(up_config)

        _, _, test_dl = get_datax_mlp(up_config)

        results = []
        base_imgs = []
        masked_imgs = []
        reconstructions_bpc = []
        reconstructions_dpc = []

        ps = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]  #
        for p in ps:
            ## test the models robustness to removed pixels to the input mnist images
            # remake the test data with noise
            n_correct = {"ubp_ff": 0}
            n_total = 0
            for x, y in tqdm(test_dl):
                n_total += len(x)
                # mask y
                mask_fixed = make_mask(
                    p, up_config.batch_size, 1
                )  # one if fixed zero if not
                y_masked = y * mask_fixed

                # store index and values of pixels that are not masked (mask = 1)
                y = jnp.array(y)
                rows, cols = jnp.where(mask_fixed)
                values = y.at[(rows, cols)].get()

                ################### UPC ###################
                # initialise the model
                pred = fp_up(y_masked, model=model_up)
                n_correct["ubp_ff"] += np.sum(
                    np.argmax(pred, axis=1) == np.argmax(x, axis=1)
                )

            # convert each count to percentage
            n_correct = {k: v / n_total for k, v in n_correct.items()}
            results.append(n_correct)
        results_across_seeds.append(results)

    # average results across seeds
    results_mean = []
    results_sem = []
    for i in range(len(ps)):
        res = {
            k: np.mean([r[i][k] for r in results_across_seeds])
            for k in results_across_seeds[0][0].keys()
        }
        results_mean.append(res)
        res = {
            k: np.std([r[i][k] for r in results_across_seeds]) / np.sqrt(len(seeds))
            for k in results_across_seeds[0][0].keys()
        }
        results_sem.append(res)

    # save results_mean and results_sem
    np.save("missing_inputs_bp_mean.npy", results_mean)
    np.save("missing_inputs_bp_sem.npy", results_sem)


if __name__ == "__main__":
    main()
