import torch
import numpy as np
import cvxpy as cp
import torch.distributions as D
import scipy.stats as st
import matplotlib.pyplot as plt

torch.manual_seed(0)
beta = 0.5
n = 1000
num_epoch = 300
sigma = 0.1
x = torch.randn(n)
eps = torch.randn(n)
x_input = x + sigma*eps
label = -eps/sigma


# define model
class nn_type4(torch.nn.Module):
  def __init__(self):
    super(nn_type4, self).__init__()
    self.fc1 = torch.nn.Linear(1, n)
    self.fc2 = torch.nn.Linear(n, 1)
    self.v = torch.nn.Parameter(torch.randn(1, 1))
  def forward(self, x):
      skip_term = x*self.v
      x = self.fc1(x)
      x = torch.nn.functional.relu(x)
      x = self.fc2(x)
      return x + skip_term
  
def dsm_loss(model, samples, labels):
  pred_score = model(samples)
  loss_1 = ((pred_score-labels.unsqueeze(1))**2).sum()/2
  loss_2 = beta*((model.fc1.weight**2).sum() + (model.fc2.weight**2).sum())/2
  return loss_1 + loss_2

# # solve convex program
# # define tilde A
# A_pre = x.unsqueeze(1) - x
# A = torch.abs(A_pre)
# tilde_A = A - A.mean(axis=0)
# tilde_y = label - label.mean()
# x_t = (x - x.mean()).unsqueeze(1)
# B = (torch.eye(n)-x_t@x_t.T/((x_t**2).sum()))/2
# C = tilde_A.T@B@tilde_A
# if (C.T!=C).any():
#   C = (C+C.T)/2
# # add diagonal
# t0 = 1e-8
# while (torch.real(torch.linalg.eigvals(C))<1e-5).any():
#   C = C + t0*torch.eye(n)
#   t0 = t0*10
# #   assert t0<=1e-2
# assert (torch.real(torch.linalg.eigvals(C))>=1e-5).all()
# a = 2*tilde_A.T@B@tilde_y
# c = tilde_y@B@tilde_y


# y = cp.Variable(n)
# obj =  y@C@y + a@y + c + 2*beta*cp.norm(y,1)
# prob = cp.Problem(cp.Minimize(obj))
# prob.solve()
# print(prob.value)

# # reconstruct nn
# W_rec = torch.ones(n)
# b_rec = -x
# y_opt = torch.Tensor(y.value)
# alpha_rec = -torch.Tensor(y_opt)
# gamma_rec = torch.sqrt(torch.abs(2*alpha_rec/W_rec))
# W_rec = W_rec*gamma_rec
# b_rec = b_rec*gamma_rec
# alpha_rec = alpha_rec/gamma_rec
# y2 = -(tilde_A@y_opt+tilde_y)@x_t/((x_t**2).sum())
# b0_rec = (A@y_opt + x*y2 + label).mean()
# v_rec = -y2

# alpha_rec = 2*alpha_rec
# v_rec = v_rec - ((W_rec*alpha_rec).sum())/2
# b0_rec = b0_rec - ((b_rec*alpha_rec).sum())/2


# model_rec = nn_type4()
# model_rec.fc1.weight = torch.nn.Parameter(W_rec.unsqueeze(1))
# model_rec.fc1.bias = torch.nn.Parameter(b_rec)
# model_rec.fc2.weight = torch.nn.Parameter(alpha_rec.unsqueeze(1).T)
# model_rec.fc2.bias = torch.nn.Parameter(b0_rec.unsqueeze(0).unsqueeze(1))
# model_rec.v = torch.nn.Parameter(v_rec.unsqueeze(0).unsqueeze(1))

# opt_nn = dsm_loss(model_rec, x.unsqueeze(1), label).item()
# print('reconstruct nn loss: ', opt_nn)


# # non-convex nn
# num_epoch = 300
# trials = 10
# loss_list = torch.zeros(trials, num_epoch)
# for i in range(trials):
#   model_torch = nn_type4()
#   optimizer = torch.optim.Adam(model_torch.parameters(), lr=1e-2, weight_decay=0)
#   for t in range(num_epoch):
#     optimizer.zero_grad()
#     loss = dsm_loss(model_torch, x.unsqueeze(1), label)
#     loss_list[i][t] = loss.item()
#     # print(t, loss.item())
#     loss.backward()
#     optimizer.step()

# print(torch.min(loss_list).item())
# plt.axhline(y = opt_nn, color = 'b', linestyle = 'dashed', linewidth=5.0)
# for i in range(trials):
#   plt.plot(torch.linspace(0,200,200), loss_list[i][0:200], linewidth=5.0)
# plt.xlabel('# Epochs',fontsize=30)
# plt.ylabel('Training Loss',fontsize=30)
# plt.tight_layout()
# plt.savefig('dsm_type4_training.pdf')
# plt.show()

