import numpy as np
from numpy.random import multivariate_normal
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D


def gen_data_2d(n=500, positive_r=.9):
    global mu_negative, mu_positive, Spl
    mu_negative = np.zeros(2)
    mu_positive = np.array([2, 2])
    Spl = np.fromfunction(lambda i, j: np.power(0.8, abs(i-j)), shape=(2, 2)).astype(float)
    n_positive, n_negative = int(n * positive_r), int(n * (1 - positive_r))
    data_positive = multivariate_normal(mu_positive, cov=Spl, size=n_positive)
    data_negative = multivariate_normal(mu_negative, cov=Spl, size=n_negative)
    label1 = np.ones(shape=(data_positive.shape[0], 1))
    label0 = np.zeros(shape=(data_negative.shape[0], 1))
    data, label = np.concatenate((data_positive, data_negative)), np.concatenate((label1, label0))
    return data, label


data_2d, label_2d = gen_data_2d()
label_2d = label_2d.astype("int")

# 绘图
fig, ax = plt.subplots(figsize=(6, 5))

# 生成网格数据
x_min, x_max = data_2d[:, 0].min() - 1, data_2d[:, 0].max() + 1
y_min, y_max = data_2d[:, 1].min() - 1, data_2d[:, 1].max() + 1
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 1000),
                     np.linspace(y_min, y_max, 1000))
ax.set_xlim([x_min, x_max])
ax.set_ylim([y_min, y_max])

# 计算网格上每一点的分类结果
w_star = np.dot(np.linalg.inv(Spl), (mu_positive - mu_negative))
b_star = 0.5 * (np.dot(mu_negative, np.dot(np.linalg.inv(Spl), mu_negative)) -
           np.dot(mu_positive, np.dot(np.linalg.inv(Spl), mu_positive))) + np.log(0.8 / 0.2)
print(w_star)
print(b_star)


def plot_boundary(ax, w, b, c, linestyle, linelabel):
    x = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 2)
    y = (-(w[0]*x + b)) / w[1]
    ax.plot(x, y, c=c, linestyle=linestyle)


Z = np.dot(np.concatenate([xx.ravel().reshape(-1, 1), yy.ravel().reshape(-1, 1)], axis=1), w_star) + b_star
Z = Z.reshape(xx.shape)
ax.contourf(xx, yy, Z, levels=[-np.inf, 0, np.inf], colors=['royalblue', 'red'], alpha=0.15)

# 在画布上画出数据点
cdict = {0: 'navy', 1: 'red'}
ldict = {0: 'negative', 1: 'positive'}
for g in np.unique(label_2d):
    ix = np.where(label_2d == g)[0]
    if g == 0:
        ax.plot(data_2d[ix, 0], data_2d[ix, 1], 'o', mec=cdict[g], mfc='none', label=ldict[g], markersize=5)
    else:
        ax.plot(data_2d[ix, 0], data_2d[ix, 1], 'x', mec=cdict[g], mfc='none', label=ldict[g], markersize=5)

# 显示图例
plot_boundary(ax, w_star, b_star, 'k', '-.', 'Decision boundary')
decision_boundary_legend = Line2D([0], [0], color='black', linestyle='-.', label='Decision Boundary')
legend_elements = [Patch(facecolor='red', edgecolor='red', label='Positive', alpha=0.5),
                   Patch(facecolor='navy', edgecolor='navy', label='Negative', alpha=0.5),
                   decision_boundary_legend]

# ax.set_yticks([])
# ax.set_xticks([])
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.legend(handles=legend_elements)


plt.savefig("2d-data.pdf", format='pdf', bbox_inches='tight')
