import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

from tqdm import tqdm
from scipy.stats import beta, entropy
from scipy.special import expit as sigmoid
from torch.utils.data import TensorDataset

EPS = 1e-5


class RBF(nn.Module):
    def __init__(self, sigma=1.0):
        super().__init__()
        self.sigma = nn.Parameter(data=torch.tensor(sigma))

    def forward(self, query, keys):
        return torch.exp(-torch.linalg.norm(query - keys, dim=-1) ** 2 / (2 * self.sigma ** 2))


def beta_uncertainty(score, y_ref):
    alpha = (score * y_ref).sum(dim=1)
    beta = (score * (1 - y_ref)).sum(dim=1)
    return (alpha * beta) / ((alpha + beta) ** 2 * (alpha + beta + 1) + EPS)


def nonparametric_estimate(X_q, X_ref, y_ref, sigma=0.05):
    with torch.no_grad():
        score = RBF(sigma=sigma)(X_q, X_ref)

    pos = score[:, y_ref == 1]
    neg = score[:, y_ref == 0]

    pos_counts = pos.sum(dim=-1).detach().numpy()
    neg_counts = neg.sum(dim=-1).detach().numpy()

    pred_count = (score * y_ref).sum(dim=-1) / score.sum(dim=-1)

    means, stderrs = beta.stats(pos_counts, neg_counts)

    return means, np.sqrt(stderrs), score


def rbf_beta_estimate(X_query, X_ref, y_ref, sigma=0.05):
    X_query = X_query.detach().numpy()
    X_ref = X_ref.detach().numpy()
    y_ref = y_ref.detach().numpy().squeeze()

    score = np.exp(-np.linalg.norm(X_query - X_ref, axis=-1) ** 2 / (2 * sigma ** 2 + 1e-5))
    # score = np.clip(score, EPS, None)
    score = 2 * sigmoid(score * 1e4) - 1 + 1e-2

    pos = (score * y_ref).sum(axis=-1)
    neg = (score * (1. - y_ref)).sum(axis=-1)

    means, stderrs = beta.stats(pos, neg)

    return means, np.sqrt(stderrs)


def bandit_eval(model, lr, X_test, y_test, num_points=1000, noise=0., multi=False, show_max_reg=True):
    test_dl = torch.utils.data.DataLoader(TensorDataset(X_test[:num_points], y_test[:num_points]), batch_size=1, shuffle=True)

    all_regrets = list()
    all_regrets_lr = list()

    cum_regret_beta = [0]
    cum_regret_lr = [0]

    all_stderr = list()
    all_stderr_lr = list()
    num_steps = num_points

    stderr_index_fn = (lambda sd, a: sd[:, a]) if multi else (lambda sd, a: sd[a])

    for X, y in tqdm(test_dl, total=num_steps):
        X = X + noise * torch.randn(X.size())
        mean, stderr = model.infer_batch(X)
        mean_lr, stderr_lr, _ = lr(X)

        # UCB
        action = np.argmax(mean + stderr)
        action_lr = np.argmax(mean_lr + stderr_lr)  # TODO: maybe can use softmax (GPClassification)

        regret = (y.argmax(dim=-1) != action).float().item()
        regret_lr = (y.argmax(dim=-1) != action_lr).float().item()

        all_regrets.append(regret)
        all_regrets_lr.append(regret_lr)

        cum_regret_beta.append(cum_regret_beta[-1] + regret)
        cum_regret_lr.append(cum_regret_lr[-1] + regret_lr)

        all_stderr.append(stderr_index_fn(stderr, action))
        all_stderr_lr.append(stderr_lr)

        model.store_buffer(X, y)  # TODO: Should we store the true label??

        lr.update_posterior(X.numpy(), y.numpy())
        lr.replace_prior()

    plt.figure(figsize=(12, 6))

    ax = plt.subplot(121)
    ax.plot(cum_regret_beta, label=model._get_name())
    ax.plot(cum_regret_lr, label="Bayes LR")
    if show_max_reg:
        ax.plot([0, num_steps], [0, num_steps], "--", label="Maximum cumulative regret")
    ax.set_ylabel("Cumulative regret")
    ax.set_xlabel("Steps")
    ax.legend()

    ax1 = plt.subplot(222)
    ax1.plot(all_stderr, label=f"Uncertainty by {model._get_name()}")
    ax1.legend()

    ax2 = plt.subplot(224)
    ax2.plot(np.array(all_stderr_lr) - 1, label="Uncertainty by BayesLR", c="tab:orange")
    ax2.legend()

    return {
        "all_regrets": all_regrets,
        "all_regrets_lr": all_regrets_lr,
        "cum_regret_beta": cum_regret_beta,
        "cum_regret_lr": cum_regret_lr,
        "all_stderr": all_stderr,
        "all_stderr_lr": all_stderr_lr
    }


