import numpy as np
import matplotlib.pyplot as plt

from numpy.linalg import inv
from scipy.stats import norm as univariate_normal


class BayesLR:
    def __init__(self, feat_dim: int, ci_width: float = 1.96, noise_var: float = 1, lamb_var: float = 1,
                 projection_func=None):
        self.feat_dim = feat_dim

        if projection_func is None:
            self.projection_func = self.to_design_mat
            self.feat_dim += 1
        else:
            self.projection_func = projection_func

        self.sigma = np.eye(self.feat_dim) * lamb_var
        self.mu = np.zeros(self.feat_dim)

        self.prior_sigma = self.sigma.copy()
        self.prior_mu = self.mu.copy()

        self.ci_width = ci_width
        self.noise_var = noise_var
        self.lamb_var = lamb_var

    def __call__(self, feat: np.ndarray):
        feat_mat = self.projection_func(feat)

        pred_mean = feat_mat.dot(self.mu)
        pred_cov = feat_mat.dot(self.sigma.dot(feat_mat.T)) + self.noise_var

        pred_sd = np.sqrt(np.diag(pred_cov))

        return pred_mean, pred_sd, univariate_normal(loc=pred_mean.flatten(), scale=pred_cov)

    def to_design_mat(self, feat_vec: np.ndarray) -> np.ndarray:
        feat_mat = np.ones((len(feat_vec), self.feat_dim))
        feat_mat[:, 1:] = feat_vec
        return feat_mat

    def update_posterior(self, feat_vec: np.ndarray, y: np.ndarray):
        X = self.projection_func(feat_vec)

        self.sigma = inv(X.T.dot(X) / self.noise_var + inv(self.prior_sigma))
        self.mu = self.sigma.dot(inv(self.prior_sigma).dot(self.prior_mu) + X.T.dot(y) / self.noise_var)

    def replace_prior(self):
        self.prior_mu = self.mu.copy()
        self.prior_sigma = self.sigma.copy()

    def reset(self):
        self.sigma = np.eye(self.feat_dim) * self.lamb_var
        self.mu = np.zeros(self.feat_dim)


class BayesMultiLR:
    def __init__(self, feat_dim: int, out_dim: int = 1, ci_width: float = 1.96, noise_var: float = 1, lamb_var: float = 1,
                 projection_func=None):
        self.feat_dim = feat_dim

        if projection_func is None:
            self.projection_func = self.to_design_mat
            self.feat_dim += 1
        else:
            self.projection_func = projection_func

        self.sigma = np.eye(self.feat_dim) * lamb_var
        self.mu = np.zeros((self.feat_dim, out_dim))

        self.prior_sigma = self.sigma.copy()
        self.prior_mu = self.mu.copy()

        self.ci_width = ci_width
        self.noise_var = noise_var
        self.lamb_var = lamb_var

    def __call__(self, feat: np.ndarray):
        feat_mat = self.projection_func(feat)

        pred_mean = feat_mat.dot(self.mu)
        pred_cov = feat_mat.dot(self.sigma.dot(feat_mat.T)) + self.noise_var

        pred_sd = np.sqrt(np.diag(pred_cov))

        return pred_mean, pred_sd, None

    def to_design_mat(self, feat_vec: np.ndarray) -> np.ndarray:
        feat_mat = np.ones((len(feat_vec), self.feat_dim))
        feat_mat[:, 1:] = feat_vec
        return feat_mat

    def update_posterior(self, feat_vec: np.ndarray, y: np.ndarray):
        X = self.projection_func(feat_vec)

        self.sigma = inv(X.T.dot(X) / self.noise_var + inv(self.prior_sigma))
        self.mu = self.sigma.dot(inv(self.prior_sigma).dot(self.prior_mu) + X.T.dot(y) / self.noise_var)

    def replace_prior(self):
        self.prior_mu = self.mu.copy()
        self.prior_sigma = self.sigma.copy()

    def reset(self):
        self.sigma = np.eye(self.feat_dim) * self.lamb_var
        self.mu = np.zeros(self.feat_dim)


def generate(xrange: tuple[int, int], func: callable = None, num_points: int = 30,
             noise_var: float = 1, seed: int = None):
    if seed:
        np.random.seed(seed)

    xrange = np.random.uniform(*xrange, size=num_points)
    if func is None:
        func = lambda x: 2 * x + 1

    yrange = func(xrange)

    idx = np.argsort(xrange)
    xrange = xrange[idx]
    yrange = yrange[idx]

    return xrange, yrange + np.random.normal(0, np.sqrt(noise_var), num_points), yrange


def cubic_graph(x):
    feat_mat = np.ones((len(x), 4))
    for i in range(1, 4):
        feat_mat[:, i] = x ** i
    return feat_mat


if __name__ == '__main__':
    model = BayesLR(4, projection_func=cubic_graph)

    cubic_fn = lambda x: -0.2 * x ** 3 + x ** 2 - x + 2
    xinterval = np.linspace(-6, 8, num=100)

    xrange1, yrange1, true_yrange1 = generate((-4, -3), noise_var=5, func=cubic_fn, seed=42, num_points=10)

    model.update_posterior(xrange1, yrange1)
    mean1, sd1, dist1 = model(xinterval)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 5))

    ax1.plot(xinterval, cubic_fn(xinterval), label="True function", c="tab:orange")
    ax1.plot(xinterval, mean1, label="Predicted function", c="tab:blue")
    ax1.fill_between(xinterval, mean1 - 1.96 * sd1, mean1 + 1.96 * sd1, alpha=0.4)
    ax1.scatter(xrange1, yrange1, label="Observed points", c="tab:orange")
    ax1.legend()

    xrange2, yrange2, true_yrange2 = generate((5, 7), noise_var=5, func=cubic_fn, seed=42, num_points=10)

    if True:
        title = "(1) Replaced prior with posterior and updated with only second set of observations"
        model.replace_prior()

        model.update_posterior(xrange2, yrange2)
    elif False:
        title = "(2) Prior unchanged but updated with all observations"
        model.update_posterior(np.hstack((xrange1, xrange2)), np.hstack((yrange1, yrange2)))
    else:
        title = "(3) Prior unchanged and updated with only second set of observations"
        model.update_posterior(xrange2, yrange2)

    mean2, sd2, dist2 = model(xinterval)

    ax2.plot(xinterval, cubic_fn(xinterval), label="True function", c="tab:orange")
    ax2.plot(xinterval, mean2, label="Predicted function", c="tab:blue")
    ax2.fill_between(xinterval, mean2 - 1.96 * sd2, mean2 + 1.96 * sd2, alpha=0.4)
    ax2.scatter(xrange2, yrange2, label="Observed points", c="tab:orange")
    ax2.scatter(xrange1, yrange1, label="Previous observed points", c="tab:gray")
    ax2.legend()

    fig.suptitle(title)

    plt.tight_layout()
