import torch
import numpy as np
import cvxpy as cp
import matplotlib.pyplot as plt
import scipy.stats as st
import torch.distributions as D


torch.manual_seed(1)
n = 500
x = torch.randn(n)
x = x.sort()[0]
beta=20

def sign_(x):
  x = torch.where(torch.abs(x)>1.5e-7, x, 0)
  return torch.nan_to_num(x/torch.abs(x), nan=1)


def clip(x):
  return torch.where(torch.abs(x)>1e-8, x, 0)


# form C
A = torch.abs(x.unsqueeze(1) - x)
A_t = A - A.mean(axis=0)
A_tt = torch.hstack([A_t, A_t])
x_t = (x - x.mean()).unsqueeze(1)
B = (torch.eye(n)-x_t@x_t.T/((x_t**2).sum()))/2
# B_sq = torch.from_numpy(scipy.linalg.sqrtm(B).real)
# C_ = B_sq@A_tt
# C = C_.T@C_
C = clip(A_tt.T@B@A_tt)
if (C.T!=C).any():
  C = (C+C.T)/2

print('min eig: ', min(torch.real(torch.linalg.eigvals(C))))
# add diagonal
t0 = 1e-8
while (torch.real(torch.linalg.eigvals(C))<1e-5).any():
  assert t0<=1e-2
  C = C + t0*torch.eye(2*n)
  t0 = t0*10
#   print('min eig: ', min(torch.real(torch.linalg.eigvals(C))))
#   print('t0=', t0)
assert (torch.real(torch.linalg.eigvals(C))>=1e-5).all()

# form b
b0 = (-n/(x_t**2).sum())*x_t
b1 = sign_(x.unsqueeze(1) - x).sum(axis=0)
b2 = sign_(-x.unsqueeze(1) + x).sum(axis=0)
b = (A_tt.T@b0).squeeze() + torch.hstack([b1,-b2])
beta = torch.max(torch.abs(b)) + 1

# cvx program
y = cp.Variable(2*n)
obj = y.T@C@y + b@y + beta*cp.norm(y,1) - n**2/(2*(x_t**2).sum())
prob = cp.Problem(cp.Minimize(obj))
prob.solve()
print('cvx optimal value: ', prob.value)

y_opt = torch.Tensor(y.value)
y0 = y_opt[:n]
y1 = y_opt[-n:]
y2 = -(x_t.T@A_t@(y0+y1)+n)/(x_t**2).sum()
y3 = -(A@(y0+y1)+x.unsqueeze(1)@y2).mean()
print((torch.abs(y_opt)>1e-6).nonzero(as_tuple=True)[0])

# define model
class net(torch.nn.Module):
  def __init__(self):
    super(net, self).__init__()
    self.fc1 = torch.nn.Linear(1, 2*n)
    self.fc2 = torch.nn.Linear(2*n, 1)
    self.v = torch.nn.Parameter(torch.randn(1, 1))
  def forward(self, x):
      skip_term = x*self.v
      x = self.fc1(x)
      x = torch.abs(x)
      x = self.fc2(x)
      return x + skip_term
  
# define loss
def score_matching(model, samples):
  pred_score = model(samples)
  loss_1 = (pred_score**2).sum()/2
  loss_2 = (sign_(model.fc1(samples))*(model.fc1.weight.squeeze()*model.fc2.weight.squeeze())).sum() + n*model.v
  loss_3 = beta*((model.fc1.weight**2).sum() + (model.fc2.weight**2).sum())/2
  return (loss_1 + loss_2+ loss_3).squeeze()

# parameter 1...n
W1_rec = torch.ones(n)
b1_rec = -x
alpha1_rec = torch.Tensor(y0)
gamma1_rec = torch.sqrt(torch.abs(alpha1_rec/W1_rec))
W1_rec = W1_rec*gamma1_rec
b1_rec = b1_rec*gamma1_rec
alpha1_rec = alpha1_rec/gamma1_rec
# parameter n+1...2n
W2_rec = -torch.ones(n)
b2_rec = x
alpha2_rec = torch.Tensor(y1)
gamma2_rec = torch.sqrt(torch.abs(alpha2_rec/W2_rec))
W2_rec = W2_rec*gamma2_rec
b2_rec = b2_rec*gamma2_rec
alpha2_rec = alpha2_rec/gamma2_rec
# v
v_rec = y2
# b0
b0_rec = y3


