import torch
import numpy as np
import cvxpy as cp
import torch.distributions as D
import scipy.stats as st
import matplotlib.pyplot as plt

torch.manual_seed(0)
beta = 0.5
n = 1000
num_epoch = 300
sigma = 0.1
x = torch.randn(n)
eps = torch.randn(n)
x_input = x + sigma*eps
label = -eps/sigma

# define model
class nn_type2(torch.nn.Module):
  def __init__(self):
    super(nn_type2, self).__init__()
    self.fc1 = torch.nn.Linear(1, 2*n)
    self.fc2 = torch.nn.Linear(2*n, 1)
  def forward(self, x):
      x = self.fc1(x)
      x = torch.nn.functional.relu(x)
      x = self.fc2(x)
      return x
  
def dsm_loss(model, samples, labels):
  pred_score = model(samples)
  loss_1 = ((pred_score-labels.unsqueeze(1))**2).sum()/2
  loss_2 = beta*((model.fc1.weight**2).sum() + (model.fc2.weight**2).sum())/2
  return loss_1 + loss_2




# solve convex program
# define tilde A
A1_pre = x.unsqueeze(1) - x
A1 = torch.where(A1_pre>=0, A1_pre, 0)
A2 = torch.where(-A1_pre>=0, -A1_pre, 0)
A = torch.hstack([A1, A2])
tilde_A = A - A.mean(axis=0)
tilde_y = label - label.mean()
y = cp.Variable(2*n)
obj = cp.sum_squares(tilde_A@y+tilde_y)/2 + beta*cp.norm(y,1)
prob = cp.Problem(cp.Minimize(obj))
prob.solve()
print(prob.value)

y_opt = torch.Tensor(y.value)
y0 = y_opt[:n]
y1 = y_opt[-n:]
b0_rec = (A@y_opt + label).mean()

# parameter 1...n
W1_rec = torch.ones(n)
b1_rec = -x
alpha1_rec = -torch.Tensor(y0)
gamma1_rec = torch.sqrt(torch.abs(alpha1_rec/W1_rec))
W1_rec = W1_rec*gamma1_rec
b1_rec = b1_rec*gamma1_rec
alpha1_rec = alpha1_rec/gamma1_rec
# parameter n+1...2n
W2_rec = -torch.ones(n)
b2_rec = x
alpha2_rec = -torch.Tensor(y1)
gamma2_rec = torch.sqrt(torch.abs(alpha2_rec/W2_rec))
W2_rec = W2_rec*gamma2_rec
b2_rec = b2_rec*gamma2_rec
alpha2_rec = alpha2_rec/gamma2_rec
# concat
W_rec = torch.concat([W1_rec, W2_rec])
b_rec = torch.concat([b1_rec, b2_rec])
alpha_rec = torch.concat([alpha1_rec, alpha2_rec])
model_rec = nn_type2()
model_rec.fc1.weight = torch.nn.Parameter(W_rec.unsqueeze(1))
model_rec.fc1.bias = torch.nn.Parameter(b_rec)
model_rec.fc2.weight = torch.nn.Parameter(alpha_rec.unsqueeze(1).T)
model_rec.fc2.bias = torch.nn.Parameter(b0_rec.unsqueeze(0).unsqueeze(1))

opt_nn = dsm_loss(model_rec, x.unsqueeze(1), label).item()
print('reconstruct nn loss: ', opt_nn)


# non-convex nn
num_epoch = 300
trials = 10
loss_list = torch.zeros(trials, num_epoch)
for i in range(trials):
  model_torch = nn_type2()
  optimizer = torch.optim.Adam(model_torch.parameters(), lr=1e-2, weight_decay=0)
  for t in range(num_epoch):
    optimizer.zero_grad()
    loss = dsm_loss(model_torch, x.unsqueeze(1), label)
    loss_list[i][t] = loss.item()
    # print(t, loss.item())
    loss.backward()
    optimizer.step()
print(torch.min(loss_list).item())

plt.axhline(y = opt_nn, color = 'b', linestyle = 'dashed', linewidth=5.0)
for i in range(trials):
  plt.plot(torch.linspace(0,200,200), loss_list[i][0:200], linewidth=5.0)
# plt.legend(['cvx'])
plt.xlabel('# Epochs',fontsize=30)
plt.ylabel('Training Loss',fontsize=30)
plt.tight_layout()
plt.savefig('dsm_type2_training.pdf')
plt.show()
print(model_rec((torch.rand(5)).unsqueeze(1)))

