import numpy as np
import matplotlib.pyplot as plt
import os

N = 40

J = np.random.randn(N,N)
J = J + J.T
J = J/np.sqrt(N)

D = 5

T = 100

method = "AIM"

outpath = f"demo/{method}"

if(not os.path.exists(outpath)):
    os.makedirs(outpath)


def step_CIM(x, h):
    
    dt = 0.1
    a = 0.4
    beta = 0.5

    x_ = x[:, 0]  + dt*(-a*x[:, 0] - x[:,0]**3 + h)
    x[:,0] = x_
    return x

def step_CAC(x, h):
    
    dt = 0.1
    a = 0.4
    beta = 0.5

    x_ = x[:, 0] + dt*(-a*x[:, 0] - x[:,0]**3 + np.exp(x[:,1])*h)
    e_ = x[:, 1] + dt*beta*(1 - x[:,0]**2)
    x[:,0] = x_
    x[:,1] = e_
    return x

def step_AIM(x, h):
    
    dt = 0.2
    a = 0.5
    beta = 0.5
    lamb = 1.0
    z_ = x[:, 2]
    x_ = x[:, 2] + dt*(-a*x[:, 2]  + 0.5*h + lamb*(x[:, 2] - x[:, 1]))
    
    x[:,2] = x_
    x[:,1] = z_
    x[:,0] = np.tanh(z_)
    return x


def step_SBM(x, h):
    
    dt = 0.1
    a = 1.0
    beta = 0.5

    y_ = x[:, 1] + dt*( - a*x[:, 2] + 1.0*h)
    x_ = x[:, 2] + dt*(y_)

    s_ = np.sign(x_)
    flag = np.abs(x_) <= 1

    x_ = np.minimum(np.maximum(x_, -1), 1)
    y_ = y_*flag

    x[:,2] = x_
    x[:,1] = y_
    x[:,0] = s_
    return x

def step(x,h):
    if(method == "AIM"):
        return step_AIM(x,h)
    if(method == "SBM"):
        return step_SBM(x,h)
    if(method == "CAC"):
        return step_CAC(x,h)
    if(method == "CIM"):
        return step_CIM(x,h)

x = np.random.randn(N, D)*0.1

x_rec = np.zeros((N, T))

E_rec = np.zeros(T)

for i in range(T):
    h = -np.dot(J, x[:, 0])
    E_rec[i] = np.sum(np.dot(J, np.sign(x[:,0]))*np.sign(x[:,0]))

    x = step(x, h)
    x_rec[:, i] = x[:, 0]
    if(method == "SBM"):
        x_rec[:, i] = x[:, 2]


plt.figure(figsize=(5,4))

for i in range(N):
    plt.plot(range(T), x_rec[i, :])

plt.xlabel("time step")
plt.ylabel("spin amplitude (x)")

plt.tight_layout()


plt.savefig(outpath + "/spin_traj.png")
plt.show()
plt.close()


plt.figure(figsize=(5,4))

plt.plot(range(T), E_rec)

plt.ylabel("Ising energy")
plt.xlabel("time step")
plt.tight_layout()

plt.savefig(outpath + "/energy_traj.png")
plt.show()
plt.close()


h_field = np.random.randn(T)

testfunc = [-1,-1,1,1,1,2,2,-1,-1,0]
testfunc = np.array(testfunc)*1.4
h_field = np.interp(np.array(range(T))/T, np.array(range(len(testfunc)))/(len(testfunc) - 1), testfunc)

x = np.random.randn(1, D)*0.1
x_rec = np.zeros((1, T))

for i in range(T):
    h = h_field[i]
    x = step(x, h)
    x_rec[:, i] = x[:, 0]


plt.figure(figsize = (5,4))

plt.xlabel("time step")


plt.plot([0,T], [1,1], color = "gray", dashes = [3,3])
plt.plot([0,T], [-1,-1], color = "gray", dashes = [3,3])

plt.plot(range(T), h_field, label = "coupling field")
plt.plot(range(T), x_rec[0,:], label = "spin amplitude")

plt.legend()

plt.tight_layout()

plt.savefig(outpath + "/dynamics.png")
plt.show()
plt.close()
