import numpy as np
import scipy as sp
import ot
from src import roc_mc_general as roc

import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import cm

n, m = 5, 5
M = np.array([[0.21127662, 0.04463152, 0.03679133, 0.47742201, 0.01048634],
              [0.03686981, 0.24643091, 0.00583015, 0.18072592, 0.4106865],
              [0.22275042, 0.00166351, 0.10119858, 0.13773912, 0.13874616],
              [0.05261907, 0.47245945, 0.29284259, 0.13475065, 0.17071307],
              [0.47648407, 0.23481461, 0.56333735, 0.0693623, 0.26936792]])
a = np.array([0.14028758, 0.22526308, 0.20529575, 0.37823114, 0.05092246])
b = np.array([0.13875259, 0.10392514, 0.24320017, 0.14342897, 0.37069313])
M_p = M

basis = np.array([[0, 1.],  # epsilon
                  [-0.7, -0.4],  # tau_1
                  [0.98, -0.1]])  # tau_2


def coord3tocoord2(coord3):
    return np.matmul(np.array(coord3), basis)


fig = plt.figure(figsize=np.array([3.1, 2.5]) * 2.5)
base_ax = fig.subplots()
base_ax.set_axis_off()
base_ax.set_xlim(-1.3, 1.8)
base_ax.set_ylim(-1,1.5)
pts = [np.array([0,0.]), 
       basis[0], basis[1], basis[2],
       basis[1] + basis[2],
       basis[0] + basis[1], 
       basis[0] + basis[2],]
edge_1 = np.array([[0,0], [0,1]])
edge_2 = np.array([[0,0], [-0.7,-0.4]])
edge_3 = np.array([[0,0], [0.98,-0.1]])

base_ax.plot(edge_1[:,0], edge_1[:,1], alpha=0.15, c="k")
base_ax.plot(edge_2[:,0], edge_2[:,1], alpha=0.15, c="k")
base_ax.plot(edge_3[:,0], edge_3[:,1], alpha=0.15, c="k")
for offset in [pts[i] for i in [2,3,4]]:
    base_ax.plot(edge_1[:,0]+offset[0], edge_1[:,1]+offset[1], color='k')
for offset in [pts[i] for i in [2,]]:
    base_ax.plot(edge_3[:,0]+offset[0], edge_3[:,1]+offset[1], color='k')
for offset in [pts[i] for i in [3,]]:
    base_ax.plot(edge_2[:,0]+offset[0], edge_2[:,1]+offset[1], color='k')


for offset in [pts[i] for i in [1,5]]:
    base_ax.plot(edge_3[:,0]+offset[0], edge_3[:,1]+offset[1],"--", color='k')
for offset in [pts[i] for i in [1,6]]:
    base_ax.plot(edge_2[:,0]+offset[0], edge_2[:,1]+offset[1],"--", color='k')



for vec in [edge_1,edge_2, edge_3]:
    base_ax.plot(vec[1,0]*np.array([1,1.4]),vec[1,1]*np.array([1,1.4]),c='k')    


c2 = coord3tocoord2([1.4,0,0])
base_ax.text(c2[0]-0.25, c2[1], "sigmoid(log($\\epsilon_P$))", fontsize="large",c='b')
c2 = coord3tocoord2([0.03,1.7,0])
base_ax.text(c2[0], c2[1], "sigmoid(log($\\epsilon_\\eta$))", fontsize="large",c='b')
c2 = coord3tocoord2([0,0,1.4,])
base_ax.text(c2[0], c2[1], "sigmoid(log($\\epsilon_\\theta$))", fontsize="large",c='b')

