import copy
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from tqdm import tqdm
import numpy as np
from train_mlp_pcax import main as train_model
from utils_pcax.data import get_datax_mlp
from utils_pcax.models import energy, initialisation

import pcax.utils as pxu
import pcax.predictive_coding as pxc
import optax
import jax.numpy as jnp
import jax
from utils_pcax.models import Model
import pcax.functional as pxf
import matplotlib.pyplot as plt


class Config:
    def __init__(self):
        self.batch_size = 256
        self.is_supervised = True
        self.nm_epochs = 25
        self.activity_decay = 0.0
        self.gamma = 0
        self.h_var = 0.0
        self.activity_init = "ff"
        self.activity_init_kwargs = {"layer_var": 0.0}
        self.input_var = 1.0
        self.latent_dim = 10
        self.hidden_dim = 256
        self.data_dim = 784
        self.nm_layers = 4
        self.T = 8
        self.T_eval = 100
        self.dataset = "fashion_mnist"
        self.train_size = 60000
        self.val_size = 5000
        self.test_size = 5000
        self.is_wandb = False
        self.verbose = True
        self.epochs_per_val = 5
        self.make_mean_image = False
        self.is_post_activation = False
        self.is_shared_weights = False
        self.is_hybrid = False
        self.is_cnn = False
        self.is_free_latents = False
        self.free_latent_dim = None
        self.is_arbitrary_graph = False

        self.alpha_up = None
        self.alpha_down = None
        self.lr_x = None
        self.momentum = None
        self.lr_p = None
        self.weight_decay = None
        self.activation_fn = None
        self.is_up_initialisation_default = None

        self.seed = 0
        self.load_path = None
        self.save_path = None

        self.is_acc_init_up = True
        self.acc_init = "ff"
        self.is_rmse_init_up = False
        self.rmse_init = "ff"

        self.out_activation_fn = None


@pxf.jit(static_argnums=(0, 3))
def infer_on_batch_no_init(
    T: int,
    x: jax.Array,
    y: jax.Array,
    mode: int,
    rows: jax.Array,
    cols: jax.Array,
    values: jax.Array,
    *,
    model: Model,
    optim_h: pxu.Optim,
):
    mode_mapping = {
        0: "constrained",
        1: "label-only",
        2: "data-only",
        3: "unconstrained",
    }
    mode = mode_mapping.get(mode, mode)

    def h_step(i, x, y, rows, cols, values, *, model, optim_h):
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            (e, (y_down, x_up)), g = pxf.value_and_grad(
                pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True), [False, True]),
                has_aux=True,
            )(energy)(x, y, model=model)
        optim_h.step(model, g["model"], True)
        model.vodes[-1].h._value = model.vodes[-1].h._value.at[rows, cols].set(values)
        return (x, y, rows, cols, values), None

    model.train()

    if mode == "constrained":
        model.vodes[0].h.frozen = True
        model.vodes[-1].h.frozen = True
    elif mode == "label-only":
        model.vodes[0].h.frozen = True
        model.vodes[-1].h.frozen = False
        y = None
        # is_up_initialisation = False
    elif mode == "data-only":
        model.vodes[0].h.frozen = False
        model.vodes[-1].h.frozen = True
        x = None
        # is_up_initialisation = True
    elif mode == "unconstrained":
        model.vodes[0].h.frozen = False
        model.vodes[-1].h.frozen = False
        x = None
        # y = None  # one should be kept for vmap but will be ignored because there is no init

    optim_h.init(pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True))(model))

    # Inference steps
    pxf.scan(h_step, xs=jax.numpy.arange(T))(
        x, y, rows, cols, values, model=model, optim_h=optim_h
    )

    optim_h.clear()

    # restore frozen states
    model.vodes[0].h.frozen = True
    model.vodes[-1].h.frozen = True
    return model.vodes[0].get("h"), model.vodes[-1].get("h")


def make_mask(p, batch_size, patch_size):
    # infer but allow only the wrongly initialised neurons to be updated
    # can be done by updating the call function of the input vode
    mask_fixed = np.ones((batch_size, 28, 28))  # one if fixed zero if not

    # sample patches, patches are indexed from top left to bottom right
    n_pathches = 784 // patch_size**2
    n_patches_per_row = 28 // patch_size

    # p gives the proportion of input neurons that should be left uninitialised
    # for each image in the batch samples int(p * 784) pixels to be updatable
    # randomly select int(p * 784) in each image
    for i in range(batch_size):
        idxs = np.random.choice(n_pathches, int(p * n_pathches), replace=False)

        # fill in the selected patches with zeros
        for idx in idxs:
            row = idx // n_patches_per_row
            col = idx % n_patches_per_row
            mask_fixed[
                i,
                row * patch_size : (row + 1) * patch_size,
                col * patch_size : (col + 1) * patch_size,
            ] = 0

    return mask_fixed.reshape(batch_size, 784)


