import numpy as np
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
from sklearn import svm
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.legend_handler import HandlerLine2D, HandlerTuple
import matplotlib.patches as patches

import matplotlib as mpl
plt.rcParams.update({"text.usetex": True})
mpl.rcParams['text.latex.preamble'] = r'\usepackage{{amsmath}}'
## for Palatino and other serif fonts use:

def rotation_matrix(degree):
    theta = np.radians(degree)
    c, s = np.cos(theta), np.sin(theta)
    R = np.array(((c, -s), (s, c)))
    return R


style = "Simple, tail_width=0.5, head_width=10, head_length=24"

n = 30
# eps=1 # ratio = 1
eps = 2e-1 # ratio very large

# generate data
source_1_pos = lambda k: [1-k*eps,k,1]
source_1_neg = lambda k: [-1-k*eps,k,-1]

source_2_pos = lambda k: [1+k*eps,k,(k-1)/eps]
source_2_neg = lambda k: [-1+k*eps,k,(1+k)/eps]

Wstar = np.array([[1, eps], [1, -eps]])
Fstar = np.array([[1, 0, 0], [0, 1, 0]])
#sigma_star = np.linalg.svd(Wstar, compute_uv=False)
#print("Spectrum SVD: "+str(sigma_star))
#print("Ratio SVD: "+str(max(sigma_star)/min(sigma_star)))
sp_star = [np.sqrt(2), eps*np.sqrt(2)]
print("Optimal spectrum: "+str(sp_star))
print("Optimal ratio: "+str(1./eps))

k = np.random.rand(n) # uniform distribution in [0,1]
source = []

source.append(np.vstack(([source_1_pos(i) for i in k], [source_1_neg(i) for i in k])))
y = np.concatenate((np.ones(n), -np.ones(n)))
source.append(np.vstack(([source_2_pos(i) for i in k], [source_2_neg(i) for i in k])))


markers = ['+', '_']
colors = ['r', 'b']

p = []
What = np.array([[0, 1], [1, -eps]])
Fhat = np.array([[0, 1, 0], [0, 0, 1]])
#sigma_svd = np.linalg.svd(What, compute_uv=False)
#print("Corrected spectrum SVD: "+str(sigma_svd))
#print("Corrected ratio SVD: "+str(max(sigma_svd)/min(sigma_svd)))
exact_ratio = np.sqrt((2 + eps**2 + eps*np.sqrt((eps**2+4)))/(2 + eps**2 - eps*np.sqrt(eps**2+4)))
print("Corrected ratio: "+str(exact_ratio))


plt.figure(num=1, figsize=(19.20,10.80))

for i in range(len(source)):

    ax = plt.gca()

    sourceix = source[i].dot(Fstar.T)[:,0]
    sourceiy = source[i].dot(Fstar.T)[:,1]

    for k, j in enumerate(np.unique(y)):
        idx = np.where(y == j)[0]
        p.append(ax.scatter(sourceix[idx], sourceiy[idx], c=colors[i], marker=markers[k], s=1000, linewidth=5))
    #plt.scatter(sourceix, sourceiy, label = "Source task "+str(i+1))

    xlim = ax.get_xlim()

    # set the x-spine (see below for more info on `set_position`)
    ax.spines['left'].set_position('zero')

    # turn off the right spine/ticks
    ax.spines['right'].set_color('none')
    ax.yaxis.tick_left()

    # set the y-spine
    ax.spines['bottom'].set_position('zero')

    # turn off the top spine/ticks
    ax.spines['top'].set_color('none')
    ax.xaxis.tick_bottom()

    xx = np.linspace(-.5, .5, 3)
    plt.plot(xx, -xx*1./Wstar[i,1], color=colors[i], lw=8)

    coordW = [-3.5, .4]

    #ax.arrow(coordW[0], coordW[1], .8*Wstar[i, 0], .8*Wstar[i, 1], fc=colors[i], ec=colors[i], head_width=0.05)

    kw = dict(arrowstyle=style, color=colors[i], lw=8)
    a1 = patches.FancyArrowPatch((coordW[0], coordW[1]), (coordW[0]+.8*Wstar[i, 0], coordW[1]+.8*Wstar[i, 1]), **kw)
    plt.gca().add_patch(a1)

    if i == 0:
        ax.text(coordW[0], coordW[1] + .1, r"$\mathbf{w}^*_1$", fontsize=50)
    else:
        ax.text(coordW[0], coordW[1] - .2, r"$\mathbf{w}^*_2$", fontsize=50)

    ax.add_patch(patches.Rectangle((coordW[0]-.45, coordW[1]-0.25), 2.5, .8, facecolor='none', edgecolor='black', alpha=1, linewidth=2))

    # phi_star = (r"$\frac{\sigma_\text{max}(\mathbf{W}^*)}{\sigma_\text{max}(\mathbf{W}^*)}=1$")
    #phi_star = (r"$\frac{\sigma_\text{max}(\mathbf{W}^*)}{\sigma_\text{min}(\mathbf{W}^*)}xrightarrow{\varepsilon \rightarrow 0} +\infty$")
    phi_star = (r"$R_\sigma(\mathbf{W}^*)\xrightarrow{\varepsilon \rightarrow 0} +\infty$")


    ax.text(coordW[0] - .3, coordW[1] + .3, phi_star, fontsize=50)

    #phi_star = (r"$\Phi^* = \begin{pmatrix} 1 & 0 & 0\\0 & 1 & 0\end{pmatrix} \\ \frac{\sigma_\text{max}(\mathbf{W}^*)}{\sigma_\text{max}(\mathbf{W}^*)}=1$")
    #phi_star = (r"$\frac{\sigma_\text{max}(\mathbf{W}^*)}{\sigma_\text{max}(\mathbf{W}^*)}=1$")
    #ax.text(-1, -.4, phi_star, fontsize=14)

