import numpy as np
import random
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap


random.seed(2024)

def Generate_Pi(K):
  Pi=np.zeros((K,K))
  Pi1=np.zeros((K,K))
  Pi2=np.zeros((K,K))
  for i in range(K-1):
    Pi[i+1,i]=1
  Pi[0,K-1]=1
  Pi1=Pi
  Pi2=0.5*(Pi+Pi.T)
  return Pi1,Pi2

def Generate_sample0(K, N):
  X=np.zeros((K,N))
  for n in range(K):
    s=n
    for i in range(N):
      X[n,i]=s
      if s==0:
        s=K-1
      else:
        s=s-1
  return X

def Generate_sample1(n_sample, K, N):
  X=np.zeros((n_sample,N))
  for n in range(n_sample):
    s=random.randint(0,K-1)
    for i in range(N):
      X[n,i]=s
      if s==0:
        s=K-1
      else:
        s=s-1
  return X

def Generate_sample2(n_sample, K, N):
  X=np.zeros((n_sample,N))
  for n in range(n_sample):
    s=random.randint(0,K-1)
    for i in range(N):
      X[n,i]=s
      if s==0:
        r=random.randint(0,1)
        if r==0:
          s=1
        else:
          s=K-1
      elif s==K-1:
        r=random.randint(0,1)
        if r==0:
          s=0
        else:
          s=K-2
      else:
        r=random.randint(0,1)
        if r==0:
          s=s-1
        else:
          s=s+1
  return X





def update(n_sample, K, N, M, X, V, W12, W22, eta, epsilon):
  diffV=np.zeros((K,K))
  diffW12=np.zeros((K,N))
  diffW22=np.zeros((N,N))

  for n in range(n_sample):
    #XWx
    XWx=[0 for i in range(N)]
    for i in range(N-1):
      XWx[i]=W12[X[n,i],N-1]*np.sqrt(M)+W22[i,N-1]*M
    XWx[N-1]=W22[N-1,N-1]

    #S
#    S=[0 for i in range(N)]
#    total_exp=0
#    for i in range(N):
#      total_exp=total_exp+np.exp(XWx[i])
#    for i in range(N):
#      S[i]=np.exp(XWx[i])/total_exp

    XWx=np.array(XWx)
    exp_term = np.exp(XWx - XWx[:, np.newaxis])
    total_exp = np.sum(exp_term, axis=1)
    S = 1 / total_exp

    t=0
    for i in range(N-1):
      t=t+S[i]*V[X[n,N-1],X[n,i]]

    #V
    for i in range(N-1):
      diffV[X[n,N-1],X[n,i]]=diffV[X[n,N-1],X[n,i]]+S[i]/(t+epsilon)

    #W12
    A=0
    for j in range(N-1):
      A=A+S[j]*V[X[n,N-1],X[n,j]]
    for i in range(N-1):
      diffW12[X[n,i],N-1]=diffW12[X[n,i],N-1]+(V[X[n,N-1],X[n,i]]-A)*S[i]*np.sqrt(M)/(t+epsilon)

    #W22
    B=0
    for j in range(N-1):
      B=B+S[j]*V[X[n,N-1],X[n,j]]
    for i in range(N-1):
      diffW22[i,N-1]=diffW22[i,N-1]+(V[X[n,N-1],X[n,i]]-B)*S[i]*M/(t+epsilon)


  diffV=diffV*eta/n_sample
  diffW12=diffW12*eta/n_sample
  diffW22=diffW22*eta/n_sample

  V=V+diffV
  W12=W12+diffW12
  W22=W22+diffW22

  return V,W12,W22

def accuracy(n_test, K, N, M, V, W12, W22, c):
  if c==1:
    X_test=Generate_sample1(n_sample=n_test, K=K, N=N)
  if c==2:
    X_test=Generate_sample2(n_sample=n_test, K=K, N=N)

  X=X_test.astype(int)

  num_correct=0
  for n in range(n_test):
    #XWx
    XWx=[0 for i in range(N)]
    for i in range(N-1):
      XWx[i]=W12[X[n,i],N-1]*np.sqrt(M)+W22[i,N-1]*M
    XWx[N-1]=W22[N-1,N-1]*M

    #S
#    S=[0 for i in range(N)]
#    total_exp=0
#    for i in range(N):
#      total_exp=total_exp+np.exp(XWx[i])
#    for i in range(N):
#      S[i]=np.exp(XWx[i])/total_exp
    XWx=np.array(XWx)
    exp_term = np.exp(XWx - XWx[:, np.newaxis])
    total_exp = np.sum(exp_term, axis=1)
    S = 1 / total_exp

    #f
    f=[0 for i in range(K)]
    for i in range(N-1):
      for j in range(K):
        f[j]=f[j]+V[j,X[n,i]]*S[i]

    #predict
    if f[0]<f[1]:
      index=1
      s=f[1]
    else:
      index=0
      s=f[0]
    for i in range(K):
      if f[i]>s:
        index=i
        s=f[i]

    if index==X[n,N-1]:
      num_correct=num_correct+1

  accuracy=num_correct/n_test
  return accuracy








random.seed(2024)

#case 1 large K

n_sample=1000
K=20
N=101
M=1000
eta=1
epsilon=0.1
T=50
n_test=1000

X1=Generate_sample1(n_sample=n_sample, K=K, N=N)
X2=Generate_sample2(n_sample=n_sample, K=K, N=N)
X1=X1.astype(int)
X2=X2.astype(int)
V=np.zeros((K,K))
W12=np.zeros((K,N))
W22=np.zeros((N,N))
#V=np.random.normal(0,0.001,size=(K,K))
#W12=np.random.normal(0,0.001,size=(K,N))
#W22=np.random.normal(0,0.001,size=(N,N))

acc=np.random.rand(T)
iter=np.random.rand(T)

for i in range(T):
  VT,W12T,W22T=update(n_sample=n_sample, K=K, N=N, M=M, X=X1, V=V, W12=W12, W22=W22, eta=eta, epsilon=epsilon)
  V=VT
  W12=W12T
  W22=W22T
  acc[i]=accuracy(n_test=n_test, K=K, N=N, M=M, V=V, W12=W12, W22=W22, c=1)
  iter[i]=i+1

plt.plot(iter,acc)
plt.xlabel('iteration')
plt.ylabel('accuracy')
plt.ylim(0, 0.3)
plt.title('Accuracy')
plt.show()





plt.imshow(V, cmap='binary_r', interpolation='nearest', vmin = 0, vmax = 2)
plt.colorbar()
plt.xticks([])
plt.yticks([])
plt.show()




random.seed(2024)

X_test=Generate_sample2(n_sample=n_test, K=K, N=N)
X=X_test.astype(int)

S_total=[0 for i in range(N)]
for n in range(n_test):
  #XWx
  XWx=[0 for i in range(N)]
  for i in range(N-1):
    XWx[i]=W12[X[n,i],N-1]*np.sqrt(M)+W22[i,N-1]*M
  XWx[N-1]=W22[N-1,N-1]*M

  #S
  S=[0 for i in range(N)]
  total_exp=0
  for i in range(N):
    total_exp=total_exp+np.exp(XWx[i])
  for i in range(N):
    S[i]=np.exp(XWx[i])/total_exp
    S_total[i]=S_total[i]+S[i]

S_average=[i/n_test for i in S_total]
dimension=[i+1 for i in range(N)]

plt.bar(dimension, S_average)

plt.title('Token Seletion')
plt.xlabel('Position')
plt.ylabel('Softmax')
plt.ylim(0,0.02)
plt.xticks(range(1,N+1,10))

plt.show()