import jax
import jax.scipy as jsc
from jax import random
import jax.numpy as jnp
import numpy as np
from utils import invgamma_logp
class Y2Regression:
    max_N = 100000
    test_N = 10000
    max_g = 100
    def __init__(self, N = 100, g = 10, m = 5, d = 5, alpha = 0.0, *args, **kwargs):
        self.x_dim = d
        self.d = d + 1
        self.g = g
        self.N = N
        self.m = m
        self.x = random.normal(random.PRNGKey(0), (self.max_N, m, d))[:self.N]
        self.test_x = random.normal(random.PRNGKey(1), (self.test_N, m, d))
        self.assignment = random.choice(random.PRNGKey(2), self.g, (self.max_N,))[:self.N]
        self.test_assignment = random.choice(random.PRNGKey(3), self.g, (self.test_N,))
        self.valid_x = random.normal(random.PRNGKey(4), (self.test_N, m, d))
        self.valid_assignment = random.choice(random.PRNGKey(5), self.g, (self.test_N,))
        self.valid_y = None
        self.beta = None
        self.alpha = alpha

    def log_prior(self, theta):
        return jnp.sum(jsc.stats.norm.logpdf(theta, 0, 1))

    def log_likelihood(self, theta, y):
        raise NotImplementedError()

    def log_likelihoods(self, theta, y):
        raise NotImplementedError()
    def test_log_likelihoods(self, theta, y):
        raise NotImplementedError()
    def sample_datapoint(self, key, theta):
        beta = theta[..., :self.x_dim]
        sigma = jnp.exp(theta[..., self.x_dim])
        y = random.normal(key, (self.N, self.m)) * sigma + jnp.sum(beta * self.x, axis=-1)
        return jnp.sum(y * y, axis=-1)

    def sample_prior(self, key):
        return random.normal(key, (self.d,))

    def data(self, key = None):
        key1, key2, key3, key4, key5 = random.split(key, 5)
        beta = random.normal(key1, ((self.max_g + 1), 1, self.x_dim, ))
        beta = (1-self.alpha) * beta[-1] + self.alpha * beta
        self.beta = beta[:self.g]
        logsigma = random.normal(key2, (self.max_g + 1, ),)
        logsigma = (1-self.alpha) * logsigma[-1] + self.alpha * logsigma
        sigma = jnp.expand_dims(jnp.exp(logsigma), -1)
        data = random.normal(key3, (self.max_N, self.m))[:self.N] * sigma[self.assignment] + jnp.sum(beta[self.assignment] * self.x, axis=-1)
        test_data = random.normal(key4, (self.test_N, self.m)) * sigma[self.test_assignment] + jnp.sum(beta[self.test_assignment] * self.test_x, axis=-1)
        valid_data = random.normal(key5, (self.test_N, self.m)) * sigma[self.valid_assignment] + jnp.sum(beta[self.valid_assignment] * self.valid_x, axis=-1)
        self.valid_y = jnp.sum(valid_data * valid_data, axis=-1)
        return jnp.sum(data * data, axis=-1), jnp.sum(test_data * test_data, axis=-1)

    def likelihood_parameters(self, theta):
        raise NotImplementedError()

    def M(self, theta = None):
        raise NotImplementedError()

    def validate_crps(self, theta1, theta2, key,):
        key1, key2 = random.split(key)
        beta1 = theta1[:self.x_dim]
        sigma1 = jnp.exp(theta1[self.x_dim])
        beta2 = theta2[:self.x_dim]
        sigma2 = jnp.exp(theta2[self.x_dim])
        y1 = random.normal(key1, (self.test_N, self.m)) * sigma1 + jnp.sum(beta1 * self.valid_x, axis=-1)
        y2 = random.normal(key2, (self.test_N, self.m)) * sigma2 + jnp.sum(beta2 * self.valid_x, axis=-1)
        y1 = jnp.sum(y1 * y1, -1)
        y2 = jnp.sum(y2 * y2, -1)
        return jnp.sum(jnp.abs(y1 - self.valid_y) / 2 + jnp.abs(y2 - self.valid_y) / 2 - jnp.abs(y1 - y2) / 2)


    def test_crps(self, theta1, theta2, key, test_y):
        key1, key2 = random.split(key)
        beta1 = theta1[:self.x_dim]
        sigma1 = jnp.exp(theta1[self.x_dim])
        beta2 = theta2[:self.x_dim]
        sigma2 = jnp.exp(theta2[self.x_dim])
        y1 = random.normal(key1, (self.test_N, self.m)) * sigma1 + jnp.sum(beta1 * self.test_x, axis=-1)
        y2 = random.normal(key2, (self.test_N, self.m)) * sigma2 + jnp.sum(beta2 * self.test_x, axis=-1)
        y1 = jnp.sum(y1 * y1, -1)
        y2 = jnp.sum(y2 * y2, -1)
        #print(jnp.sum(jnp.abs(y1 - test_y) / 2 + jnp.abs(y2 - test_y) / 2), jnp.sum(jnp.abs(y1 - y2) / 2))
        return jnp.sum(jnp.abs(y1 - test_y) / 2 + jnp.abs(y2 - test_y) / 2 - jnp.abs(y1 - y2) / 2)
