import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import numpy as np
from train_mlp_pcax import main as train_model
from train_mlp_pcax_bp import main as train_model_bp
from utils_pcax.data import get_datax_mlp
from utils_pcax.eval import eval_fid_inception_latent, eval_supervised_specificity
from utils_pcax.models import OModel


import pcax.utils as pxu
import optax


class Config:
    def __init__(self):
        self.batch_size = 256
        self.is_supervised = True
        self.nm_epochs = 25
        self.activity_decay = 0.0
        self.gamma = 0
        self.h_var = 0.0
        self.activity_init = "ff"
        self.activity_init_kwargs = {"layer_var": 0.0}
        self.input_var = 1.0
        self.latent_dim = 10
        self.hidden_dim = 256
        self.data_dim = 784
        self.nm_layers = 4
        self.T = 8
        self.T_eval = 100
        self.dataset = "mnist"
        self.train_size = 60000
        self.val_size = 5000
        self.test_size = 5000
        self.is_wandb = False
        self.verbose = False
        self.epochs_per_val = 5
        self.make_mean_image = False
        self.is_post_activation = False
        self.is_shared_weights = False
        self.is_hybrid = False
        self.is_cnn = False
        self.is_free_latents = False
        self.free_latent_dim = None
        self.is_arbitrary_graph = False

        self.alpha_up = None
        self.alpha_down = None
        self.lr_x = None
        self.momentum = None
        self.lr_p = None
        self.weight_decay = None
        self.activation_fn = None
        self.is_up_initialisation_default = None

        self.seed = 0
        self.load_path = None
        self.save_path = None

        self.is_acc_init_up = True
        self.acc_init = "ff"
        self.is_rmse_init_up = False
        self.rmse_init = "ff"

        self.out_activation_fn = None


def main():
    for seed in np.arange(0, 3):
        save = True if seed == 0 else False
        ############################ BPC ############################
        # get bPC model
        bPC_config = Config()
        bPC_config.alpha_up = 1.0
        bPC_config.alpha_down = 0.0001
        bPC_config.lr_x = 0.01
        bPC_config.momentum = 0.0
        bPC_config.lr_p = 0.0001
        bPC_config.weight_decay = 0.005
        bPC_config.activation_fn = "leaky_relu"
        bPC_config.is_up_initialisation_default = True
        bPC_config.seed = seed
        model_bPC, optim_h, _, is_up_initialisation = train_model(bPC_config)
        optim_h_eval = pxu.Optim(optax.adam(0.001))

        # eval the models
        train_dl, val_dl, test_dl = get_datax_mlp(bPC_config)

        imgs, labels = eval_supervised_specificity(
            model_bPC,
            optim_h_eval,
            test_dl,
            is_up_initialisation,
            nm_samples=256,
            is_supervised=True,
            save=save,
            save_path="samples_gen_bpc.png",
        )

        imgs = np.squeeze(imgs) * 2 - 1  # scale back to [-1, 1]
        is_mean, fid, _ = eval_fid_inception_latent(
            imgs, bPC_config.dataset, val_dl, verbose=True, subset="val"
        )

        # ############################ PC ############################
        # get up model
        up_config = Config()
        up_config.alpha_up = 1.0
        up_config.alpha_down = 0.0
        up_config.lr_x = 0.01
        up_config.momentum = 0.0
        up_config.lr_p = 0.0001
        up_config.weight_decay = 0.005
        up_config.activation_fn = "leaky_relu"
        up_config.is_up_initialisation_default = True
        up_config.seed = seed

        model_up, optim_h_up, _, is_up_initialisation = train_model(up_config)

        # # get down model
        down_config = Config()
        down_config.alpha_up = 0.0
        down_config.alpha_down = 1.0
        down_config.lr_x = 0.01
        down_config.momentum = 0.0
        down_config.lr_p = 0.0001
        down_config.weight_decay = 0.005
        down_config.activation_fn = "leaky_relu"
        down_config.is_up_initialisation_default = False
        down_config.seed = seed
        model_down, optim_h_down, _, is_up_initialisation = train_model(down_config)

        # make combined model
        train_dl, val_dl, test_dl = get_datax_mlp(down_config)

        model_combined = OModel(model_up, model_down, alpha_up=1.0, alpha_down=1.0)

        imgs, labels = eval_supervised_specificity(
            model_combined,
            optim_h_eval,
            test_dl,
            is_up_initialisation,
            nm_samples=256,
            is_supervised=True,
            save=save,
            save_path="samples_gen_pc.png",
            is_recalibrate=True,
        )

        imgs = np.squeeze(imgs) * 2 - 1  # scale back to [-1, 1]
        is_mean, fid, _ = eval_fid_inception_latent(
            imgs, down_config.dataset, val_dl, verbose=True, subset="val"
        )

        ############################ BP ############################
        # get up model
        up_config = Config()
        up_config.T = 0
        up_config.alpha_up = 1.0
        up_config.alpha_down = 0.0
        up_config.lr_x = 0.01
        up_config.momentum = 0.0
        up_config.lr_p = 0.0001
        up_config.weight_decay = 0.005
        up_config.activation_fn = "leaky_relu"
        up_config.is_up_initialisation_default = True
        up_config.seed = seed
        model_up, optim_h_up, _, is_up_initialisation = train_model_bp(up_config)

        # # get down model
        down_config = Config()
        up_config.T = 0
        down_config.alpha_up = 0.0
        down_config.alpha_down = 1.0
        down_config.lr_x = 0.01
        down_config.momentum = 0.9
        down_config.lr_p = 0.0001
        down_config.weight_decay = 0.005
        down_config.activation_fn = "leaky_relu"
        down_config.is_up_initialisation_default = False
        down_config.seed = seed
        model_down, optim_h_down, _, is_up_initialisation = train_model_bp(down_config)

        # make combined model
        train_dl, val_dl, test_dl = get_datax_mlp(down_config)

        model_combined = OModel(model_up, model_down, alpha_up=1.0, alpha_down=1.0)

        imgs, labels = eval_supervised_specificity(
            model_combined,
            optim_h_eval,
            test_dl,
            is_up_initialisation,
            nm_samples=256,
            is_supervised=True,
            save=save,
            save_path="samples_gen_bp.png",
            is_recalibrate=True,
        )

        imgs = np.squeeze(imgs) * 2 - 1  # scale back to [-1, 1]
        is_mean, fid, _ = eval_fid_inception_latent(
            imgs, down_config.dataset, val_dl, verbose=True, subset="val"
        )


if __name__ == "__main__":
    main()
