import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

from tqdm import tqdm
import random


import jax
import jax.numpy as jnp
import optax

import pcax.predictive_coding as pxc
import pcax.nn as pxnn
import pcax.functional as pxf
import pcax.utils as pxu


from utils_pcax.data import (
    compute_energy_landscape,
    create_grid,
    main_energy_landscape_function,
    organize_batches,
    plot_energy_landscape,
)
from utils_pcax.utils import set_seed, sgdld
from utils_pcax.models import Model, initialisation, infer_on_batch, energy

import numpy as np
import matplotlib.pyplot as plt


@pxf.jit(static_argnums=(0, 3))
def train_on_batch(
    T: int,
    x: jax.Array,
    y: jax.Array,
    is_up_initialisation: bool,
    *,
    model: Model,
    optim_w: pxu.Optim,
    optim_h: pxu.Optim,
):

    model.train()

    infer_on_batch(T, x, y, is_up_initialisation, 0, model=model, optim_h=optim_h)

    # Learning step
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        (e, y_), g = pxf.value_and_grad(
            pxu.Mask(pxnn.LayerParam, [False, True]), has_aux=True
        )(energy)(x, y, model=model)
    optim_w.step(model, g["model"])
    return e


def train(
    dl,
    T,
    is_up_initialisation,
    *,
    model: Model,
    optim_w: pxu.Optim,
    optim_h: pxu.Optim,
    verbose: bool = False,
):
    dl = tqdm(dl, desc="Energy: ") if verbose else dl
    for x, y in dl:
        e = train_on_batch(
            T, x, y, is_up_initialisation, model=model, optim_w=optim_w, optim_h=optim_h
        )
        if verbose:
            dl.set_description_str(f"Energy: {e:.3f}")
    return e


class BaseConfig:
    def __init__(self):
        self.batch_size = 256
        self.is_supervised = True
        self.nm_epochs = 1000
        self.lr = 0.01
        self.momentum = 0.0
        self.activity_decay = 0.0
        self.gamma = 0
        self.h_var = 0.0
        self.activity_init = "ff"
        self.activity_init_kwargs = {"layer_var": 0.0}

        self.lr_p = 0.0002
        self.weight_decay = 0.001
        self.is_shared_weights = False
        self.input_var = 1.0
        self.activation = "leaky_relu"
        self.latent_dim = 1
        self.hidden_dim = 16
        self.data_dim = 2
        self.nm_layers = 4
        self.T = 8
        self.T_eval = 10000

        self.verbose = True


def get_configs(mode="bpc"):
    if mode == "bpc":
        pass


def get_model(config, alpha_up, alpha_down):
    return Model(
        input_dim=config.latent_dim,
        hidden_dim=config.hidden_dim,
        output_dim=config.data_dim,
        nm_layers=config.nm_layers,
        activation=config.activation,
        input_var=config.input_var,
        alpha_up=alpha_up,
        alpha_down=alpha_down,
        is_supervised=config.is_supervised,
        is_shared_weights=config.is_shared_weights,
        activity_init=config.activity_init,
        activity_init_kwargs=config.activity_init_kwargs,
        out_activation_down=None,
        out_activation_up="sigmoid",
    )