def bandit_eval_gp(model, gp, X_test, y_test, num_points=1000, noise=0., X_data=None, y_data=None, multi=False, regressor=True, show_max_reg=True):
    test_dl = torch.utils.data.DataLoader(TensorDataset(X_test[:num_points], y_test[:num_points]), batch_size=1, shuffle=True)

    all_regrets = list()
    all_regrets_lr = list()

    cum_regret_beta = [0]
    cum_regret_lr = [0]

    all_stderr = list()
    all_stderr_lr = list()
    num_steps = num_points

    stderr_index_fn = (lambda sd, a: sd[:, a]) if multi else (lambda sd, a: sd[a])

    def _gp_predict_clf(X):
        prob_gp = gp.predict_proba(X)
        mean_gp = prob_gp
        stderr_gp = entropy(prob_gp, axis=-1)
        return mean_gp, stderr_gp

    def _gp_predict_reg(X):
        return gp.predict(X, return_std=True)

    gp_predict_fn = _gp_predict_reg if regressor else _gp_predict_clf

    for X, y in tqdm(test_dl, total=num_steps):
        X = X + noise * torch.randn(X.size())
        mean, stderr = model.infer_batch(X)

        mean_gp, stderr_gp = gp_predict_fn(X)

        # UCB
        action = np.argmax(mean + stderr)
        action_lr = np.argmax(mean_gp + stderr_gp)

        regret = (y.argmax(dim=-1) != action).float().item()
        regret_lr = (y.argmax(dim=-1) != action_lr).float().item()

        all_regrets.append(regret)
        all_regrets_lr.append(regret_lr)

        cum_regret_beta.append(cum_regret_beta[-1] + regret)
        cum_regret_lr.append(cum_regret_lr[-1] + regret_lr)

        all_stderr.append(stderr_index_fn(stderr, action))
        all_stderr_lr.append(stderr_gp)

        model.store_buffer(X, y)

        X_data = np.vstack([X_data, X])
        y_data = np.append(y_data, y[0].argmax())

        gp.fit(X_data, y_data)

    plt.figure(figsize=(12, 6))

    ax = plt.subplot(121)
    ax.plot(cum_regret_beta, label=model._get_name())
    ax.plot(cum_regret_lr, label="Gaussian Process")
    if show_max_reg:
        ax.plot([0, num_steps], [0, num_steps], "--", label="Maximum cumulative regret")
    ax.set_ylabel("Cumulative regret")
    ax.set_xlabel("Steps")
    ax.legend()

    ax1 = plt.subplot(222)
    ax1.plot(all_stderr, label=f"Uncertainty by {model._get_name()}")
    ax1.legend()

    ax2 = plt.subplot(224)
    ax2.plot(np.array(all_stderr_lr), label="Uncertainty by GP", c="tab:orange")
    ax2.legend()

    return {
        "all_regrets": all_regrets,
        "all_regrets_lr": all_regrets_lr,
        "cum_regret_beta": cum_regret_beta,
        "cum_regret_lr": cum_regret_lr,
        "all_stderr": all_stderr,
        "all_stderr_lr": all_stderr_lr
    }