vis_ax = []
vis_ax += [fig.add_axes([0.05,0.8,0.14,0.14]), # 0
           fig.add_axes([0.06,0.3,0.1,0.1]),  # 1
           fig.add_axes([0.55,0.81,0.1,0.1]),  # 2
           fig.add_axes([0.8,0.25,0.1,0.1]),  # 3
           fig.add_axes([0.85,0.64,0.1,0.1]),  # 4
           fig.add_axes([0.85,0.81,0.1,0.1]),  # 5
           fig.add_axes([0.85,0.46,0.1,0.1]),   # 6
           fig.add_axes([0.06,0.48,0.1,0.1]),  # 7
           fig.add_axes([0.06,0.65,0.1,0.1]),  # 8
           ]
for ax in vis_ax:
    ax.set_axis_off()

M = M_p.copy()
cmap = cm.get_cmap("viridis").copy()
cmap.set_bad('white')
aug_M = np.ones([7,7])*np.nan
aug_M[2:,2:] = M.T.copy()

ta = a.copy()
ta *= np.max(M) / np.max(a)
tb = b.copy()
tb *= np.max(M) / np.max(b)
aug_M[0,2:] = tb
aug_M[2:,0] = ta
M = -np.log(M)
print(M)
vis_ax[0].text(-2,3.8, "$\\eta$", fontsize="large")
vis_ax[0].text(2.8,-0.6, "$\\theta$", fontsize="large")
vis_ax[0].text(0.5,7.6,"$M=\\exp(-C)$", fontsize="large")
vis_ax[0].imshow(aug_M, cmap, vmin=0)

vis_ax[1].imshow(roc.sinkhorn_knopp_unbalanced(a,b,M,0.0051,1e9,1e9), vmin=0)
vis_ax[1].text(-1,5.8,"$(0,\\infty,\\infty)$:", fontsize="large")
vis_ax[1].text(-2,7,"Discriminative", fontsize="large")
c2 = coord3tocoord2([0,1,1])
base_ax.annotate("c2",xy=c2, xytext=np.array([-1.4,-0.3]),
                 arrowprops = dict(
                                   arrowstyle="->",
                                   color="red",
                                   connectionstyle="angle,angleA=-20,angleB=35,rad=15"))
base_ax.text(c2[0]+0.04, c2[1]-0.06, "Discr", fontsize="large")
base_ax.scatter(c2[0],c2[1],marker='o',c='k')

vis_ax[2].imshow(roc.sinkhorn_knopp_unbalanced(a,b,M,2000,1e9,1e9), vmin=0)
vis_ax[2].text(-3,5.5,"$(\\infty,\\infty,\\infty): \\eta\\otimes\\theta$", fontsize="large")
c2 = coord3tocoord2([1,1,1])
base_ax.annotate("c2",xy=c2, xytext=np.array([0.6,1.4]),
                 arrowprops = dict(
                                   arrowstyle="->",
                                   color="red",
                                   connectionstyle="angle,angleA=20,angleB=-85,rad=15"))
base_ax.text(c2[0]+0.03, c2[1]-0.04, "$\\eta\\otimes\\theta$", fontsize="large")
base_ax.scatter(c2[0],c2[1],marker='o',c='k')

vis_ax[3].imshow(roc.sinkhorn_knopp_unbalanced(a,b,M,1,1e9,1e9), vmin=0)
vis_ax[3].text(-1,5.5,"$(1,\\infty,\\infty)$:", fontsize="large")
vis_ax[3].text(-8,6.8,"Cooperative Communication", fontsize="large")
c2 = coord3tocoord2([0.5,1,1])
base_ax.annotate("",xy=c2, xytext=np.array([1.6,-0.33]),
                 arrowprops = dict(
                                   arrowstyle="->",
                                   color="red",
                                   connectionstyle="angle,angleA=-5,angleB=-70,rad=15"))
base_ax.text(c2[0]+0.02, c2[1], "CC", fontsize="large")
base_ax.scatter(c2[0],c2[1],marker='o',c='k')

