import os
import pickle
import pickle as pkl

import numpy as np
import pandas as pd

import jax.numpy as jnp
import wandb

import random

import jax
from jax import lax, jit
from jax.experimental import host_callback
import collections
import warnings


def print_table(table, rows=None, cols=None):
    if rows is None:
        rows = table.index
    if cols is None:
        cols = table.columns
    table_str = """
    \\begin{table}[]
    \\centering"""
    table_str += table.loc[rows, cols].to_latex(escape=False)
    table_str += """
        \\caption{Caption}
        \\label{tab:my_label}
    \\end{table}
    """
    print(table_str)


def make_table_from_metric(
    metric,
    runs_df,
    val_metric=None,
    pm_metric="sem",
    data_name="dataset",
    latex=False,
    bold=True,
    show_group=False,
):
    if val_metric is None:
        val_metric = metric

    results = (
        runs_df.groupby(by=["method", data_name])
        .agg({metric: ["mean", pm_metric], val_metric: ["mean", "std", "sem"],})
        .reset_index()
    )
    # group_max_idx = (
    #     results.groupby(by=["method", data_name]).transform(max)[val_metric]["mean"]
    #     == results[val_metric]["mean"]
    # )
    table = results  # [group_max_idx]

    # table = table[table[data_name].isin(["Earthquake", "Fire", "Flood", "Volcano"])]

    if latex:

        def format_result(row):
            return (
                f"{{{row[metric]['mean']:0.2f}_{{\pm {row[metric][pm_metric]:0.2f}}}}}"
            )

        def bold_result(row):
            return "\\bm" + row["result"] if row["bold"].any() else row["result"]

    else:

        def format_result(row):
            return f"{row[metric]['mean']:0.2f} ± {row[metric][pm_metric]:0.2f}"

        def bold_result(row):
            return "* " + row["result"] if row["bold"].any() else row["result"]

    table["bold"] = (
        table.groupby(by=[data_name]).transform(max)[metric]["mean"]
        == table[metric]["mean"]
    )

    table["result"] = table.apply(format_result, axis=1)
    if bold:
        table["result"] = table.apply(bold_result, axis=1)

    if latex:
        table["result"] = table.apply(lambda row: "$" + row["result"] + "$", axis=1)

    cols = (
        ["method", data_name, "group"]
        if show_group
        else ["method", data_name, "result"]
    )

    table_flat = table[cols].pivot(index="method", columns=data_name)

    table_flat = table_flat.droplevel(level=0, axis=1)
    table_flat = table_flat.droplevel(level=0, axis=1)
    table_flat.columns.name = None
    table_flat.index.name = None

    return table_flat


def rename_cols_for_pd_wide_to_long(col_names):
    new_column_names = [c for c in col_names]
    new_column_names = [
        c
        if " classification AUC" not in c
        else c.replace(" classification AUC", "_classification_AUC")
        for c in new_column_names
    ]
    new_column_names = [
        c if " test " not in c else c.replace(" test ", " test_")
        for c in new_column_names
    ]
    new_column_names = [
        c if "test" not in c else ";".join(c.split(" ")[::-1]) for c in new_column_names
    ]
    new_column_names = [
        c if "_classification_AUC" not in c else ";".join(c.split(" ")[::-1])
        for c in new_column_names
    ]
    new_column_names = [
        c if "time" not in c else ";".join(c.split(" ")[::-1]) for c in new_column_names
    ]
    return new_column_names


def print_df_duplicates(df, columns):
    return df[df.duplicated(subset=columns, keep=False)]


def slice_df(df, value_dict):
    for k, v in value_dict.items():
        df = df[df[k].isin(v)]
    return df


def load_or_run(path, fun, args):
    if not wandb.config.base["reload"] and os.path.exists(path):

        with open(path, "rb") as f:
            res = pkl.load(f)

    else:
        res = fun(*args)

        with open(path, "wb") as f:
            pkl.dump(res, f)

    return res