model_rec = net()
W_rec = torch.concat([W1_rec, W2_rec])
model_rec.fc1.weight = torch.nn.Parameter(W_rec.unsqueeze(1))
b_rec = torch.concat([b1_rec, b2_rec])
model_rec.fc1.bias = torch.nn.Parameter(b_rec)
alpha_rec = torch.concat([alpha1_rec, alpha2_rec])
model_rec.fc2.weight = torch.nn.Parameter(alpha_rec.unsqueeze(1).T)
model_rec.fc2.bias = torch.nn.Parameter(b0_rec.unsqueeze(0).unsqueeze(1))
model_rec.v = torch.nn.Parameter(v_rec.unsqueeze(0).unsqueeze(1))

print('reconstruct loss: ', score_matching(model_rec, x.unsqueeze(1)))

opt_nn = score_matching(model_rec, x.unsqueeze(1)).item()

num_epoch = 500
trials = 10
loss_list = torch.zeros(trials, num_epoch)
for i in range(trials):
  model_torch = net()
  optimizer = torch.optim.Adam(model_torch.parameters(), lr=1e-2, weight_decay=0)
  for t in range(num_epoch):
    optimizer.zero_grad()
    loss = score_matching(model_torch, x.unsqueeze(1))
    loss_list[i][t] = loss.item()
    # print(t, loss.item())
    loss.backward()
    optimizer.step()


print('nonconvex loss: ', torch.min(loss_list))

plt.axhline(y = opt_nn, color = 'b', linestyle = 'dashed', linewidth=5.0)
for i in range(trials):
  plt.plot(torch.linspace(100,500,400),loss_list[i][100:num_epoch], linewidth=5.0)
# plt.legend(['cvx'])
plt.xlabel('# Epochs',fontsize=30)
plt.ylabel('Training Loss',fontsize=30)
plt.tight_layout()
plt.savefig('abs_skip_%i.pdf'%n)
# plt.yscale('log')
plt.show()

plt.axhline(y = opt_nn, color = 'b', linestyle = 'dashed', linewidth=4.0)
for i in range(trials):
  plt.plot(loss_list[i][0:200], linewidth=2.0)
# plt.legend(['cvx'])
plt.xlabel('# Epochs',fontsize=18)
plt.ylabel('Training Loss',fontsize=18)
plt.tight_layout()
plt.savefig('abs_skip.pdf')
plt.show()


x_hat = torch.linspace(-10,10,10000)
y_hat = model_rec(x_hat.unsqueeze(1)).detach().squeeze()
plt.plot(x_hat,y_hat, linewidth=5.0)
plt.xlabel('x',fontsize=30)
plt.ylabel("Score",fontsize=30)
# plt.xlim([min(x)*1.5,max(x)*1.5])
# plt.ylim([-2.1,4])
plt.tight_layout()
plt.savefig('abs_skip_score.pdf')
plt.show()

eps = torch.Tensor([0.5])
steps = 100
n_samples = 100000

# d/dx -log p(x) for N(0, 1)
force = lambda x: model_rec(x.unsqueeze(1)).detach().squeeze()
# Prior distribution from which the initial positions are sampled
prior = D.Uniform(-20, 20)
# Run Langevin Dynamics
z = prior.sample((n_samples, ))
for j in range(5):
    for i in range(steps):
        # print(i)
        z = z + eps * force(z) + torch.sqrt(2 * eps) * torch.randn(size=z.shape)
    eps=eps/2
s = z.detach()

N = 50000
n_ = N//5000

plt.hist(s, density=True, bins=50)
mn, mx = plt.xlim()
plt.xlim(mn, mx)
kde_xs = np.linspace(mn, mx, 300)
kde = st.gaussian_kde(s)
plt.plot(kde_xs, kde.pdf(kde_xs), linewidth=5.0)
plt.legend(loc="upper left")
# plt.ylabel("Probability")
# plt.xlabel("Data")
# plt.title("Histogram")
plt.xlabel("Sample",fontsize=30)
plt.ylabel("Frequency",fontsize=30)
plt.tight_layout()
plt.savefig('abs_skip_hist.pdf')
plt.show()