def main():
    # seed = np.arange(8, 30)
    seed = [1]
    for s in seed:
        print(f"Seed: {s}")
        set_seed(s)

        config = BaseConfig()
        verbose = config.verbose
        nm_epochs = config.nm_epochs
        batch_size = config.batch_size
        lr = config.lr
        momentum = config.momentum
        activity_decay = config.activity_decay
        gamma = 0
        h_var = config.h_var
        lr_p = config.lr_p
        weight_decay = config.weight_decay
        latent_dim = config.latent_dim
        T = config.T

        alpha_up = 1.0
        alpha_down = 1.0
        is_up_initialisation = True

        model = get_model(config, alpha_up, alpha_down)
        h_optimiser_fn = sgdld
        optim_h = pxu.Optim(h_optimiser_fn(lr, momentum, h_var, gamma))
        with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
            initialisation(
                jax.numpy.zeros((batch_size, latent_dim)),
                jax.numpy.zeros((batch_size, 2)),
                model=model,
                is_up_initialisation=is_up_initialisation,
            )
            optim_w = pxu.Optim(
                optax.adamw(lr_p, weight_decay=weight_decay),
                pxu.Mask(pxnn.LayerParam)(model),
            )
            optim_w_down = pxu.Optim(
                optax.adamw(lr_p / 3, weight_decay=weight_decay),
                pxu.Mask(pxnn.LayerParam)(model),
            )

        # make xor dataloaders
        y = jnp.array([[-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=jnp.float32)
        x = jnp.array([[0], [1], [1], [0]], dtype=jnp.float32)
        # make into batches of size batch_size
        x = jnp.tile(x, (batch_size // 4, 1))
        y = jnp.tile(y, (batch_size // 4, 1))
        train_dl = [(x, y) for i in range(10)]

        dl = tqdm(range(nm_epochs), desc="Energy: ") if verbose else range(nm_epochs)
        for i in dl:
            random.shuffle(train_dl)
            e = train(
                train_dl,
                T=T,
                is_up_initialisation=is_up_initialisation,
                model=model,
                optim_w=optim_w,
                optim_h=optim_h,
                verbose=False,
            )
            x, y = train_dl[0]
            if verbose:
                dl.set_description_str(f"Energy: {e:.3f}")

        # plot energy landscape of bpc
        main_energy_landscape_function(
            model, x, y, batch_size, optim_h, T_=10000, init_up=True, plot=False
        )
        plt.savefig(f"bPC_xor_energy_landscape.png")
        plt.close()

        # plot energy for up and down seperately
        y_grd, xx, yy = create_grid()
        y_grd_dl = organize_batches(y_grd, batch_size)

        # retrain up only and plot energy landscape
        model_up = get_model(config, 1.0, 0.0)
        dl = tqdm(range(nm_epochs), desc="Energy: ") if verbose else range(nm_epochs)
        for i in dl:
            random.shuffle(train_dl)
            e = train(
                train_dl,
                T=T,
                is_up_initialisation=True,
                model=model_up,
                optim_w=optim_w,
                optim_h=optim_h,
                verbose=False,
            )
            x, y = train_dl[0]
            if verbose:
                dl.set_description_str(f"Energy: {e:.3f}")

        _, _, energy_0_up, _, _, energy_1_up = compute_energy_landscape(
            y_grd, y_grd_dl, batch_size, 10000, model_up, optim_h, True
        )

        # retrain down only and plot energy landscape
        model_down = get_model(config, 0.0, 1.0)
        dl = tqdm(range(nm_epochs), desc="Energy: ") if verbose else range(nm_epochs)
        for i in dl:
            random.shuffle(train_dl)
            e = train(
                train_dl,
                T=T,
                is_up_initialisation=False,
                model=model_down,
                optim_w=optim_w_down,
                optim_h=optim_h,
                verbose=False,
            )
            x, y = train_dl[0]
            if verbose:
                dl.set_description_str(f"Energy: {e:.3f}")

        _, energy_0_down, _, _, energy_1_down, _ = compute_energy_landscape(
            y_grd, y_grd_dl, batch_size, 10000, model_down, optim_h, False
        )

        # zero energy will make plot fail - put all zeros to smallest value except 0
        energy_0_up = jnp.where(
            energy_0_up == 0, energy_0_up[energy_0_up != 0].min(), energy_0_up
        )
        energy_1_up = jnp.where(
            energy_1_up == 0, energy_1_up[energy_1_up != 0].min(), energy_1_up
        )
        energy_0_down = jnp.where(
            energy_0_down == 0, energy_0_down[energy_0_down != 0].min(), energy_0_down
        )
        energy_1_down = jnp.where(
            energy_1_down == 0, energy_1_down[energy_1_down != 0].min(), energy_1_down
        )

        # calculate energy landscape for up and down seperately
        energy_0 = energy_0_up * alpha_up + energy_0_down * alpha_down
        energy_1 = energy_1_up * alpha_up + energy_1_down * alpha_down

        fig, axs = plt.subplots(3, 2, sharex=True, sharey=True)
        plot_data = True
        energy_threshold = None
        plot_energy_landscape(
            xx,
            yy,
            energy_0,
            y,
            x,
            "Energy label 0",
            axs[0, 0],
            plot_data=plot_data,
            energy_threshold=energy_threshold,
        )
        plot_energy_landscape(
            xx,
            yy,
            energy_1,
            y,
            x,
            "Energy label 1",
            axs[0, 1],
            plot_data=plot_data,
            energy_threshold=energy_threshold,
        )
        plot_energy_landscape(
            xx, yy, energy_0_up, y, x, "", axs[1, 0], plot_data=plot_data
        )
        plot_energy_landscape(
            xx, yy, energy_1_up, y, x, "", axs[1, 1], plot_data=plot_data
        )
        axs[1, 0].set_ylabel("Energy up")
        plot_energy_landscape(
            xx, yy, energy_0_down, y, x, "", axs[2, 0], plot_data=plot_data
        )
        plot_energy_landscape(
            xx, yy, energy_1_down, y, x, "", axs[2, 1], plot_data=plot_data
        )
        axs[2, 0].set_ylabel("Energy down")
        plt.savefig(f"discPC+genPC_xor_energy_landscape.png")
        plt.close()


if __name__ == "__main__":
    main()