def calc_bits_per_pixel(lp, lp_err, n_dims):

    Dlog2 = n_dims * np.log(2)
    bpp = -lp / Dlog2 + 8
    bpp_err = lp_err / Dlog2

    return bpp, bpp_err


def logistic(x):
    """
    Elementwise logistic sigmoid.
    :param x: numpy array
    :return: numpy array
    """
    return 1.0 / (1.0 + np.exp(-x))


def one_hot_encode(labels, n_labels):
    """
    Transforms numeric labels to 1-hot encoded labels. Assumes numeric labels are in the range 0, 1, ..., n_labels-1.
    """

    assert np.min(labels) >= 0 and np.max(labels) < n_labels

    y = np.zeros([labels.size, n_labels])
    y[range(labels.size), labels] = 1

    return y


def _print_consumer(arg, transform):
    iter_num, num_samples = arg
    print(f"Iteration {iter_num:,} / {num_samples:,}")


@jit
def progress_bar(arg, result):
    """
    Print progress of a scan/loop only if the iteration number is a multiple of the print_rate
    Usage: `carry = progress_bar((iter_num + 1, num_samples, print_rate), carry)`
    Pass in `iter_num + 1` so that counting starts at 1 and ends at `num_samples`
    """
    iter_num, num_samples, print_rate = arg
    result = lax.cond(
        iter_num % print_rate == 0,
        lambda _: host_callback.id_tap(
            _print_consumer, (iter_num, num_samples), result=result
        ),
        lambda _: result,
        operand=None,
    )
    return result


def progress_bar_scan(num_samples):
    """
    Decorator that adds a progress bar to `body_fun` used in `lax.scan`. 
    Note that `body_fun` must be looping over `jnp.arange(num_samples)`.
    This means that `iter_num` is the current iteration number
    """

    def _progress_bar_scan(func):
        print_rate = (
            wandb.config.base["print_rate"]
            if wandb.config.base["print_rate"] != "None"
            else (5000 if wandb.config.model["diff"] == "net" else 50000)
        )

        def wrapper_progress_bar(carry, iter_num):
            iter_num = progress_bar((iter_num + 1, num_samples, print_rate), iter_num)
            return func(carry, iter_num)

        return wrapper_progress_bar

    return _progress_bar_scan


# def progress_bar_scan(num_samples, message=None):
#     "Progress bar for a JAX scan"
#     if message is None:
#         message = f"Running for {num_samples:,} iterations"
#     tqdm_bars = {}

#     if num_samples > 20:
#         print_rate = int(num_samples / 20)
#     else:
#         print_rate = 1  # if you run the sampler for less than 20 iterations
#     remainder = num_samples % print_rate

#     def _define_tqdm(arg, transform):
#         tqdm_bars[0] = tqdm(range(num_samples))
#         tqdm_bars[0].set_description(message, refresh=False)

#     def _update_tqdm(arg, transform):
#         tqdm_bars[0].update(arg)

#     def _update_progress_bar(iter_num):
#         "Updates tqdm progress bar of a JAX scan or loop"
#         _ = lax.cond(
#             iter_num == 0,
#             lambda _: host_callback.id_tap(_define_tqdm, None, result=iter_num),
#             lambda _: iter_num,
#             operand=None,
#         )

#         _ = lax.cond(
#             # update tqdm every multiple of `print_rate` except at the end
#             (iter_num % print_rate == 0) & (iter_num != num_samples - remainder),
#             lambda _: host_callback.id_tap(_update_tqdm, print_rate, result=iter_num),
#             lambda _: iter_num,
#             operand=None,
#         )

#         _ = lax.cond(
#             # update tqdm by `remainder`
#             iter_num == num_samples - remainder,
#             lambda _: host_callback.id_tap(_update_tqdm, remainder, result=iter_num),
#             lambda _: iter_num,
#             operand=None,
#         )

#     def _close_tqdm(arg, transform):
#         tqdm_bars[0].close()

