import numpy as np
import matplotlib.pyplot as plt


class RV1d(object):
    def __init__(self, name='null', save_data=False):
        self.name = name
        self.sum = 0
        self.squared_sum = 0
        self.cnt = 0
        self.sample = [] if save_data else None

    def append(self, x):
        self.sum += x
        self.squared_sum += x**2
        self.cnt += 1
        if self.sample is not None: self.sample.append(x)

    def merge(self, another_rv):
        self.sum += another_rv.sum
        self.squared_sum += another_rv.squared_sum
        self.cnt += another_rv.cnt
        if self.sample is not None:
            self.sample = None if another_rv.sample is None else self.sample + another_rv.sample

    #@property
    def size(self):
        return self.cnt

    #@property
    def mean(self):
        return float('nan') if self.cnt < 1 else self.sum / self.cnt

    #@property
    def var(self):
        return float('nan') if self.cnt < 2 else (self.squared_sum - self.sum**2 / self.cnt) / (self.cnt-1)

    #@property
    def stdev(self):
        var = self.var()
        assert var >= 0  # var can be negative when S^2 is strictly 0, due to float arithmetic
        return np.sqrt(var)

    def mean_ci(self, cl=90, force=True):
        """
        Provide an approximate confidence interval of the population mean based on CLT:
        The interval has a prior-probability of 'cl' (by default 90%) to capture the true population mean under apriori
        assumption that the sample size has been "large enough" to make the sample mean normally distributed (and to
        make the sample stdev converges as well).
        cl: confidence level. currently only support cl=90, 95, 98, 99
        force: by default force=True, in which case will output an actual estimate whenever possible (i.e. as long as
                sample size >= 2); when force=False, will output (-inf,+inf) if the sample size is judged too small.
        """
        if self.cnt < 2: return float('-inf'), float('+inf')

        if force is False:
            # TODO: come up with better method to judge whether the sample size is large enough for quality output,
            #  e.g. based on normality test
            if self.cnt < 50: return float('-inf'), float('+inf')

        sample_size = self.cnt
        sample_mean = self.mean()
        sample_stdev = self.stdev()
        quantile = {90:1.645, 95:1.960, 98:2.326, 99:2.576}  # two-sided confidence level
        error = sample_stdev * quantile[cl] / np.sqrt(sample_size)
        interval = (sample_mean - error, sample_mean + error)
        return interval

    def __str__(self):
        name = self.name
        mean = self.mean()
        mean_interval = self.mean_interval()
        dev = self.dev()
        size = self.size()
        return '%s\tmean= %.2f (%.2e , %.2e)\tstd= %.2f\tn= %d' % (name, mean, *mean_interval, dev, size)

    def plot(self, blocking=True):
        if self.sample is not None:
            self._plot(blocking)

    def _plot(self, blocking=True):
        data = self.sample

        if not blocking:
            plt.ion()
            #plt.draw()
            #plt.pause(0.1)

        fig, ax = plt.subplots()
        plt.hist(data, bins='auto', facecolor='blue', alpha=0.5)
        fig.suptitle(self.name)
        ax.grid()
        #plt.subplots_adjust(left=0.15)

        if not blocking:
            plt.draw()
            plt.pause(0.1)
        else:
            plt.show(block=True)

    # back-compatible interface
    #@property
    def dev(self):
        return self.stdev()

    #@property
    def mean_interval(self):
        return self.mean_ci(cl=90)



