#!#! EXPERIMENTS  #!#!
#! VAE with copula with NSF
#! Neural spline flow, maf and so on

import pickle as pkl
from time import time
import os
import pandas as pd
import jax.numpy as jnp
import torch
import math
import numpy as np
import seaborn as sns
from matplotlib import cm
from matplotlib import pyplot as plt
from sklearn.mixture import BayesianGaussianMixture
import scipy as sp
from sklearn.model_selection import train_test_split
from sklearn.linear_model import BayesianRidge
from jax.scipy.special import logsumexp

# from sklearn.datasets import fetch_openml
from tqdm import tqdm

import hydra
from omegaconf import DictConfig, OmegaConf
import wandb
import importlib

import sys
import os

current = os.path.dirname(os.path.realpath(__file__))
parent = os.path.dirname(current)
sys.path.append(parent)

utils = importlib.import_module("utils")

import utils


@hydra.main(config_path=utils.get_project_root() + "/conf", config_name="config")
def main(cfg: DictConfig) -> None:
    import data

    utils.seed_everything(cfg.base.seed)

    print(OmegaConf.to_yaml(cfg))

    wandb.init(
        project="arcopula",
        reinit=True,
        settings=wandb.Settings(start_method="thread"),
        config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
    )
    assert wandb.config.base["class"] + wandb.config.base["regress"] < 2

    # %%
    # region load data
    if wandb.config.base["class"] and cfg.data.dataset != "statlog":
        if cfg.data.dataset == "breast":
            data = pd.read_csv(
                f"{utils.get_project_root()}/datasets/wdbc.data", header=None,
            )

            data[data == "?"] = np.nan
            data.dropna(axis=0, inplace=True)
            y_data = data.iloc[:, 1].values  # convert strings to integer
            x_data = data.iloc[:, 2:,].values

            # set to binary
            y_data[y_data == "B"] = 0  # benign
            y_data[y_data == "M"] = 1  # malignant
            y_data = y_data.astype("int")

        elif cfg.data.dataset == "ionosphere":
            data = pd.read_csv(
                f"{utils.get_project_root()}/datasets/ionosphere_class.data",
                header=None,
            )
            y_data = data.iloc[:, 34].values  # convert strings to integer
            x_data = data.iloc[:, 0:34]
            x_data = x_data.drop(1, axis=1).values  # drop constant columns

            # set to binary
            y_data[y_data == "g"] = 1  # good
            y_data[y_data == "b"] = 0  # bad
            y_data = y_data.astype("int")

        elif cfg.data.dataset == "parkinsons":
            data = pd.read_csv(
                f"{utils.get_project_root()}/datasets/parkinsons_class.data"
            )
            data[data == "?"] = np.nan
            data.dropna(axis=0, inplace=True)
            y_data = data["status"].values  # convert strings to integer
            x_data = data.drop(columns=["name", "status"]).values

        elif cfg.data.dataset == "mnist":
            # y, lab = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
            # k = ckern.Conv(GPflow.kernels.RBF(25, ARD=self.run_settings['kernel_ard']), [28, 28],
            #                [5, 5]) + GPflow.kernels.White(1, 1e-3)
            from torchvision.datasets import MNIST

            train_data = MNIST(utils.get_data_root(), train=True, download=True)
            x_data = train_data.data.numpy().reshape((-1, 784))
            y_data = train_data.targets.numpy()
            y_data, x_data = y_data[y_data < 2], x_data[y_data < 2]

            trn_idx, val_idx = train_test_split(
                np.arange(len(y_data)), test_size=0.05, random_state=cfg.base.seed + 3
            )
            y_val = np.concatenate((x_data[val_idx], y_data[val_idx][:, None]), axis=1)

            x_data = x_data[trn_idx]
            y_data = y_data[trn_idx]
            test_data = MNIST(utils.get_data_root(), train=False, download=True)
            x_test = test_data.data.numpy().reshape((-1, 784))
            y_test = test_data.targets.numpy()
            y_test = np.concatenate(
                (x_test[y_test < 2], y_test[y_test < 2][:, None]), axis=1
            )

        else:
            print("Dataset doesn't exist")
            raise NotImplementedError

        y = np.concatenate((x_data, y_data[:, None]), axis=1)

    elif wandb.config.base["regress"]:
        if cfg.data.dataset == "concrete":
            data = pd.read_excel(
                f"{utils.get_project_root()}/datasets/Concrete_Data.xls"
            )
            y_data = data.iloc[:, 8].values
            x_data = data.iloc[:, 0:8].values
        elif cfg.data.dataset == "wine":
            data = pd.read_csv(
                f"{utils.get_project_root()}/datasets/winequality-red.csv", sep=";"
            )
            y_data = data.iloc[:, 11].values  # convert strings to integer
            x_data = data.iloc[:, 0:11].values
        elif cfg.data.dataset == "boston":
            from sklearn.datasets import load_boston

            x_data, y_data = load_boston(return_X_y=True)
        elif cfg.data.dataset == "diabetes":
            from sklearn.datasets import load_diabetes

            x_data, y_data = load_diabetes(return_X_y=True)
        else:
            print("Dataset doesn't exist")
            raise NotImplementedError

        y = np.concatenate((x_data, y_data[:, None]), axis=1)

    elif cfg.data.dataset == "gmm":
        y = data.load_gmm_data(
            cfg.data.n_data_points, cfg.data.d, cfg.data.K, cfg.base.seed
        )
    elif cfg.data.dataset == "image":
        train_dataset = data.load_face_dataset(
            name=cfg.data.dataset_file, num_points=cfg.data.n_data_points, reverse=True
        )
        y = train_dataset.data.cpu().numpy()
    elif cfg.data.dataset == "plane":
        train_dataset = data.load_plane_dataset(
            name=cfg.data.dataset_file, num_points=cfg.data.n_data_points
        )
        y = train_dataset.data.cpu().numpy()
    elif cfg.data.dataset in ["power", "gas", "hepmass", "miniboone", "bsds300"]:
        y, y_val, y_test = data.return_train_test_maf_data(cfg.data.dataset)
    elif cfg.data.dataset in ["breast", "wine", "ionosphere", "parkinsons"]:
        y = pd.read_csv(
            f"{utils.get_project_root()}/datasets/{cfg.data.dataset}.data", header=None
        )
        y = y.values
    elif cfg.data.dataset == "digits":
        from sklearn.datasets import load_digits

        # load the data
        y = load_digits().data
    elif cfg.data.dataset == "mnist":
        # y, lab = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
        # k = ckern.Conv(GPflow.kernels.RBF(25, ARD=self.run_settings['kernel_ard']), [28, 28],
        #                [5, 5]) + GPflow.kernels.White(1, 1e-3)
        from torchvision.datasets import MNIST

        train_data = MNIST(utils.get_data_root(), train=True, download=True)
        y = train_data.data.numpy().reshape((-1, 784))
        test_data = MNIST(utils.get_data_root(), train=False, download=True)
        y_test = test_data.data.numpy().reshape((-1, 784))
        y, y_val = train_test_split(y, test_size=0.05, random_state=42)
    elif cfg.data.dataset == "statlog":
        data = pd.read_csv(
            f"{utils.get_project_root()}/datasets/german.data",
            header=None,
            delim_whitespace=True,
        )
        data[data == "?"] = np.nan
        data.dropna(axis=0, inplace=True)
        y_data = data.iloc[:, 20].values  # convert strings to integer
        x_data = data.iloc[:, 0:20]

        # set to binary
        y_data[y_data == 1] = 0  # good
        y_data[y_data == 2] = 1  # bad
        y_data = y_data.astype("int")

        # convert to dummy
        d = np.shape(x_data)[1]
        types = np.array(
            [
                1.0,
                0.0,
                1.0,
                1.0,
                0.0,
                1.0,
                1.0,
                0.0,
                1.0,
                1.0,
                0.0,
                1.0,
                0.0,
                1.0,
                1.0,
                0.0,
                1.0,
                0.0,
                1.0,
                1.0,
            ]
        )
        for j in range(d):
            if types[j] == 1.0:
                x_data.iloc[:, j] = pd.factorize(x_data.iloc[:, j])[0]
        x_data = x_data.values
        y = np.concatenate((x_data, y_data[:, None]), axis=1)
    elif cfg.data.dataset == "boston":
        from sklearn.datasets import load_boston

        y = load_boston(return_X_y=False).data
    else:
        raise NotImplementedError

    if (
        cfg.data.dataset
        not in ["power", "gas", "hepmass", "miniboone", "bsds300", "mnist",]
        and not wandb.config.base["class"]
        and not wandb.config.base["regress"]
    ):
        y = utils.drop_corr_and_constant(y, threshold=0.98, return_columns=False)
        y, y_test = train_test_split(y, test_size=cfg.data.test_size, random_state=42)
        y_test, y_val = train_test_split(
            y_test, test_size=cfg.data.val_size, random_state=42
        )
    elif cfg.data.dataset == "mnist":
        to_drop = [
            0,
            1,
            2,
            3,
            4,
            5,
            6,
            7,
            8,
            9,
            10,
            11,
            15,
            16,
            17,
            18,
            19,
            20,
            21,
            22,
            23,
            24,
            25,
            26,
            27,
            28,
            29,
            30,
            31,
            33,
            52,
            53,
            54,
            55,
            56,
            57,
            82,
        ]
        # y = np.delete(y, to_drop, 1)
        to_drop = np.where(np.std(y, axis=0) == 0)[0]
        y = np.delete(y, to_drop, 1)
        y_val = np.delete(y_val, to_drop, 1)
        y_test = np.delete(y_test, to_drop, 1)

        norm_idx = y.shape[1] - wandb.config.base["class"]
        y[:, :norm_idx] = y[:, :norm_idx] + np.random.uniform(
            0, 1, size=y[:, :norm_idx].shape
        )
        y_val[:, :norm_idx] = y_val[:, :norm_idx] + np.random.uniform(
            0, 1, size=y_val[:, :norm_idx].shape
        )
        y_test[:, :norm_idx] = y_test[:, :norm_idx] + np.random.uniform(
            0, 1, size=y_test[:, :norm_idx].shape
        )
        y = y.astype("float")
        y_val = y_val.astype("float")
        y_test = y_test.astype("float")
        y[:, :norm_idx] = y[:, :norm_idx] / 256.0
        y_val[:, :norm_idx] = y_val[:, :norm_idx] / 256.0
        y_test[:, :norm_idx] = y_test[:, :norm_idx] / 256.0

        def logit_transform(x):
            return utils.logit(1e-10 + (1 - 2e-10) * x)

        y[:, :norm_idx] = logit_transform(y[:, :norm_idx])
        y_val[:, :norm_idx] = logit_transform(y_val[:, :norm_idx])
        y_test[:, :norm_idx] = logit_transform(y_test[:, :norm_idx])

    elif not (wandb.config.base["class"] or wandb.config.base["regress"]):
        y, (y_val, y_test) = utils.drop_corr_and_constant(
            y, y_test=[y_val, y_test], threshold=0.98, return_columns=False
        )
    else:
        y, y_test = train_test_split(
            y, test_size=cfg.data.test_size, random_state=cfg.base.seed
        )

        y_val = y_test.copy()

    y = y[: cfg.data.max_n]
    to_drop = np.where(np.std(y, axis=0) == 0)[0]
    y = np.delete(y, to_drop, 1)
    y_val = np.delete(y_val, to_drop, 1)
    y_test = np.delete(y_test, to_drop, 1)
    y = pd.DataFrame(y).dropna(axis=0, inplace=False).values
    # if cfg.data.dataset == "mnist":
    #     to_drop = np.where(np.std(y, axis=0) == 0)[0]
    #     y = np.delete(y, to_drop, 1)
    #     y_val = np.delete(y_val, to_drop, 1)
    #     y_test = np.delete(y_test, to_drop, 1)
    # else:
    #     y, (y_val, y_test) = utils.drop_corr_and_constant(
    #         y, y_test=[y_val, y_test], threshold=0.98, return_columns=False
    #     )

    norm_idx = y.shape[1] - wandb.config.base["class"]
    if cfg.data.init == "normal":
        mean_norm = np.mean(y[:, :norm_idx], axis=0)
        std_norm = np.std(y[:, :norm_idx], axis=0)  # np.std(y, axis=0)
        y[:, :norm_idx] = (y[:, :norm_idx] - mean_norm) / std_norm
        y_val[:, :norm_idx] = (y_val[:, :norm_idx] - mean_norm) / std_norm
        y_test[:, :norm_idx] = (y_test[:, :norm_idx] - mean_norm) / std_norm
        wandb.log({"correcting_constant": -np.log(std_norm).sum()})
    elif cfg.data.init == "uniform":
        min = np.min(y[:, :norm_idx], axis=0)
        max = np.max(y[:, :norm_idx], axis=0)
        y[:, :norm_idx] = (y[:, :norm_idx] - min) / (max - min)
        y_val[:, :norm_idx] = (y_val[:, :norm_idx] - min) / (max - min)
        y_test[:, :norm_idx] = (y_test[:, :norm_idx] - min) / (max - min)
    else:
        raise NotImplementedError

    # evaluate_classification_copula(
    #     y[:, -1].astype(int), y[:, :-1], y_test[:, -1].astype(int), y_test[:, :-1]
    # )

    # pd.DataFrame(((y- y[0])**2).sum(-1)).quantile(1e-3)

    wandb.config.data.update({"n_data_points": y.shape[0], "d": y.shape[1]})
    wandb.config.base.update(
        {"plot_density": (y.shape[1] == 2) and cfg.base.plot_density}
    )
    # wandb.config.data.update({"batching": y.shape[0] > y})

    if "net" in cfg.model.diff or cfg.model.model_class in ["maf", "rq-nsf"]:
        wandb.config.model.update({"scipy_opt": False})
    if "d_perm" in cfg.model:
        wandb.config.model.update(
            {
                "d_perm": cfg.model.d_perm
                if cfg.model.d_perm < math.factorial(y.shape[1])
                else math.factorial(y.shape[1])
            }
        )

    # endregion

    # %%
    if cfg.model.model_class == "copula":
        # region fit (or load) copula obj

        # region imports
        from models.main_copula_AR import (
            fit_copula_density,
            predict_copula_density,
            smc_sample_from_copula,
            sample_copula_density,
        )

        # endregion
        copula_kws = {
            k: v
            for k, v in wandb.config.model.items()
            if k
            in [
                "n_perm",
                "d_perm",
                "n_perm_optim",
                "n_optim",
                "maxiter",
                "bern",
                "init_rho",
                "init_length",
            ]
        }
        copula_kws["seed"] = cfg.base.seed
        model_keys = [
            l
            for l in list(cfg.model.keys())
            if l
            not in [
                "train_batch_size",
                "learning_rate",
                "num_flow_steps",
                "hidden_features",
                "num_bins",
                "num_transform_blocks",
                "dropout_probability",
                "monitor_interval2",
                "monitor_interval",
            ]
        ]
        model_name = utils.make_model_name(cfg, model_keys)
        model_path = f"{utils.get_checkpoint_root()}/{model_name}.pkl"
        density_path = f"{utils.get_checkpoint_root()}/{model_name}_{cfg.data.num_points_per_axis}_logdens.txt"

        if not cfg.base.reload and os.path.exists(model_path):
            from types import SimpleNamespace

            with open(model_path, "rb") as f:
                copula_density_obj = pkl.load(f)
            copula_density_obj = SimpleNamespace(**copula_density_obj)

        else:
            model_name_opt = utils.make_model_name(cfg, model_keys, ["n_data_points"])
            opt_path = f"{utils.get_checkpoint_root()}/{model_name_opt}_opt.pkl"
            copula_density_obj = fit_copula_density(
                y, y_val, opt_path=opt_path, **copula_kws
            )
            print("Preq loglik is {}".format(copula_density_obj.preq_loglik))

            with open(model_path, "wb") as f:
                pkl.dump(copula_density_obj._asdict(), f)

        wandb.log({"train loss": -copula_density_obj.preq_loglik})
        wandb.log({"model_path": model_path})
        # endregion

        if wandb.config.base["class"] == True:

            def predict_density(batch):
                test_scores = predict_copula_density(copula_density_obj, batch)[1][
                    :, -1
                ]
                test_scores = y_test[:, -1] * test_scores + (
                    1 - y_test[:, -1]
                ) * np.log(1 - jnp.exp(test_scores))
                return test_scores

        else:

            def predict_density(batch):
                return predict_copula_density(copula_density_obj, batch)[1][:, -1]

    elif cfg.model.model_class == "ori_copula":
        model_keys = list(cfg.model.keys())
        model_name = utils.make_model_name(cfg, model_keys)
        model_path = f"{utils.get_checkpoint_root()}/{model_name}.pkl"
        density_path = f"{utils.get_checkpoint_root()}/{model_name}_{cfg.data.num_points_per_axis}_logdens.txt"

        from models.old_copula import ori_main_copula_AR

        copula_density_obj = ori_main_copula_AR.fit_copula_density(
            y,
            seed=cfg.base.seed,
            n_perm_optim=cfg.model.n_perm_optim,
            n_optim=cfg.model.n_optim,
            n_perm=cfg.model.n_perm,
        )
        print("Bandwidth is {}".format(copula_density_obj.rho_lengths_opt))
        print("Preq loglik is {}".format(copula_density_obj.preq_loglik))

        def predict_density(batch):
            batch = jnp.array(batch)
            return ori_main_copula_AR.predict_copula_density(copula_density_obj, batch)[
                1
            ][:, -1]

    elif cfg.model.model_class == "ori_class_copula":
        model_keys = list(cfg.model.keys())
        model_name = utils.make_model_name(cfg, model_keys)
        model_path = f"{utils.get_checkpoint_root()}/{model_name}.pkl"
        density_path = f"{utils.get_checkpoint_root()}/{model_name}_{cfg.data.num_points_per_axis}_logdens.txt"

        from models.old_copula import ori_main_copula_AR_class as ori_main_copula_AR

        copula_density_obj = ori_main_copula_AR.fit_copula_classification(
            y[:, -1],
            y[:, :-1],
            seed=cfg.base.seed,
            n_perm=wandb.config.model["n_perm"],
            n_perm_optim=wandb.config.model["n_perm_optim"],  # None,
            single_x_bandwidth=False,
        )
        print("Bandwidth is {}".format(copula_density_obj.rho_opt))
        print("Preq loglik is {}".format(copula_density_obj.preq_loglik))

        if wandb.config.base["class"] == True:

            def predict_density(batch):
                test_scores = ori_main_copula_AR.predict_copula_density(
                    copula_density_obj, batch
                )[:, -1]
                test_scores = y_test[:, -1] * test_scores + (
                    1 - y_test[:, -1]
                ) * np.log(1 - jnp.exp(test_scores))
                return test_scores

        else:

            def predict_density(batch):
                return ori_main_copula_AR.predict_copula_density(
                    copula_density_obj, batch
                )[:, -1]

    # %%
    # region baselines
    elif cfg.model.model_class in ["maf", "rq-nsf"]:
        from experiments import run_nsf

        print(f"running {cfg.model.model_class}...")

        model_keys = [
            "train_batch_size",
            "learning_rate",
            "num_flow_steps",
            "hidden_features",
            "num_bins",
            "num_transform_blocks",
            "dropout_probability",
        ]

        model_name = utils.make_model_name(cfg, model_keys)
        model_path = f"{utils.get_checkpoint_root()}/{model_name}.pkl"
        density_path = f"{utils.get_checkpoint_root()}/{model_name}_{cfg.data.num_points_per_axis}_logdens.txt"

        if not cfg.base.reload and os.path.exists(model_path):
            from types import SimpleNamespace

            with open(model_path, "rb") as f:
                flow = pkl.load(f)

        else:
            flow = run_nsf.main(cfg, y, y_val, y_test)

            with open(model_path, "wb") as f:
                pkl.dump(flow, f)
                # pkl.dump(GMM._asdict(), f)

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        def predict_density(batch):
            with torch.no_grad():
                log_density = flow.log_prob(torch.Tensor(batch).to(device))
            return log_density.cpu().numpy()

    elif cfg.model.model_class == "bgmm":
        print("running GMM...")
        if y.shape[0] < cfg.model.n_components:
            wandb.config.model.update({"n_components": y.shape[0] // 2})
        model_keys = [
            "n_components",
            "covariance_type",
            "weight_concentration_prior_type",
            "n_init",
        ]

        model_name = utils.make_model_name(cfg, model_keys)
        model_path = f"{utils.get_checkpoint_root()}/{model_name}.pkl"
        density_path = f"{utils.get_checkpoint_root()}/{model_name}_{cfg.data.num_points_per_axis}_logdens.txt"

        if not cfg.base.reload and os.path.exists(model_path):
            from types import SimpleNamespace

            with open(model_path, "rb") as f:
                GMM = pkl.load(f)

        else:
            from sklearn.model_selection import GridSearchCV

            params = {"weight_concentration_prior": np.logspace(-40, 1, 40)}
            start = time()
            grid = GridSearchCV(
                BayesianGaussianMixture(
                    **{k: v for k, v in wandb.config.model.items() if k in model_keys},
                    verbose=False,
                    random_state=cfg.base.seed,
                ),
                params,
                cv=5,
            )
            grid.fit(y[:10000])
            print(
                "best weight_concentration_prior: {0}".format(
                    grid.best_estimator_.weight_concentration_prior
                )
            )
            end = time()
            print("time for BGMM = {}".format(end - start))
            wandb.log({"opt_time": end - start})
            GMM = grid.best_estimator_
            if len(y) > 10000:
                GMM.fit(y, verbose_interval=50, n_init=10)
            end = time()
            print("time for Bayesian GMM = {}".format(end - start))
            wandb.log({"fit_time": end - start})

            with open(model_path, "wb") as f:
                pkl.dump(GMM, f)
                # pkl.dump(GMM._asdict(), f)

        def predict_density(batch):
            return GMM.score_samples(batch)

    elif cfg.model.model_class == "kde":
        print("running KDE...")
        model_keys = []
        model_name = utils.make_model_name(cfg, model_keys)

        model_path = f"{utils.get_checkpoint_root()}/{model_name}.pkl"
        density_path = f"{utils.get_checkpoint_root()}/{model_name}_{cfg.data.num_points_per_axis}_logdens.txt"

        if not cfg.base.reload and os.path.exists(model_path):
            from types import SimpleNamespace

            with open(model_path, "rb") as f:
                KDE = pkl.load(f)

        else:
            from sklearn.neighbors import KernelDensity
            from sklearn.model_selection import GridSearchCV

            params = {"bandwidth": np.logspace(-1, 2, 80)}

            if True:  # cfg.model.bandwidth is None or cfg.model.bandwidth == "None":
                start = time()
                grid = GridSearchCV(KernelDensity(), params, cv=10)
                grid.fit(y[:10000])
                print("best bandwidth: {0}".format(grid.best_estimator_.bandwidth))
                end = time()
                print("time for KDE = {}".format(end - start))
                wandb.log({"fit_time": end - start})

                # use the best bandwidth to compute the kernel density estimate
                KDE = grid.best_estimator_
                if len(y) > 10000:
                    print("fitting full dataset...")
                    KDE.fit(y)

            else:
                KDE = KernelDensity(
                    kernel="gaussian", bandwidth=cfg.model.bandwidth
                ).fit(y)

            with open(model_path, "wb") as f:
                pkl.dump(KDE, f)
                # pkl.dump(GMM._asdict(), f)

        def predict_density(batch):
            return KDE.score_samples(batch)

    elif cfg.model.model_class == "gaussian":
        train_mean = y.mean(axis=0)
        train_cov = np.cov(y.T)

        model_keys = []
        model_name = utils.make_model_name(cfg, model_keys)
        density_path = f"{utils.get_checkpoint_root()}/{model_name}_{cfg.data.num_points_per_axis}_logdens.txt"

        def predict_density(batch):
            return sp.stats.multivariate_normal.logpdf(
                batch, mean=train_mean, cov=train_cov
            )

    elif wandb.config.base["class"]:

        if cfg.model.model_class == "gp":
            # GP
            from sklearn.gaussian_process.kernels import (
                RBF,
                ConstantKernel,
                WhiteKernel,
            )
            from sklearn.gaussian_process import GaussianProcessClassifier

            kernel = ConstantKernel() * RBF() + WhiteKernel()
            gp = GaussianProcessClassifier(kernel=kernel, n_restarts_optimizer=10).fit(
                y[:, :-1], y[:, -1]
            )

            def predict_density(batch):
                logp = np.log(gp.predict_proba(batch[:, :-1]))
                return batch[:, -1] * logp[:, 1] + (1 - batch[:, -1]) * logp[:, 0]

        elif cfg.model.model_class == "linear":

            from sklearn.linear_model import LogisticRegression
            from sklearn.model_selection import GridSearchCV

            from sklearn.model_selection import GridSearchCV

            params = {"C": np.logspace(-2, 5, 80)}
            start = time()
            grid = GridSearchCV(LogisticRegression(), params, cv=5,)
            grid.fit(y[:, :-1], y[:, -1])

            logreg = grid.best_estimator_
            logreg.fit(y[:, :-1], y[:, -1])

            def predict_density(batch):
                logp = logreg.predict_log_proba(batch[:, :-1])
                return batch[:, -1] * logp[:, 1] + (1 - batch[:, -1]) * logp[:, 0]

        elif cfg.model.model_class == "lightgbm":

            import lightgbm as lgb

            # Bayesian linear ridge regression
            linreg = lgb.LGBMClassifier()
            linreg.fit(y[:, :-1], y[:, -1])

            def predict_density(batch):
                logp = np.log(linreg.predict_proba(batch[:, :-1]))
                return batch[:, -1] * logp[:, 1] + (1 - batch[:, -1]) * logp[:, 0]

        elif cfg.model.model_class == "mlp":
            from sklearn.neural_network import MLPClassifier
            from sklearn.model_selection import GridSearchCV

            params = {"hidden_layer_sizes": [16, 32, 64, 128, 256]}
            start = time()
            grid = GridSearchCV(MLPClassifier(), params, cv=5,)
            grid.fit(y[:, :-1], y[:, -1])

            logreg = grid.best_estimator_
            logreg.fit(y[:, :-1], y[:, -1])

            def predict_density(batch):
                logp = np.clip(logreg.predict_log_proba(batch[:, :-1]), -1000, None)
                return batch[:, -1] * logp[:, 1] + (1 - batch[:, -1]) * logp[:, 0]

    elif wandb.config.base["regress"]:

        if cfg.model.model_class == "gp":
            # GP
            from sklearn.gaussian_process.kernels import (
                RBF,
                ConstantKernel,
                WhiteKernel,
            )
            from sklearn.gaussian_process import GaussianProcessRegressor

            kernel = ConstantKernel() * RBF() + WhiteKernel()
            gp = GaussianProcessRegressor(
                kernel=kernel, n_restarts_optimizer=10, normalize_y=True
            ).fit(y[:, :-1], y[:, -1])

            def predict_density(batch):
                mean_gp_test, std_gp_test = gp.predict(batch[:, :-1], return_std=True)
                return sp.stats.norm.logpdf(
                    batch[:, -1], loc=mean_gp_test, scale=std_gp_test
                )

        elif cfg.model.model_class == "linear":

            # Bayesian linear ridge regression
            from sklearn.linear_model import BayesianRidge
            from sklearn.model_selection import GridSearchCV

            start = time()
            params = {
                "alpha_init": [1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.9],
                "lambda_init": [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-9],
            }
            grid = GridSearchCV(BayesianRidge(solver="lbfgs"), params, cv=5,)
            grid.fit(y[:, :-1], y[:, -1])
            linreg = BayesianRidge(
                alpha_init=grid.best_estimator_.alpha_init,
                lambda_init=grid.best_estimator_.lambda_init,
            )
            linreg.fit(y[:, :-1], y[:, -1])

            def predict_density(batch):
                mean_lr, std_lr = linreg.predict(batch[:, :-1], return_std=True)
                return sp.stats.norm.logpdf(batch[:, -1], loc=mean_lr, scale=std_lr)

        elif cfg.model.model_class == "lightgbm":

            from ngboost import NGBRegressor

            # Bayesian linear ridge regression
            linreg = NGBRegressor()
            linreg.fit(y[:, :-1], y[:, -1])

            def predict_density(batch):
                Y_dists = linreg.pred_dist(batch[:, :-1])

                # test Negative Log Likelihood
                return -Y_dists.logpdf(batch[:, -1])

        elif cfg.model.model_class == "mlp":
            from sklearn.neural_network import MLPRegressor

            from sklearn.model_selection import GridSearchCV

            params = {"hidden_layer_sizes": [16, 32, 64, 128, 256]}
            start = time()
            grid = GridSearchCV(MLPRegressor(solver="lbfgs"), params, cv=5,)
            grid.fit(y[:, :-1], y[:, -1])
            mlps = []
            for _ in range(10):
                linreg = MLPRegressor(
                    solver="lbfgs",
                    max_iter=1000,
                    hidden_layer_sizes=(grid.best_estimator_.hidden_layer_sizes,),
                )
                linreg.fit(y[:, :-1], y[:, -1])
                mlps.append(linreg)

            def predict_density(batch):
                res = []
                for linreg in mlps:
                    res.append(linreg.predict(batch[:, :-1]))

                mean_lr, std_lr = (
                    np.stack(res, 1).mean(axis=1),
                    np.stack(res, 1).std(axis=1),
                )

                return sp.stats.norm.logpdf(batch[:, -1], loc=mean_lr, scale=1)

    else:
        raise NotImplementedError
    # endregion

    # %%
    # region test likelihood
    if wandb.config.base["val"]:
        start = time()
        print("validating...")
        start = time()
        if wandb.config.data["batching_test"]:
            log_density_np = []
            for i in range(len(y_val) // wandb.config.data["batch_size_test"]):
                batch = y_val[
                    wandb.config.data["batch_size_test"]
                    * i : (i + 1)
                    * wandb.config.data["batch_size_test"]
                ]
                log_density = predict_density(batch)
                log_density_np = np.concatenate((log_density_np, log_density))
            if len(y_val) % wandb.config.data["batch_size_test"] != 0:
                batch = y_val[-(len(y_val) % wandb.config.data["batch_size_test"]) :]
                log_density = predict_density(batch)
        else:
            log_density = predict_density(y_val)
        test_scores = log_density
        test_mean = test_scores.mean()
        test_ste = test_scores.std() / np.sqrt(len(test_scores))
        print(
            "val loglik for {} = {} \\pm {}".format(
                cfg.model.model_class, test_mean, test_ste,
            )
        )
        wandb.log({"val_loglik_mean": test_mean})
        wandb.log({"val_loglik_se": test_ste})
        wandb.log({"val_loglik_time": time() - start})
        print("time for test = {}".format(time() - start))

    if wandb.config.base["test"]:
        start = time()
        print("testing...")
        start = time()
        if wandb.config.base["choose_best_dperm"]:
            from models.main_copula_AR import predict_copula_density_per_perm

            logcdf, logpdf = predict_copula_density_per_perm(copula_density_obj, y_val)
            log_dens_per_perm = logpdf[..., -1].mean(0)
            best_perm_all = np.argmax(log_dens_per_perm.flatten())
            best_perm_d = np.argmax(logsumexp(logpdf[..., -1], 1))
            best_dens_all = np.max(log_dens_per_perm.flatten())
            best_dens_d = (
                (logsumexp(logpdf[..., -1], 1) - np.log(len(log_dens_per_perm[0])))
                .mean(0)
                .max()
            )

            logcdf, logpdf = predict_copula_density_per_perm(copula_density_obj, y_test)
            test_scores_bestd = logsumexp(logpdf[..., -1][:, best_perm_d], 1) - np.log(
                len(log_dens_per_perm[0])
            )
            test_scores_bestall = logpdf[..., -1].reshape(-1, len(y_test))[
                :, best_perm_all
            ]
            test_scores_all = logsumexp(
                logsumexp(logpdf[..., -1], 1) - np.log(len(log_dens_per_perm[0])), 1
            ) - np.log(len(log_dens_per_perm))
            wandb.log({"best_dens_d": test_scores_bestd.mean()})
            wandb.log({"best_dens_all": test_scores_bestall.mean()})
            wandb.log({"test_scores_all": test_scores_all.mean()})
            print({"best_dens_d": test_scores_bestd.mean()})
            print({"best_dens_all": test_scores_bestall.mean()})
            print({"test_scores_all": test_scores_all.mean()})
            if best_dens_d > best_dens_all:
                log_density = test_scores_bestd
            else:
                log_density = test_scores_bestall

        elif wandb.config.data["batching_test"]:
            best_perm_d = None
            log_density_np = []
            for i in range(len(y_test) // wandb.config.data["batch_size_test"]):
                batch = y_test[
                    wandb.config.data["batch_size_test"]
                    * i : (i + 1)
                    * wandb.config.data["batch_size_test"]
                ]
                log_density = predict_density(batch)
                log_density_np = np.concatenate((log_density_np, log_density))
            if len(y_test) % wandb.config.data["batch_size_test"] != 0:
                batch = y_test[-(len(y_test) % wandb.config.data["batch_size_test"]) :]
                log_density = predict_density(batch)
                log_density_np = np.concatenate((log_density_np, log_density))

        else:
            best_perm_d = None
            log_density = predict_density(y_test)
        test_scores = (
            log_density  #! - np.log(std_norm).sum()  # predict_density(y_test)
        )
        end = time()
        test_mean = test_scores.mean()
        test_ste = test_scores.std() / np.sqrt(len(test_scores))

        if cfg.data.dataset in ["mnist", "digits"] and not wandb.config.base["class"]:
            test_mean, test_ste = utils.calc_bits_per_pixel(
                test_mean, test_ste, y_test.shape[1]
            )

        print(
            "test loglik for {} = {} \\pm {}".format(
                cfg.model.model_class, test_mean, test_ste,
            )
        )
        wandb.log({"test_loglik_mean": test_mean})
        wandb.log({"test_loglik_se": test_ste})
        print("time for test = {}".format(end - start))
        wandb.log({"test_loglik_time": end - start})

    # endregion

    if wandb.config.base["plot_density"]:

        # region test grid data
        low, high = y.min(), y.max()
        bounds = np.array([[low + 1e-3, high - 1e-3], [low + 1e-3, high - 1e-3]])
        grid_dataset = data.TestGridDataset(
            num_points_per_axis=cfg.data.num_points_per_axis, bounds=bounds
        )
        grid_loader = torch.utils.data.DataLoader(
            dataset=grid_dataset,
            batch_size=int(5e4)
            if cfg.model.model_class != "copula"
            else int(cfg.data.num_points_per_axis ** 2),
            drop_last=False,
        )
        # endregion

        # region plot density
        print("plotting density...")

        # estimate/load density estimates
        if not cfg.base.reload and os.path.exists(density_path):
            from types import SimpleNamespace

            with open(density_path, "rb") as f:
                log_density_np = pkl.load(f)

        else:

            log_density_np = []
            for batch in tqdm(grid_loader):
                batch = batch.cpu().numpy()
                log_density = predict_density(batch)
                log_density_np = np.concatenate((log_density_np, log_density))

            with open(density_path, "wb") as f:
                pkl.dump(log_density_np, f)

        vmax = pd.DataFrame(np.exp(log_density_np)).quantile(0.999) * 1.1
        cmap = cm.magma

        # plot density
        # plt.rcParams.update({"font.size": 26})
        figure, axes = plt.subplots(1, 1, figsize=(2.5, 2.5))
        axes.pcolormesh(
            grid_dataset.X,
            grid_dataset.Y,
            np.exp(log_density_np).reshape(grid_dataset.X.shape),
            cmap=cmap,
            vmin=0,
            vmax=vmax,
        )
        axes.set_xlim(bounds[0])
        axes.set_ylim(bounds[1])
        axes.set_xticks([])
        axes.set_yticks([])
        # axes.title.set_text("R$_d$-BP $_{(-2.32{\pm .02})}$")
        # plt.tight_layout()
        path = os.path.join(
            utils.get_output_root(), "{}-density.png".format(model_name)
        )
        plt.savefig(path, bbox_inches="tight", pad_inches=0, dpi=300)
        plt.show()
        plt.close()
        wandb.log({"density estimate": wandb.Image(path)})
        # endregion

        # region plot data
        print("plotting data...")

        figure, axes = plt.subplots(1, 1, figsize=(2.5, 2.5))
        axes.hist2d(
            y[:, 0],
            y[:, 1],
            range=bounds,
            bins=64,
            cmap=cmap,
            rasterized=False,
            density=True,
            vmax=vmax,  # * was 1.1 * vmax
        )
        axes.set_xlim(bounds[0])
        axes.set_ylim(bounds[1])
        axes.set_xticks([])
        axes.set_yticks([])
        if cfg.base.dens_title is not None:
            axes.title.set_text(cfg.base.dens_title)
        # plt.tight_layout()
        path = os.path.join(utils.get_output_root(), "{}-data.png".format(model_name))
        plt.savefig(path, bbox_inches="tight", pad_inches=0, dpi=300)
        plt.show()
        plt.close()
        wandb.log({"data": wandb.Image(path)})
        # endregion

    # %%
    if cfg.base.sample:
        # region sampling
        from models.main_copula_AR import smc_sample_from_copula
        from models.copula_AR_functions import init_marginals_perm

        num_samples = cfg.sample.num_samples

        start = time()
        if not cfg.sample.smc:
            y_sampled, err, n_iter = sample_copula_density(
                copula_density_obj,
                cfg.sample.num_samples,
                seed=cfg.base.seed,
                best_d=best_perm_d,
            )
        else:
            if cfg.data.init == "uniform":
                init_samples = np.random.random_sample(
                    (
                        wandb.config.model["d_perm"],
                        cfg.model.n_perm,
                        num_samples,
                        y.shape[1],
                    )
                )
            elif cfg.data.init == "normal":
                init_samples = np.random.normal(
                    size=(
                        wandb.config.model["d_perm"],
                        cfg.model.n_perm,
                        num_samples,
                        y.shape[1],
                    )
                )

            y_sampled, logpdf_joints_sampled, n_resampl, ess = smc_sample_from_copula(
                copula_density_obj, init_samples=init_samples
            )
            y_sampled = y_sampled.reshape(-1, y.shape[1])
        end = time()
        print("Sampling time: {}s".format(round(end - start, 3)))
        wandb.log({"Sampling time": round(end - start, 3)})

        sample_path = f"{utils.get_checkpoint_root()}/{model_name}_{cfg.sample.smc}_{num_samples}.pkl"
        with open(sample_path, "wb") as f:
            pkl.dump(y_sampled, f)

        fig, axs = plt.subplots(1, 2, figsize=(16, 4))

        for ax, name, y_arr in zip(
            axs.flatten(), ["true", "init", "sampled"], [y, init_samples, y_sampled],
        ):
            ax.scatter(y_arr[:, 0], y_arr[:, 1], s=4)
            sns.kdeplot(
                y_arr[:1000, 0],
                y_arr[:1000, 1],
                cbar=True,
                shade=True,
                ax=ax,
                alpha=0.8,
            )
            ax.set_xlabel("$y_1$")
            ax.set_ylabel("$y_2$")
            ax.set_title(name)
            ax.set_ylim([y.min(), y.max()])
            ax.set_xlim([y.min(), y.max()])

        path = os.path.join(
            utils.get_output_root(), "{}-{}sampl.png".format(model_name, num_samples)
        )
        plt.savefig(path, bbox_inches="tight", pad_inches=0, dpi=300)
        wandb.log({"samples": wandb.Image(path)})

        print(sample_path)

    if cfg.base.impute:
        # region sampling
        num_imps = (
            wandb.config.model["d_perm"] * cfg.model.n_perm * cfg.impute.num_samples
        )

        from models.main_copula_AR import impute_from_copula
        from models.copula_AR_functions import init_marginals_perm

        np.random.seed(cfg.base.seed)
        miss_mask = np.random.binomial(1, cfg.impute.miss_perc, size=y_test.shape)
        y_toimpute = y_test[np.sum(miss_mask, axis=1) > 0]
        miss_mask_toimpute = miss_mask[np.sum(miss_mask, axis=1) > 0]

        # filename = f"{utils.get_project_root()}/tmp/{model_name}_{cfg.impute.num_samples}_{cfg.impute.miss_perc}_ymiss.npy"
        # y_miss = np.memmap(
        #     filename,
        #     dtype="float32",
        #     mode="w+",
        #     shape=(
        #         y_toimpute.shape[0],
        #         cfg.model.d_perm,
        #         cfg.model.n_perm,
        #         cfg.impute.num_samples,
        #         y_toimpute.shape[1],
        #     ),
        # )
        y_miss_plot = np.zeros((10 * y_toimpute.shape[0], y_toimpute.shape[1]))
        imp_error = 0
        for i, this_y in tqdm(enumerate(y_toimpute)):

            num_samples = cfg.impute.num_samples
            if cfg.data.init == "uniform":
                init_samples = np.random.random_sample(
                    (
                        wandb.config.model["d_perm"],
                        cfg.model.n_perm,
                        num_samples,
                        y.shape[1],
                    )
                )
            elif cfg.data.init == "normal":
                init_samples = np.random.normal(
                    size=(
                        wandb.config.model["d_perm"],
                        cfg.model.n_perm,
                        num_samples,
                        y.shape[1],
                    )
                )

            y_miss = init_samples * miss_mask_toimpute[i] + this_y * (
                1 - miss_mask_toimpute[i]
            )

            y_sampled, logpdf_joints_sampled, n_resampl, ess = smc_sample_from_copula(
                copula_density_obj, init_samples=y_miss, best_d=best_perm_d
            )
            y_miss_plot[i * 10 : (i + 1) * 10] = y_sampled.reshape(
                num_imps, y.shape[1]
            )[np.random.choice(num_imps, 10)]

            # if i % 100 == 0:
            #     y_miss.flush()

            imp = y_sampled.mean(0).mean(0).mean(0)
            imp_error += np.sum(miss_mask_toimpute[i] * (imp - this_y) ** 2) / np.prod(
                imp.shape
            )

        sample_path = f"{utils.get_checkpoint_root()}/{model_name}_{cfg.sample.smc}_{num_samples}.pkl"
        with open(sample_path, "wb") as f:
            pkl.dump(y_sampled, f)

        wandb.log({"impute error": imp_error / y_toimpute.shape[0]})

        # y_miss = np.memmap(
        #     filename,
        #     dtype="float32",
        #     mode="r",
        #     shape=(
        #         y_toimpute.shape[0],
        #         wandb.config.model['d_perm'],
        #         cfg.model.n_perm,
        #         cfg.impute.num_samples,
        #         y_toimpute.shape[1],
        #     ),
        # )

        y_nan = np.where(miss_mask_toimpute, np.nan, y_toimpute)
        imp_error = np.sum(miss_mask_toimpute * (y.mean(0) - y_toimpute) ** 2)
        wandb.log({"mean impute error": imp_error / np.prod(y_toimpute.shape)})

        from sklearn.experimental import enable_iterative_imputer  # noqa
        from sklearn.impute import IterativeImputer
        from sklearn.linear_model import BayesianRidge

        imp_error = 0
        br_estimator = BayesianRidge()
        imp = IterativeImputer(
            max_iter=200,
            random_state=cfg.base.seed + i,
            estimator=br_estimator,
            sample_posterior=False,
            # sample_posterior=True,
        )
        imp.fit(y)
        imp_error = np.sum(
            miss_mask_toimpute * (imp.transform(y_nan) - y_toimpute) ** 2
        ) / np.prod(y_nan.shape)
        # for i in tqdm(range(num_imps // 10)):
        #     y_sampled = imp.transform(y_nan)
        #     imp_error += np.sum(
        #         miss_mask_toimpute * (y_sampled - y_toimpute) ** 2
        #     ) / np.prod(y_sampled.shape)

        wandb.log({"MICE impute error": imp_error})

        fig, axs = plt.subplots(1, 2, figsize=(14, 4))

        for ax, name, y_arr in zip(
            axs.flatten(), ["true", "imputed"], [y_toimpute, y_miss_plot]
        ):
            ax.scatter(y_arr[:, 0], y_arr[:, 1], s=4)
            sns.kdeplot(
                y_arr[:1000, 0],
                y_arr[:1000, 1],
                cbar=True,
                shade=True,
                ax=ax,
                alpha=0.8,
            )
            ax.set_xlabel("$y_1$")
            ax.set_ylabel("$y_2$")
            ax.set_title(name)
            ax.set_ylim([y.min(), y.max()])
            ax.set_xlim([y.min(), y.max()])

        path = os.path.join(
            utils.get_output_root(), "{}-{}sampl.png".format(model_name, num_samples)
        )
        os.makedirs(os.path.join(utils.get_project_root(), "out"), exist_ok=True)
        plt.savefig(path, bbox_inches="tight", pad_inches=0, dpi=300)
        wandb.log({"imputed": wandb.Image(path)})

        # endregion


if __name__ == "__main__":
    main()