#     def close_tqdm(result, iter_num):
#         return lax.cond(
#             iter_num == num_samples - 1,
#             lambda _: host_callback.id_tap(_close_tqdm, None, result=result),
#             lambda _: result,
#             operand=None,
#         )

#     def _progress_bar_scan(func):
#         """Decorator that adds a progress bar to `body_fun` used in `lax.scan`.
#         Note that `body_fun` must either be looping over `np.arange(num_samples)`,
#         or be looping over a tuple who's first element is `np.arange(num_samples)`
#         This means that `iter_num` is the current iteration number
#         """

#         def wrapper_progress_bar(carry, x):
#             if type(x) is tuple:
#                 iter_num, *_ = x
#             else:
#                 iter_num = x
#             _update_progress_bar(iter_num)
#             result = func(carry, x)
#             return close_tqdm(result, iter_num)

#         return wrapper_progress_bar

#     return _progress_bar_scan


class ObjectView:
    def __init__(self, d):
        self.__dict__ = d


def turn_dict_to_filename(d):
    return "_".join(["{}{}".format(k[:1], v) for k, v in d.items()])


def save(data, file):
    """
    Saves data to a file.
    """

    f = open(file, "w")
    pickle.dump(data, f)
    f.close()


def load(file):
    """
    Loads data from file.
    """

    f = open(file, "r")
    data = pickle.load(f)
    f.close()
    return data


def make_folder(folder):
    """
    Creates given folder (or path) if it doesn't exist.
    """

    if not os.path.exists(folder):
        os.makedirs(folder)


def drop_corr_and_constant(
    y, y_test=None, threshold=0.98, return_columns=False
):  # Drop super correlated variables

    data = pd.DataFrame(y)
    # Create correlation matrix
    corr_matrix = data.corr().abs()

    # Select upper triangle of correlation matrix
    upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(np.bool))
    stds = data.std()

    # Find index of feature columns with correlation greater than 0.95
    to_drop = [
        col for col in upper.columns if any(upper[col] > threshold) or stds[col] < 1e-3
    ]

    if return_columns:
        return to_drop

    y = data.drop(columns=to_drop).values

    if y_test is not None:
        y_tuple = []
        for y_i in y_test:
            y_i = pd.DataFrame(y_i).drop(columns=to_drop).values
            y_tuple.append(y_i)
        return y, y_tuple

    return y


def one_hot_encode(labels, n_labels):
    """
    Transforms numeric labels to 1-hot encoded labels. Assumes numeric labels are in the range 0, 1, ..., n_labels-1.
    """

    assert np.min(labels) >= 0 and np.max(labels) < n_labels

    y = np.zeros([labels.size, n_labels])
    y[range(labels.size), labels] = 1

    return y


def sigmoid(x):
    """
    Elementwise logistic sigmoid.
    :param x: numpy array
    :return: numpy array
    """
    return 1.0 / (1.0 + np.exp(-x))


def logit(x):
    """
    Elementwise logit (inverse logistic sigmoid).
    :param x: numpy array
    :return: numpy array
    """
    return np.log(x / (1.0 - x))


def bits_per_pixel(log_likelihood, n_dims, lamba, x):

    Dlog2 = n_dims * np.log(2)
    sigmoid_x = sigmoid(x)
    bpp = (
        -log_likelihood / Dlog2
        - np.log(1 - 2 * lamba)
        + 8
        + (np.log(sigmoid_x, 2) + np.log(1 - sigmoid_x, 2)).sum()
    )

    return bpp


def make_model_name(cfg, model_keys, wo_data=[]):
    return f"{cfg.model.model_class}_{turn_dict_to_filename({**{k:v for k, v in wandb.config.model.items() if k in model_keys}, **{k: v for k,v in wandb.config.data.items() if k not in ['num_points_per_axis', 'batch_size', 'batch_size_test', 'batching', 'batching_test', 'batch_size_train_test', 'low_mem'] + wo_data}, 'seed': cfg.base.seed, 'regr': cfg.base.regress, 'class': cfg.base['class']})}"