plt.xticks([], fontsize=30)
plt.yticks([], fontsize=30)
#ax.set_aspect('equal')
plt.ylim(-.3,1.1)
plt.xlim(-4.1, 2)

l1 = plt.legend([(p[0], p[1]), (p[2], p[3])], [r"\bf Source task 1 in $\Phi^*$ space", r"\bf Source task 2 in $\Phi^*$ space"], numpoints=1,
                    handler_map={tuple: HandlerTuple(ndivide=None)}, fontsize=42, bbox_to_anchor=(-.15,1,1.3,0.2),
                mode="expand", borderaxespad=0, loc = 'upper center', ncol=2, fancybox=True, framealpha=1, shadow=True, borderpad=.4)

plt.savefig("./example_3D_phistar.pdf", dpi = 1200, bbox_inches='tight')
#plt.show()

plt.figure(num=2, figsize=(19.20,10.80))

for i in range(len(source)):

    ax = plt.gca()

    sourceix = source[i].dot(Fhat.T)[:,0]
    sourceiy = source[i].dot(Fhat.T)[:,1]

    for k, j in enumerate(np.unique(y)):
        idx = np.where(y == j)[0]
        p.append(ax.scatter(sourceix[idx], sourceiy[idx], c=colors[i], marker=markers[k], s=1000, linewidth=5))

    #plt.scatter(sourceix, sourceiy, label = "Source task "+str(i+1))

    xlim = ax.get_xlim()

    # set the x-spine (see below for more info on `set_position`)
    ax.spines['left'].set_position('zero')

    # turn off the right spine/ticks
    ax.spines['right'].set_color('none')
    ax.yaxis.tick_left()

    # set the y-spine
    ax.spines['bottom'].set_position('zero')

    # turn off the top spine/ticks
    ax.spines['top'].set_color('none')
    ax.xaxis.tick_bottom()

    xx = np.linspace(-.1, 1, 3)

    if What[i, 0] != 0:
        plt.plot(xx, -xx * 1. / What[i, 1], color=colors[i], lw=8)
    else:
        plt.plot(xx, np.zeros(len(xx)), color=colors[i], lw=8)

    coordW = [-1, 3]
    kw = dict(arrowstyle=style, color=colors[i], lw=8)

    if What[i, 0] == 0:
        a1 = patches.FancyArrowPatch((coordW[0], coordW[1]), (coordW[0]+.8*What[i,0], coordW[1]+3*What[i,1]), **kw)
    else:
        a1 = patches.FancyArrowPatch((coordW[0], coordW[1]), (coordW[0]+.3*What[i,0], coordW[1]+3*What[i, 1]), **kw)

    plt.gca().add_patch(a1)

    if i==0:
        ax.text(coordW[0], coordW[1]+3.5, r"$\hat{\mathbf{w}}_1$", fontsize=50)
    else:
        ax.text(coordW[0]+.15, coordW[1]-2, r"$\hat{\mathbf{w}}_2$", fontsize=50)

    ax.add_patch(
        patches.Rectangle((coordW[0]-.1, coordW[1]-2.5), .9, 8, facecolor='none', edgecolor='black', alpha=1, linewidth=2))

    #phi_star = (r"$\frac{\sigma_\text{max}(\mathbf{W}^*)}{\sigma_\text{max}(\mathbf{W}^*)}=1$")
    #phi_hat = (r"$\frac{\sigma_\text{max}(\widehat{\mathbf{W}})}{\sigma_\text{min}(\widehat{\mathbf{W})}}\xrightarrow{\varepsilon \rightarrow 0} 1$")
    phi_hat = (r"$R_\sigma(\widehat{\mathbf{W}})\xrightarrow{\varepsilon \rightarrow 0} 1$")
    ax.text(coordW[0]+.1, coordW[1]+.8, phi_hat, fontsize=50)

plt.xticks([])
plt.yticks([])
#ax.set_aspect('equal')
#ax.set_aspect(.1)
plt.ylim(-5,10)
plt.xlim(-1.2, 1.1)
l1 = plt.legend([(p[0], p[1]), (p[2], p[3])], [r"\bf Source task 1 in $\hat{\Phi}$ space", r"\bf Source task 2 in $\hat{\Phi}$ space"], numpoints=1,
                    handler_map={tuple: HandlerTuple(ndivide=None)}, fontsize=42, bbox_to_anchor=(-.15,1,1.3,0.2),
                mode="expand", borderaxespad=0, loc = 'upper center', ncol=2, fancybox=True, framealpha=1, shadow=True, borderpad=.3)

plt.savefig("./example_3D_phihat.pdf", dpi=1200, bbox_inches='tight')
#plt.show()
