import tensorflow as tf
import numpy as np
from experiments.target_lnpdfs.Lnpdf import LNPDF
from math import pi

class LogisticRegression(LNPDF):
    def __init__(self, dataset_id):
        self.const_term = tf.constant(tf.cast(0.5 * tf.math.log(2. * pi), dtype=tf.float32))
        if dataset_id == "breast_cancer":
            data = np.loadtxt("experiments/target_lnpdfs/datasets/breast_cancer.data")
            X = data[:, 2:]
            X /= np.std(X, 0)[np.newaxis, :]
            X = np.hstack((np.ones((len(X), 1)), X))
            self.data = tf.cast(X, tf.float32)
            self.labels = data[:,1]
            self.num_dimensions = self.data.shape[1]
            self._prior_std = tf.constant(10., dtype=tf.float32)
            self.prior_mean = tf.constant(0., dtype=tf.float32)
        elif dataset_id == "german_credit":
            data = np.loadtxt("experiments/target_lnpdfs/datasets/german.data-numeric")
            X = data[:, :-1]
            X /= np.std(X, 0)[np.newaxis, :]
            X = np.hstack((np.ones((len(X), 1)), X))
            self.data = tf.cast(X, tf.float32)
            self.labels = data[:, -1] - 1
            self.num_dimensions = self.data.shape[1]
            self._prior_std = tf.constant(10., dtype=tf.float32)
            self.prior_mean = tf.constant(0., dtype=tf.float32)

    def get_num_dimensions(self):
        return self.num_dimensions

    @property
    def prior_std(self):
        return self._prior_std

    @tf.function(input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
    def log_likelihood(self, x):
        features = -tf.matmul(self.data, tf.transpose(x))
        log_likelihoods = tf.where(self.labels==1, tf.transpose(tf.math.log_sigmoid(features)), tf.transpose(tf.math.log_sigmoid(features) - features))
        return log_likelihoods

    @tf.function(input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
    def log_density(self, x):
        features = -tf.matmul(self.data, tf.transpose(x))
        log_likelihoods = tf.reduce_sum(tf.where(self.labels==1, tf.transpose(tf.math.log_sigmoid(features)), tf.transpose(tf.math.log_sigmoid(features) - features)), axis=1)
        log_prior = tf.reduce_sum(-tf.math.log(self.prior_std) - self.const_term - 0.5 * tf.math.square((x - self.prior_mean) / self.prior_std), axis=1)
        log_posterior = log_likelihoods + log_prior
        return log_posterior

class LogisticRegression_minibatch(LogisticRegression):
    def __init__(self, dataset_id, batchsize):
        super(LogisticRegression_minibatch, self).__init__(dataset_id)
        self.batchsize = batchsize

    def log_likelihood(self, x):
        data_mb, labels_mb = tf.split(
            tf.random.shuffle(tf.concat((self.data, tf.expand_dims(tf.cast(self.labels, tf.float32), 1)), axis=1))[
            :self.batchsize], [self.num_dimensions, 1], axis=1)
        labels_mb = tf.squeeze(labels_mb)
        features = -tf.matmul(data_mb, tf.transpose(x))
        log_likelihoods = tf.where(labels_mb==1, tf.transpose(tf.math.log_sigmoid(features)), tf.transpose(tf.math.log_sigmoid(features) - features))
        return log_likelihoods

    def log_density(self, x):
        data_mb, labels_mb = tf.split(
            tf.random.shuffle(tf.concat((self.data, tf.expand_dims(tf.cast(self.labels, tf.float32), 1)), axis=1))[
            :self.batchsize], [self.num_dimensions, 1], axis=1)
        labels_mb = tf.squeeze(labels_mb)
        features = -tf.matmul(data_mb, tf.transpose(x))
        log_likelihoods = tf.reduce_sum(tf.where(labels_mb==1, tf.transpose(tf.math.log_sigmoid(features)), tf.transpose(tf.math.log_sigmoid(features) - features)), axis=1)
        log_prior = tf.reduce_sum(-tf.math.log(self.prior_std) - 0.5 * tf.math.log(2. * pi) - 0.5 * tf.math.square((x - self.prior_mean) / self.prior_std), axis=1)
        log_posterior = log_likelihoods + log_prior
        return log_posterior

def make_breast_cancer():
    return LogisticRegression("breast_cancer")

def make_german_credit():
    return LogisticRegression("german_credit")

def make_breast_cancer_mb(batchsize):
    return LogisticRegression_minibatch("breast_cancer", batchsize)

def make_german_credit_mb(batchsize):
    return LogisticRegression_minibatch("german_credit", batchsize)