def main():
    dPC_config = Config()
    dPC_config.lr_x = 0.019645968630162165
    dPC_config.momentum = 0.0
    dPC_config.lr_p = 0.00002990748921223216
    dPC_config.weight_decay = 0.006009370555227902
    dPC_config.activation_fn = "l-relu"
    dPC_config.is_up_initialisation_default = False
    dPC_config.is_acc_init_up = False
    dPC_config.acc_init = "ff"
    dPC_config.is_rmse_init_up = False
    dPC_config.rmse_init = "ff"
    dPC_config.alpha_down = 1.0
    dPC_config.alpha_up = 0.0
    # dPC_config.load_path = "results/models/dPC_mnist_acc"

    uPC_config = Config()
    uPC_config.lr_x = 0.002148790907428985
    uPC_config.momentum = 0.0
    uPC_config.lr_p = 0.0003391881333008285
    uPC_config.weight_decay = 0.00001420194109153724
    uPC_config.activation_fn = "gelu"
    uPC_config.is_up_initialisation_default = True
    uPC_config.is_acc_init_up = True
    uPC_config.acc_init = "ff"
    uPC_config.is_rmse_init_up = True
    uPC_config.rmse_init = "ff"
    uPC_config.alpha_down = 0.0
    uPC_config.alpha_up = 1.0
    # uPC_config.load_path = "results/models/uPC__mnist_acc"

    bPC_config = Config()
    bPC_config.lr_x = 0.0016600798552176934
    bPC_config.momentum = 0.9
    bPC_config.lr_p = 0.0004509387949330343
    bPC_config.weight_decay = 0.0021609607779110493
    bPC_config.activation_fn = "gelu"
    bPC_config.is_up_initialisation_default = True
    bPC_config.alpha_down = 0.0001
    bPC_config.alpha_up = 1.0
    # bPC_config.load_path = "results/models/bPC_mnist_acc"

    seeds = [0, 1, 2, 3, 4]  #
    results_across_seeds = []
    for seed in seeds:
        uPC_config.seed = seed
        bPC_config.seed = seed
        dPC_config.seed = seed

        model_uPC, _, _, _ = train_model(uPC_config)
        model_bPC, _, _, _ = train_model(bPC_config)
        model_dPC, _, _, _ = train_model(dPC_config)

        optim_h_eval = pxu.Optim(optax.adam(0.005))

        _, _, test_dl = get_datax_mlp(uPC_config)

        results = []
        base_imgs = []
        masked_imgs = []
        reconstructions_bpc = []
        reconstructions_dpc = []

        ps = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]  #
        for p in ps:
            ## test the models robustness to removed pixels to the input mnist images
            # remake the test data with noise
            n_correct = {"upc_ff": 0, "bpc_ff": 0, "bpc_inf": 0, "dpc_inf": 0}
            n_total = 0
            for x, y in tqdm(test_dl):
                n_total += len(x)
                # mask y
                mask_fixed = make_mask(
                    p, bPC_config.batch_size, 1
                )  # one if fixed zero if not
                y_masked = y * mask_fixed

                # store index and values of pixels that are not masked (mask = 1)
                y = jnp.array(y)
                rows, cols = jnp.where(mask_fixed)
                values = y.at[(rows, cols)].get()

                ################### UPC ###################
                # initialise the model
                with pxu.step(model_uPC, "init", clear_params=pxc.VodeParam.Cache):
                    initialisation(
                        None, y_masked, model=model_uPC, is_up_initialisation=True
                    )
                n_correct["upc_ff"] += np.sum(
                    np.argmax(model_uPC.vodes[0].get("h"), axis=1)
                    == np.argmax(x, axis=1)
                )

                ################### BPC ###################
                # repeat similar for bPC
                with pxu.step(model_bPC, "init", clear_params=pxc.VodeParam.Cache):
                    initialisation(
                        None, y_masked, model=model_bPC, is_up_initialisation=True
                    )
                n_correct["bpc_ff"] += np.sum(
                    np.argmax(model_bPC.vodes[0].get("h"), axis=1)
                    == np.argmax(x, axis=1)
                )

                infer_on_batch_no_init(
                    600000,
                    None,
                    y,
                    3,
                    rows,
                    cols,
                    values,
                    model=model_bPC,
                    optim_h=optim_h_eval,
                )
                n_correct["bpc_inf"] += np.sum(
                    np.argmax(model_bPC.vodes[0].get("h"), axis=1)
                    == np.argmax(x, axis=1)
                )

                ################### DPC ###################
                # repeat similar for bPC
                pseudo_input = jnp.zeros_like(x)
                with pxu.step(model_dPC, "init", clear_params=pxc.VodeParam.Cache):
                    initialisation(
                        pseudo_input,
                        y_masked,
                        model=model_dPC,
                        is_up_initialisation=False,
                    )

                infer_on_batch_no_init(
                    600000,
                    None,
                    y,
                    3,
                    rows,
                    cols,
                    values,
                    model=model_dPC,
                    optim_h=optim_h_eval,
                )
                n_correct["dpc_inf"] += np.sum(
                    np.argmax(model_dPC.vodes[0].get("h"), axis=1)
                    == np.argmax(x, axis=1)
                )

            # convert each count to percentage
            n_correct = {k: v / n_total for k, v in n_correct.items()}
            results.append(n_correct)
            if seed == 0:
                reconstructions_bpc.append(model_bPC.vodes[-1].get("h"))
                reconstructions_dpc.append(model_dPC.vodes[-1].get("h"))
                base_imgs.append(y)
                masked_imgs.append(y_masked)
        results_across_seeds.append(results)

    # average results across seeds
    results_mean = []
    results_sem = []
    for i in range(len(ps)):
        res = {
            k: np.mean([r[i][k] for r in results_across_seeds])
            for k in results_across_seeds[0][0].keys()
        }
        results_mean.append(res)
        res = {
            k: np.std([r[i][k] for r in results_across_seeds]) / np.sqrt(len(seeds))
            for k in results_across_seeds[0][0].keys()
        }
        results_sem.append(res)

    # save results_mean and results_sem
    np.save("missing_inputs_bpc_upc_dpc_mean_fmist.npy", results_mean)
    np.save("missing_inputs_bpc_upc_dpc_sem_fmist.npy", results_sem)

    # plot accuracy plots
    plt.figure(figsize=(4, 3))
    plt.plot(ps, [r["bpc_inf"] for r in results], label="bPC inference", color="C0")
    plt.fill_between(
        ps,
        [r["bpc_inf"] - r_sem["bpc_inf"] for r, r_sem in zip(results, results_sem)],
        [r["bpc_inf"] + r_sem["bpc_inf"] for r, r_sem in zip(results, results_sem)],
        color="C0",
        alpha=0.3,
    )
    plt.plot(
        ps,
        [r["bpc_ff"] for r in results],
        label="bPC feed-forward",
        color="C0",
        linestyle="--",
    )
    plt.fill_between(
        ps,
        [r["bpc_ff"] - r_sem["bpc_ff"] for r, r_sem in zip(results, results_sem)],
        [r["bpc_ff"] + r_sem["bpc_ff"] for r, r_sem in zip(results, results_sem)],
        color="C0",
        alpha=0.3,
    )
    plt.plot(ps, [r["upc_ff"] for r in results], label="uPC feed-forward", color="C1")
    plt.fill_between(
        ps,
        [r["upc_ff"] - r_sem["upc_ff"] for r, r_sem in zip(results, results_sem)],
        [r["upc_ff"] + r_sem["upc_ff"] for r, r_sem in zip(results, results_sem)],
        color="C1",
        alpha=0.3,
    )
    plt.plot(ps, [r["dpc_inf"] for r in results], label="dPC inference", color="C2")
    plt.fill_between(
        ps,
        [r["dpc_inf"] - r_sem["dpc_inf"] for r, r_sem in zip(results, results_sem)],
        [r["dpc_inf"] + r_sem["dpc_inf"] for r, r_sem in zip(results, results_sem)],
        color="C2",
        alpha=0.3,
    )
    plt.legend()
    plt.xlabel("ratio of missing pixels")
    plt.ylabel("accuracy")
    plt.ylim(0.0, 1)
    plt.tight_layout()
    plt.savefig("missing_inputs_bpc_upc_dpc_fmist.png")
    plt.close()

    # plot reconstructions
    n_imgs = 10
    fig, axs = plt.subplots(n_imgs, 1 + len(ps) * 3)
    for idx in range(n_imgs):
        axs[idx, 0].imshow(base_imgs[0][idx].reshape(28, 28), cmap="gray")
        for i, p in enumerate(ps):
            axs[idx, 3 * i + 1].imshow(masked_imgs[i][idx].reshape(28, 28), cmap="gray")
            axs[idx, 3 * i + 2].imshow(
                reconstructions_bpc[i][idx].reshape(28, 28), cmap="gray"
            )
            axs[idx, 3 * i + 3].imshow(
                reconstructions_dpc[i][idx].reshape(28, 28), cmap="gray"
            )
            axs[idx, 3 * i + 1].axis("off")
            axs[idx, 3 * i + 2].axis("off")
            axs[idx, 3 * i + 3].axis("off")
        axs[idx, 0].axis("off")
    plt.tight_layout()
    plt.savefig(f"missing_inputs_bpc_upc_dpc_reconstructions_fmist.svg")
    plt.close()


if __name__ == "__main__":
    main()
