import jax
import jax.scipy as jsc
from jax import random
import jax.numpy as jnp
import numpy as np
from utils import invgamma_logp
class OutlierRegression:
    d = 3
    max_N = 100000
    test_N = 10000
    max_g = 100
    def __init__(self, N = 100, g = 5, outlier_scale = 10, outlier_N=10, alpha = 0.5, *args, **kwargs):
        self.N = N
        self.g = g
        self.x = random.normal(random.PRNGKey(0), (self.max_N,))[:self.N]
        self.assignment = random.choice(random.PRNGKey(1), self.g, (self.max_N,))[:self.N]
        self.test_x = random.normal(random.PRNGKey(2), (self.test_N,))
        self.test_assignment = random.choice(random.PRNGKey(3), self.g, (self.test_N,))
        self.valid_x = random.normal(random.PRNGKey(4), (self.test_N,))
        self.valid_assignment = random.choice(random.PRNGKey(5), self.g, (self.test_N,))
        self.valid_y = None
        self.outlier_scale = outlier_scale
        self.outlier_N = outlier_N
        self.alpha = alpha

    def log_prior(self, theta):
        return jnp.sum(jsc.stats.norm.logpdf(theta, 0, 1))

    def log_likelihood(self, theta, y  = None):
        if y is None:
            y = self.y
        beta1 = theta[..., 0]
        beta2 = theta[..., 1]
        sigma = jnp.exp(theta[..., 2])
        return jnp.sum(jsc.stats.norm.logpdf(y, beta1 * self.x + beta2, sigma))

    def log_likelihoods(self, theta, y  = None):
        if y is None:
            y = self.y
        beta1 = theta[..., 0]
        beta2 = theta[..., 1]
        sigma = jnp.exp(theta[..., 2])
        return jsc.stats.norm.logpdf(y, beta1 * self.x + beta2, sigma)

    def test_log_likelihoods(self, theta, y = None):
        if y is None:
            y = self.test_y
        beta1 = theta[..., 0]
        beta2 = theta[..., 1]
        sigma = jnp.exp(theta[..., 2])
        return jsc.stats.norm.logpdf(y, beta1 * self.test_x + beta2, sigma)

    def valid_log_likelihoods(self, theta, y = None):
        if y is None:
            y = self.valid_y
        beta1 = theta[..., 0]
        beta2 = theta[..., 1]
        sigma = jnp.exp(theta[..., 2])
        return jsc.stats.norm.logpdf(y, beta1 * self.valid_x + beta2, sigma)

    def sample_datapoint(self, key, theta):
        beta1 = theta[..., 0]
        beta2 = theta[..., 1]
        sigma = jnp.exp(theta[..., 2])
        return random.normal(key, (self.N, )) * sigma + beta1 * self.x + beta2

    def sample_test_datapoint(self, key, theta):
        beta1 = theta[..., 0]
        beta2 = theta[..., 1]
        sigma = jnp.exp(theta[..., 2])
        return random.normal(key, (self.test_N, )) * sigma + beta1 * self.test_x + beta2

    def data(self, key = None):
        key1, key2, key3, key4, key5 = random.split(key, 5)
        beta = random.normal(key1, ((self.max_g + 1) * 2, ))
        beta1 = beta[:self.max_g+1]
        beta2 = beta[self.max_g+1:]
        logsigma = random.normal(key2, (self.max_g+1, ),)
        logsigma = (1 - self.alpha) * logsigma[-1] + self.alpha * logsigma
        sigma = jnp.exp(logsigma)
        beta1 = (1 - self.alpha) * beta1[-1] + self.alpha * beta1
        beta2 = (1 - self.alpha) * beta2[-1] + self.alpha * beta2
        training_noise = random.normal(key3, (self.max_N,))[:self.N]
        training_noise = training_noise.at[:self.outlier_N].set(training_noise[:self.outlier_N] * self.outlier_scale)
        self.y = training_noise * sigma[self.assignment] + beta1[
            self.assignment] * self.x + beta2[self.assignment]
        self.test_y = random.normal(key4,  (self.test_N,)) * sigma[self.test_assignment] + beta1[
            self.test_assignment] * self.test_x + beta2[self.test_assignment]
        self.valid_y = random.normal(key5, (self.test_N,)) * sigma[self.valid_assignment] + beta1[
            self.valid_assignment] * self.valid_x + beta2[self.valid_assignment]
        return self.y, self.test_y

    def likelihood_parameters(self, theta):
        raise NotImplementedError()

    def M(self, theta = None):
        beta1 = theta[..., 0]
        beta2 = theta[..., 1]
        sigma = jnp.exp(theta[..., 2])
        return 1 / sigma / jnp.sqrt(2 * jnp.pi)