vis_ax[4].imshow(roc.sinkhorn_knopp_unbalanced(a,b,M,1,0,1e9), vmin=0)
vis_ax[4].text(-1,5.5,"$(1,0,\\infty)$:", fontsize="large")
vis_ax[4].text(-1,6.8,"Bayesian", fontsize="large")
c2 = coord3tocoord2([0.5,0,1])
base_ax.annotate("",xy=c2, xytext=np.array([1.8,0.85]),
                 arrowprops = dict(
                                   arrowstyle="->",
                                   color="red",
                                   connectionstyle="angle,angleA=5,angleB=77,rad=15"))
base_ax.text(c2[0]+0.02, c2[1]-0.06, "Bayesian", fontsize="large")
base_ax.scatter(c2[0],c2[1],marker='o',c='k')

vis_ax[5].imshow(roc.sinkhorn_knopp_unbalanced(a,b,M,2000,0,1e9), vmin=0)
vis_ax[5].text(-1,5.7,"$(\\infty,0,\\infty)$:", fontsize="large")
vis_ax[5].text(-0,6.9,"Tyrant", fontsize="large")
c2 = coord3tocoord2([1,0,1])
base_ax.annotate("",xy=c2, xytext=np.array([1.8,1.5]),
                 arrowprops = dict(
                                   arrowstyle="->",
                                   color="red",
                                   connectionstyle="angle,angleA=2,angleB=79,rad=15"))
base_ax.text(c2[0]+0.02, c2[1]-0.06, "Tyrant", fontsize="large")
base_ax.scatter(c2[0],c2[1],marker='o',c='k')

vis_ax[6].imshow(roc.sinkhorn_knopp_unbalanced(a,b,M,0.0051,0,1e9), vmin=0)
vis_ax[6].text(-1,5.5,"$(0,0,\\infty)$:", fontsize="large")
vis_ax[6].text(-2.5,6.8,"Column Greedy", fontsize="large")
c2 = coord3tocoord2([0,0,1])
base_ax.annotate("",xy=c2, xytext=np.array([1.8, 0.3]),
                 arrowprops = dict(
                                   arrowstyle="->",
                                   color="red",
                                   connectionstyle="angle,angleA=5,angleB=79,rad=15"))
base_ax.text(c2[0]+0.02, c2[1]-0.06, "CG")
base_ax.scatter(c2[0],c2[1],marker='o',c='k')

vis_ax[7].imshow(roc.sinkhorn_knopp_unbalanced(a,b,M,0.01,1e9,0), vmin=0)
vis_ax[7].text(-1,5.7,"$(0,\\infty,0)$:", fontsize="large")
vis_ax[7].text(-2,7,"Row Greedy", fontsize="large")
c2 = coord3tocoord2([0,1,0])
base_ax.annotate("",xy=c2, xytext=np.array([-1.3,0.3]),
                 arrowprops = dict(
                                   arrowstyle="->",
                                   color="red",
                                   connectionstyle="angle,angleA=2,angleB=-79,rad=15"))
base_ax.text(c2[0]+0.02, c2[1]-0.06, "RG", fontsize="large")
base_ax.scatter(c2[0],c2[1],marker='o',c='k')

vis_ax[8].imshow(roc.sinkhorn_knopp_unbalanced(a,b,M,2000,1e9,0), vmin=0)
vis_ax[8].text(-3,5.5,"$(\\infty,\\infty,0)$: Frequentist", fontsize="large")
# vis_ax[8].text(-3.05,6.8,"$(\\epsilon_P,\\epsilon_\\eta,\\epsilon_\\theta)$", c='blue', fontsize="large")
c2 = coord3tocoord2([1,1,0])
base_ax.annotate("",xy=c2, xytext=np.array([-1.3,0.9]),
                 arrowprops = dict(
                                   arrowstyle="->",
                                   color="red",
                                   connectionstyle="angle,angleA=-2,angleB=-79,rad=15"))
base_ax.text(c2[0]+0.02, c2[1]-0.10, "Freq", fontsize="large")
base_ax.scatter(c2[0],c2[1],marker='o',c='k')

fig.savefig("figs/cube_figure.png", dpi=200)