def generate_ar_masks(net_params):

    degrees = [list(net_params.values())[0]["w"].shape[-2]]
    for v in net_params.values():
        v = v["w"]
        assert degrees[-1] == v.shape[-2]
        degrees.append(v.shape[-1])

    old_degree = degrees[0]
    masks = []
    k = -1
    for new_degree in degrees[1:-1]:

        if new_degree > old_degree:
            k = -1
            new_mask = jnp.repeat(
                jnp.tri(old_degree, k=k).T,
                jnp.array(
                    [new_degree // old_degree for _ in range(old_degree - 1)]
                    + [new_degree // old_degree + new_degree % old_degree]
                ),
                axis=1,
            )
        else:
            new_mask = jnp.repeat(
                jnp.tri(new_degree, k=k).T,
                jnp.array(
                    [old_degree // new_degree for _ in range(new_degree - 1)]
                    + [old_degree // new_degree + old_degree % new_degree]
                ),
                axis=0,
            )
        k = 0
        masks.append(new_mask)

        old_degree = new_degree

    new_degree = degrees[0]
    new_mask = jnp.repeat(
        jnp.tri(new_degree, k=k).T,
        jnp.array(
            [old_degree // new_degree for _ in range(new_degree - 1)]
            + [old_degree // new_degree + old_degree % new_degree]
        ),
        axis=0,
    )
    new_mask = jnp.repeat(new_mask, degrees[-1] // new_degree, axis=1)
    masks.append(new_mask)

    res = masks[0]

    for mask in masks[1:]:
        res = res @ mask

    assert (
        (res > 0)
        == (jnp.repeat(jnp.tri(degrees[0], k=-1).T, degrees[-1] // degrees[0], axis=1))
    ).all()

    masks = {k: mask for k, mask in zip(net_params.keys(), masks)}

    return masks


def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch

    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


class EarlyStopping:
    def __init__(
        self,
        obj_fun,
        y_val,
        patience=10,
        miniter=0,
        verbose=True,
        d_perm_i=0,
        helper=None,
        d_perm_inds=None,
        n_perm_inds=None,
    ):
        self.patience = patience
        self.obj_fun = obj_fun
        self.counter = 0
        self.best_score = float("inf")
        self.early_stop = False
        self.best_params = None
        self.y_val = y_val
        self.iter = 0
        self.miniter = miniter
        self.verbose = verbose
        self.d_perm_i = d_perm_i
        self.vn_perm = None
        self.y_perm = None
        self.helper = helper
        self.last_score = self.best_score
        self.helper = helper
        self.d_perm_inds = d_perm_inds
        self.n_perm_inds = n_perm_inds

    # @jit
    def callback(self, params):
        self.iter += 1
        if self.iter <= self.miniter:
            return False
        if wandb.config.model["scipy_opt"]:
            score = self.obj_fun(
                hyperparam=params,
                y_perm=self.y_val,
                d_perm_inds=self.d_perm_inds,
                n_perm_inds=self.n_perm_inds,
                helper=self.helper,
            )
        else:
            score = self.obj_fun(
                rho_lengths_opt=params,
                y_test=self.y_val,
                vn_perm=self.vn_perm,
                y_perm=self.y_perm,
                d_perm_inds=self.d_perm_inds,
                n_perm_inds=self.n_perm_inds,
                bern=False,
                helper=self.helper,
            )
        wandb.log({"val_loss": score})
        wandb.log({"val_loglik_mean": -score})
        self.last_score = score
        if score > self.best_score:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
                wandb.log({f"total_iter_{self.d_perm_i}": self.iter})
        else:
            self.best_score = score
            self.save_checkpoint(params)
            self.counter = 0
        return self.early_stop

    def save_checkpoint(self, model):
        """Saves model when validation loss decrease."""
        self.best_params = model