# TODO: the following is legacy code for a multivariate RV class that is not actively used for the moment; need new impl.
class RV(object):
    def __init__(self, *variable_names):
        self.name = variable_names
        if len(self.name) == 0:
            self.name = ('null',)
        self.idx = {}
        idx = 0
        for name in self.name:
            self.idx.update({name: idx})
            idx += 1

        self.data = [[] for i in range(len(self))]
        self.sum = np.zeros(len(self), dtype=np.float)
        self.squared_sum = np.zeros(len(self), dtype=np.float)
        self.sample_size = np.zeros(len(self), dtype=np.int)

    def __len__(self):
        return len(self.name)

    def append(self, *x):
        assert(len(x) == len(self))
        self.sum = self.sum + np.array(x)
        self.squared_sum = self.squared_sum + np.array(x)**2
        self.sample_size = self.sample_size + 1

        for i in range(len(self)):
            self.data[i].append(x[i])

    def size(self):
        return self.sample_size[0]

    def mean(self, name=None):
        with np.errstate(divide='ignore', invalid='ignore'):
            if name is not None:
                idx = self.idx[name]
                return self.sum[idx] / self.sample_size[idx]

            mean = self.sum / self.sample_size

            if len(self.name) == 1:
                return mean[0]
            else:
                return mean

    def dev(self, name=None):
        mean = self.mean(name)
        squared_sum = self.squared_sum[self.idx[name]] if name is not None else self.squared_sum
        size = self.sample_size[self.idx[name]] if name is not None else self.sample_size

        with np.errstate(divide='ignore', invalid='ignore'):
            var = (squared_sum - size * mean**2) / (size - 1)
            dev = np.sqrt(var) if var > 0 else 0  # var can be negative when S^2 is strictly 0, due to float arithmetic

        if name is None and len(self.name) == 1:
            return dev[0]
        else:
            return dev

    def mean_interval(self, name=None):
        ret = []
        for i in range(len(self)):
            name = self.name[i]
            sample_size = self.sample_size[i]
            sample_mean = self.mean(name)
            sample_dev = self.dev(name)
            quantile = {'90':1.645, '95':1.960, '99':2.576}  # two-sided confidence level
            mean_error = quantile['90'] * sample_dev / np.sqrt(sample_size)  # TODO: remove runtimewarning in IDE
            mean_interval = (sample_mean-mean_error, sample_mean+mean_error)
            ret.append(mean_interval)

        return ret[0] if len(self.name)==1 else ret

    def __str__(self):
        """
        provide an interval estimate for the population mean, and a point estimate for the population std.dev.
        the interval of the mean comes with a conditional aprior confidence:
        the interval has a prior probability of 90% to capture the ground truth conditioned that the sample size
        has been large enough to make the sample mean normally distributed
        """
        s = ''
        for i in range(len(self)):
            name = self.name[i]
            sample_size = self.sample_size[i]
            sample_mean = self.mean(name)
            sample_dev = self.dev(name)
            quantile = {'90':1.645, '95':1.960, '99':2.576}  # two-sided confidence level
            mean_error = quantile['90'] * sample_dev / np.sqrt(sample_size)  # TODO: remove runtimewarning in IDE
            mean_interval = (sample_mean-mean_error, sample_mean+mean_error)
            s += '%s\tmean= %.2f (%.2e , %.2e)\tstd= %.2f\tn= %d' % (name, sample_mean, *mean_interval, sample_dev, sample_size)
            if i < len(self)-1: s += '\n'
        return s

    def plot(self, name=None, blocking=True):
        if name is not None:
            self._plot(name, blocking)
        else:
            for var in self.name: self._plot(var, blocking)

    def _plot(self, name, blocking=True):
        data = self.data[self.idx[name]]
        plt.clf()
        if not blocking: plt.ion()
        plt.hist(data, bins='auto', facecolor='blue', alpha=0.5)
        plt.title(name)
        plt.grid()
        plt.subplots_adjust(left=0.15)
        plt.show()
        if not blocking: plt.pause(1.0)



if __name__ == "__main__":
    X = RV1d('X', save_data=True)
    for i in range(10):
        X.append(i)
    print(X)
    print(X.mean())
    print(X.mean_interval())
    print(X.mean_ci(cl=90))
    print(X.mean_ci(force=False))
    print(X.mean_ci(cl=95))
    print(X.dev())
    print(X.stdev())
    print(X.var())
    X.plot(blocking=False)
    X.plot()

    Y = RV1d('Y', save_data=True)
    for i in range(10,20):
        Y.append(i)
    print(Y)
    Y.merge(X)
    print(Y)
    Y.plot()  # the empirical distribution is not uniform because the bin width is 20/6=3.3

    Y = RV1d('Y without data')
    for i in range(10, 20):
        Y.append(i)
    print(Y)
    X.merge(Y)  # X should lost its original data because Y's data has been lost
    X.plot()  # should see nothing
    '''
    X	mean= 4.50 (2.93e+00 , 6.07e+00)	std= 3.03	n= 10
    4.5
    (2.925032407740824, 6.074967592259176)
    (2.925032407740824, 6.074967592259176)
    (-inf, inf)
    (2.6234428687975773, 6.376557131202423)
    3.0276503540974917
    3.0276503540974917
    9.166666666666666
    Y	mean= 14.50 (1.29e+01 , 1.61e+01)	std= 3.03	n= 10
    Y	mean= 9.50 (7.32e+00 , 1.17e+01)	std= 5.92	n= 20
    Y without data	mean= 14.50 (1.29e+01 , 1.61e+01)	std= 3.03	n= 10
    '''


    '''
    # multivariate case will fail at the moment
    X = RV('X1','X2')
    print(X.mean())
    print(X.dev())
    X.append(1,2)
    print(X.mean())
    print(X.dev())
    X.append(2,3)
    print(X.mean())
    print(X.dev())

    Y = RV('Y1', 'Y2')
    for i in range(100):
        Y.append(i,i+1)
    print(Y.mean())
    print(Y.dev())
    print(Y.mean('Y2'))
    print(Y.dev('Y2'))
    print(Y)
    '''