L = 10
sigmas = torch.linspace(1,0.01,L)
eps_0 = 2e-5
epss = eps_0*sigmas**2/sigmas[-1]**2
# print(epss)
T = 10
# model_rec = nn_type2()
def cvx_train(x, label):
    A_pre = x.unsqueeze(1) - x
    A = torch.abs(A_pre)
    tilde_A = A - A.mean(axis=0)
    tilde_y = label - label.mean()
    x_t = (x - x.mean()).unsqueeze(1)
    B = (torch.eye(n)-x_t@x_t.T/((x_t**2).sum()))/2
    C = tilde_A.T@B@tilde_A
    if (C.T!=C).any():
        C = (C+C.T)/2
    # add diagonal
    t0 = 1e-8
    while (torch.real(torch.linalg.eigvals(C))<1e-3).any():
        C = C + t0*torch.eye(n)
        t0 = t0*10
        #   assert t0<=1e-2
    assert (torch.real(torch.linalg.eigvals(C))>=1e-3).all()
    a = 2*tilde_A.T@B@tilde_y
    c = tilde_y@B@tilde_y


    y = cp.Variable(n)
    obj =  y@C@y + a@y + c + 2*beta*cp.norm(y,1)
    prob = cp.Problem(cp.Minimize(obj))
    prob.solve()
    print(prob.value)
    W_rec = torch.ones(n)
    b_rec = -x
    y_opt = torch.Tensor(y.value)
    alpha_rec = -torch.Tensor(y_opt)
    gamma_rec = torch.sqrt(torch.abs(2*alpha_rec/W_rec))
    W_rec = W_rec*gamma_rec
    b_rec = b_rec*gamma_rec
    alpha_rec = alpha_rec/gamma_rec
    y2 = -(tilde_A@y_opt+tilde_y)@x_t/((x_t**2).sum())
    b0_rec = (A@y_opt + x*y2 + label).mean()
    v_rec = -y2

    alpha_rec = 2*alpha_rec
    v_rec = v_rec - ((W_rec*alpha_rec).sum())/2
    b0_rec = b0_rec - ((b_rec*alpha_rec).sum())/2


    model_rec = nn_type4()
    model_rec.fc1.weight = torch.nn.Parameter(W_rec.unsqueeze(1))
    model_rec.fc1.bias = torch.nn.Parameter(b_rec)
    model_rec.fc2.weight = torch.nn.Parameter(alpha_rec.unsqueeze(1).T)
    model_rec.fc2.bias = torch.nn.Parameter(b0_rec.unsqueeze(0).unsqueeze(1))
    model_rec.v = torch.nn.Parameter(v_rec.unsqueeze(0).unsqueeze(1))
    return model_rec

# score = []
# for i in range(L):
#   print('training ', i, ' model')
#   x = torch.randn(n)
#   noise = torch.randn(n)
#   input = x + sigmas[i]*noise
#   label = -noise/sigmas[i]
#   score.append(cvx_train(input,label))

score = []
for i in range(L):
  x = torch.randn(n)
  noise = torch.randn(n)
  input = x + sigmas[i]*noise
  label = noise/sigmas[i]
  model = nn_type4()
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0)
  for t in range(num_epoch):
    optimizer.zero_grad()
    loss = dsm_loss(model, input.unsqueeze(1), label)
    loss.backward()
    optimizer.step()
  score.append(model)


n_samples = 10000
prior = D.Uniform(-1, 1)
z = prior.sample((n_samples, )).unsqueeze(1)
for i in range(L):
  for t in range(T):
    z = z + (epss[i]/2) * score[i](z) + torch.sqrt(epss[i]) * torch.randn(size=z.shape)
    # print('z=',score[i](z) )


s = z.detach().squeeze()
s = s[~torch.isnan(s)]
print('s=',s)
print('min s, max s', torch.min(s), torch.max(s))


plt.hist(s, density=True, bins=50)
mn, mx = plt.xlim()
plt.xlim(mn, mx) 
# plt.xlim(-10,10)
kde_xs = np.linspace(mn, mx, 300)
kde = st.gaussian_kde(s)
plt.plot(kde_xs, kde.pdf(kde_xs), linewidth=5.0)
# plt.legend(loc="upper left")
plt.xlabel("Sample",fontsize=30)
plt.ylabel("Frequency",fontsize=30)
plt.tight_layout()
plt.savefig('dsm_type4_noncvx_sample.pdf')
plt.show()



  