# L = 10
# sigmas = torch.linspace(1,0.01,L)
# eps_0 = 2e-5
# epss = eps_0*sigmas**2/sigmas[-1]**2
# # print(epss)
# T = 10
# # model_rec = nn_type2()
# def cvx_train(x, label):
#     # model_rec = nn_type2()
#     A1_pre = x.unsqueeze(1) - x
#     A1 = torch.where(A1_pre>=0, A1_pre, 0)
#     A2 = torch.where(-A1_pre>=0, -A1_pre, 0)
#     A = torch.hstack([A1, A2])
#     tilde_A = A - A.mean(axis=0)
#     tilde_y = label - label.mean()
#     y = cp.Variable(2*n)
#     obj = cp.sum_squares(tilde_A@y+tilde_y)/2 + beta*cp.norm(y,1)
#     prob = cp.Problem(cp.Minimize(obj))
#     prob.solve()
#     print(prob.value)

#     y_opt = torch.Tensor(y.value)
#     y0 = y_opt[:n]
#     y1 = y_opt[-n:]
#     b0_rec = (A@y_opt + label).mean()

#     # parameter 1...n
#     W1_rec = torch.ones(n)
#     b1_rec = -x
#     alpha1_rec = -torch.Tensor(y0)
#     gamma1_rec = torch.sqrt(torch.abs(alpha1_rec/W1_rec))
#     W1_rec = W1_rec*gamma1_rec
#     b1_rec = b1_rec*gamma1_rec
#     alpha1_rec = alpha1_rec/gamma1_rec
#     # parameter n+1...2n
#     W2_rec = -torch.ones(n)
#     b2_rec = x
#     alpha2_rec = -torch.Tensor(y1)
#     gamma2_rec = torch.sqrt(torch.abs(alpha2_rec/W2_rec))
#     W2_rec = W2_rec*gamma2_rec
#     b2_rec = b2_rec*gamma2_rec
#     alpha2_rec = alpha2_rec/gamma2_rec
#     # concat
#     W_rec = torch.concat([W1_rec, W2_rec])
#     b_rec = torch.concat([b1_rec, b2_rec])
#     model_rec = nn_type2()
#     alpha_rec = torch.concat([alpha1_rec, alpha2_rec])
#     model_rec.fc1.weight = torch.nn.Parameter(W_rec.unsqueeze(1))
#     model_rec.fc1.bias = torch.nn.Parameter(b_rec)
#     model_rec.fc2.weight = torch.nn.Parameter(alpha_rec.unsqueeze(1).T)
#     model_rec.fc2.bias = torch.nn.Parameter(b0_rec.unsqueeze(0).unsqueeze(1))
#     print(model_rec((torch.Tensor([1,2])).unsqueeze(1)))
#     return model_rec

# # score = []
# # for i in range(L):
# #   print('training ', i, ' model')
# #   x = torch.randn(n)
# #   noise = torch.randn(n)
# #   input = x + sigmas[i]*noise
# #   label = -noise/sigmas[i]
# #   score.append(cvx_train(input,label))

# # train score functions
# score = []
# for i in range(L):
#   x = torch.randn(n)
#   noise = torch.randn(n)
#   input = x + sigmas[i]*noise
#   label = noise/sigmas[i]
#   model = nn_type2()
#   optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0)
#   for t in range(num_epoch):
#     optimizer.zero_grad()
#     loss = dsm_loss(model, input.unsqueeze(1), label)
#     loss.backward()
#     optimizer.step()
#   score.append(model)



# n_samples = 10000
# prior = D.Uniform(-1, 1)
# z = prior.sample((n_samples, )).unsqueeze(1)
# for i in range(L):
#   for t in range(T):
#     z = z + (epss[i]/2) * score[i](z) + torch.sqrt(epss[i]) * torch.randn(size=z.shape)
#     print('z=',score[i](z) )


# s = z.detach().squeeze()
# s = s[~torch.isnan(s)]
# print('s=',s)
# print('min s, max s', torch.min(s), torch.max(s))


# plt.hist(s, density=True, bins=50)
# mn, mx = plt.xlim()
# plt.xlim(mn, mx) 
# # plt.xlim(-10,10)
# kde_xs = np.linspace(mn, mx, 300)
# kde = st.gaussian_kde(s)
# plt.plot(kde_xs, kde.pdf(kde_xs), linewidth=5.0)
# # plt.legend(loc="upper left")
# plt.xlabel("Sample",fontsize=30)
# plt.ylabel("Frequency",fontsize=30)
# plt.tight_layout()
# plt.savefig('dsm_type2_noncvx_sample.pdf')
# plt.show()


