import pickle
import sys

import numpy as np
from numpy.linalg import qr
from numpy.random import permutation as randperm
from pylab import *
from scipy.stats import chi

import putil
from rutil import bernoulli

M = 3


def featmap(feats, data):
    F = feats["w"]
    b = feats["b"]
    K = len(b)
    N = len(data)
    return cos(np.dot(data, F.T) + b) * sqrt(2 / K)


def plik(w, c, feats, data, negs):
    return poslik(w, c, feats, data) + neglik(w, c, feats, negs)


def poslik(w, c, feats, data):
    # GP likelihoods
    Z = featmap(feats, data)
    f = np.dot(Z, w)
    lik = log(relu(f)) - c
    # prior likelihoods
    lik0 = lgauss(data)
    # joint
    return np.sum(log(sigmoid(lik - lik0)))


def neglik(w, c, feats, data):
    # GP likelihoods
    Z = featmap(feats, data)
    f = np.dot(Z, w)
    lik = log(relu(f)) - c
    # prior likelihoods
    lik0 = lgauss(data)
    # joint
    return np.sum(log(sigmoid(lik0 - lik)))


def lgauss(xx):
    return np.array([lgauss1(x) for x in xx])


def lgauss1(x):
    return -log(2 * np.pi) / 2 - np.dot(x, x) / 2


def mh(lik, liknew):
    r = exp(liknew - lik)
    return bernoulli(r)


def mcmc(feats, data, iters, M=100, eps=1):
    N, D = data.shape
    F = feats["w"]
    K = len(feats["b"])
    # body
    w = randn(K)
    c = 0  # logZ
    results = []
    M = N
    # negs = randn (M, D)
    for iter in range(iters):
        accept = 0
        negs = randn(M, D) * 2
        lik = plik(w, c, feats, data, negs)
        cnew = c + eps * randn()
        liknew = plik(w, cnew, feats, data, negs)
        if mh(lik, liknew):
            c = cnew
            accept += 1

        for k in randperm(K):
            lik = plik(w, c, feats, data, negs)
            r = eps * randn()
            w[k] += r
            liknew = plik(w, c, feats, data, negs)
            if mh(lik, liknew):
                accept += 1
            else:
                w[k] -= r

        if (iter > iters / 5) or (iters == 1):
            results.append(w)
        print(
            "iter[%3d] : lik = %.2f (accept=%.1f%%)"
            % ((iter + 1), lik, accept / (K + 1) * 100)
        

    w = np.mean(results, axis=0)
    return w, c, negs


def draw(w, feats, N):
    F = feats["w"]
    b = feats["b"]
    K = len(b)
    xy1, xy2 = np.meshgrid(np.linspace(-M, M, N), np.linspace(-M, M, N))
    xy = np.vstack((xy1.ravel(), xy2.ravel())).T
    # body
    N = xy.shape[0]
    Z = np.zeros((N, K), dtype=float)
    for n in range(N):
        for k in range(K):
            Z[n, k] = sqrt(2) * cos(np.dot(xy[n], F[k]) + b[k])
    Z = Z / K
    zz = np.dot(Z, w)
    N = len(zz)
    for n in range(N):
        if zz[n] < 1e-7:
            zz[n] = 0
    plot2(xy1, xy2, zz)


def plot2(xy1, xy2, zz):
    N = len(xy1[0])
    pcolor(xy1, xy2, zz.reshape(N, N), cmap="jet")
    putil.aspect_ratio(1)
    colorbar()


def plot_data(data, marker="o"):
    if marker == "x":
        scatter(data[:, 0], data[:, 1], s=120, marker=marker, color="k")
    else:
        scatter(
            data[:, 0], data[:, 1], s=120, marker="o", facecolor="none", edgecolors="k"
        )


def plot_feats(feats, N):
    F = feats["w"]
    b = feats["b"]
    K = len(b)
    xx = np.linspace(-M, M, N)
    yy = np.zeros(N, dtype=float)
    for k in range(K):
        for n in range(N):
            yy[n] = sqrt(2) * cos(np.dot(F[k], xx[n]) + b[k])
        plot(xx, yy, label=("Feature %d" % (k + 1)))
    legend(bbox_to_anchor=(1.05, 1))


def rff(D, K=10):
    F = np.zeros((K, D), dtype=float)
    b = np.zeros(K, dtype=float)
    for k in range(K):
        F[k] = randn(D) * 2
        b[k] = rand() * 2 * np.pi
    return {"w": F, "b": b}


def off(D, K=10):  # orthogonal Fourier features
    # assume K > D.
    Y = []
    while len(Y) < K:
        X = randn(D, D)
        Q, R = qr(X)
        Y.extend(Q.T)
    Y = np.array(Y[0:K])
    c = chi.rvs(D, size=K)
    F = np.dot(diag(c), Y) * 2
    # biases
    b = rand(K) * 2 * np.pi
    return {"w": F, "b": b}


def load(file):
    data = np.loadtxt(file, dtype=float)
    if len(data.shape) > 1:
        return data
    else:
        N = data.shape[0]
        return data.reshape(N, 1)


def relu(xx, eps=1e-7):
    return np.array([max(x, eps) for x in xx])


def sigmoid(x):
    return 1 / (1 + exp(-x))


def usage():
    print("usage: rffreg-mcmc.py data.dat iters [K] [output] [feats_path]")
    sys.exit(0)


def main():
    if len(sys.argv) < 5:
        usage()
    else:
        data = load(sys.argv[1])
        iters = int(sys.argv[2])
        K = int(sys.argv[3])
        output = sys.argv[4]

    if len(sys.argv) > 5:
        feats = pickle.load(open(sys.argv[5], "rb"))
    else:
        feats = rff(2, K)
        pickle.dump(feats, open(f"feats_D-2_K-{str(K)}.pkl", "wb"))

    np.random.seed(12345)
    w, c, negs = mcmc(feats, data, iters)
    print("w =", w)
    print("c =", c)
    pickle.dump(w, open(f'weights_{output.split(".")[0]}.pkl', "wb"))

    draw(w, feats, 100)
    plot_data(data)
    axis([-M, M, -M, M])

    putil.savefig(output, dpi=200)
    show()


if __name__ == "__main__":
    main()
