import numpy
import matplotlib.pyplot as plt
import time

def normalize_vector(v):
  """ Normalize vector v to a unit-length vector.
  """
  return v / numpy.linalg.norm(v)

def normalize_each_row(V):
  """ Normalize each row of a matrix.
  """
  return (V.T / numpy.linalg.norm(V, axis=1) ).T

def softmax_prob(u, V, beta):
  """ u: user feature, a (d,) vector
      V: creator features, (n, d) matrix
      beta: parameter
  """
  unnormalized_prob = numpy.exp( beta * numpy.dot(V, u) )
  prob = unnormalized_prob / numpy.sum(unnormalized_prob)
  return prob

def get_dynamics(U_init, V_init, eta_u, eta_c, beta, T,
                 user_update_rule="inner_product", creator_update_rule="inner_product",
                 fixed_dimension=0):
  """ Input:
        - U_init: m * d matrix, each row is a user feature vector
        - V_init: n * d matrix, each row is a creator feature vector
      Returns a dict consisting of:
        - "U_record": a  T * m * d array
        - "V_record": a  T * n * d array
  """
  n, d = V_init.shape;  m = U_init.shape[0];  assert U_init.shape[1] == d
  U_record = numpy.zeros((T, m, d))
  V_record = numpy.zeros((T, n, d))
  U_record[0] = numpy.copy(U_init)
  V_record[0] = numpy.copy(V_init)

  if fixed_dimension > 0:
    non_fixed_norm = numpy.zeros((m))
    for j in range(m):
      non_fixed_norm[j] = numpy.linalg.norm(U_init[j, fixed_dimension:])

  for t in range(T-1):
    U = U_record[t]
    V = V_record[t]

    ### Sample: sample a creator for each user:
    user_to_creator = [None for j in range(m)]
    creator_to_users = [[] for i in range(n)]
    # p_matrix = numpy.zeros((n, m))
    for j in range(m):
      prob = softmax_prob(U[j], V, beta)
      # p_matrix[:, j] = prob
      i = numpy.random.choice(range(n), p=prob)
      user_to_creator[j] = i
      # add user j to creator i's list
      creator_to_users[i].append(j)

    ### User update:
    new_U = numpy.copy(U)
    eta_u_value = eta_u(t)
    for j in range(m):
      uj = U[j]
      vi = V[user_to_creator[j]]
      if fixed_dimension == 0:
        if user_update_rule == "inner_product":
          new_U[j] = normalize_vector( uj + eta_u_value * numpy.dot(uj, vi) * vi )
        elif user_update_rule == "sign":
          new_U[j] = normalize_vector( uj + eta_u_value * numpy.sign(numpy.dot(uj, vi)) * vi )
        else:
          raise BaseException(f"user update rule {user_update_rule} not supported")
      else:
        if user_update_rule == "inner_product":
          new_U[j, fixed_dimension:] = normalize_vector( uj[fixed_dimension:] +  eta_u_value*numpy.dot(uj, vi)*vi[fixed_dimension:] ) * non_fixed_norm[j]
        elif user_update_rule == "sign":
          new_U[j, fixed_dimension:] = normalize_vector( uj[fixed_dimension:] +  eta_u_value*numpy.sign(numpy.dot(uj, vi)*vi[fixed_dimension:])) * non_fixed_norm[j]
        else:
          raise BaseException(f"user update rule {user_update_rule} not supported")
      # print(numpy.linalg.norm(new_U[j]))  ## should be 1

    ### Creator update:
    new_V = numpy.copy(V)
    eta_c_value = eta_c(t)
    for i in range(n):
      if creator_to_users[i] != []:
        if creator_update_rule == "average":
            avg = numpy.mean(U[creator_to_users[i]], axis=0)
            new_V[i] = normalize_vector(V[i] + eta_c_value * avg)
        elif creator_update_rule == "inner_product":
            J = creator_to_users[i]
            tmp = ( U[J].T * numpy.dot(U[J], V[i]) ).T
            avg = numpy.mean(tmp, axis=0)
            new_V[i] = normalize_vector(V[i] + eta_c_value * avg)
        elif creator_update_rule == "sign":
            J = creator_to_users[i]
            tmp = ( U[J].T * numpy.sign(numpy.dot(U[J], V[i])) ).T
            avg = numpy.mean(tmp, axis=0)
            new_V[i] = normalize_vector(V[i] + eta_c_value * avg)
        elif creator_update_rule == "fixed":
            pass
        else:
            raise BaseException(f"creator update rule {creator_update_rule} not supported")

    # record the new feature vectors
    U_record[t+1] = new_U
    V_record[t+1] = new_V

  return {"U_record":U_record, "V_record":V_record}



import numpy as np
import plotly.express as px
import plotly.graph_objects as go

def show_dynamics_3d(record, show_every=1, show_array=None, title_string=""):
  """ record is a dictionary of U_record, V_record, where
      U_record: T * m * d
      V_record: T * n * d
  """
  U_record, V_record = record["U_record"], record["V_record"]
  T, m, d = U_record.shape;  n = V_record.shape[1]
  assert T == V_record.shape[0];  assert d == V_record.shape[2]

  T_list = list(range(T))
  for t in range(T):
    U, V = U_record[t], V_record[t]

    to_show=False
    if t == 0 or t == T-1:
      to_show = True
    elif show_array != None:
      to_show = t in show_array
    elif (t+1) % show_every == 0:
      to_show = True

    ### Plot:
    if to_show:
      u, v = np.mgrid[0 : 2*np.pi : 30j, 0 : np.pi : 20j]
      x = np.cos(u) * np.sin(v)
      y = np.sin(u) * np.sin(v)
      z = np.cos(v)

      data = []
      data.append( go.Surface(x=x, y=y, z=z, opacity=0.3, showscale=False) )

      data.append( go.Scatter3d(x=U[:, 0], y=U[:, 1], z=U[:, 2], name="users", mode="markers", marker=dict(size=6, color='red')))
      data.append( go.Scatter3d(x=V[:, 0], y=V[:, 1], z=V[:, 2], name="creators", mode="markers", marker=dict(size=4, color='blue', symbol='x')))
      # data.append( go.Scatter3d(x=[], y=[], z=[], name="creators", mode="markers", marker=dict(size=6, color='blue', symbol='x')))
      
      fig = go.Figure(data=data)
      # Adjust the camera angle
      camera = dict(up=dict(x=0, y=0, z=1),
                    center=dict(x=0, y=0, z=0),
                    # eye=dict(x=1.5, y=0.5, z=0.55)
                    eye=dict(x=2, y=0.65, z=0.7)
                    )
      fig.update_layout(scene_camera=camera, title=f"t={t+1}")
      # fig.update_layout(scene_camera=camera, showlegend=False)
      fig.update_layout(scene=dict(xaxis_title='', yaxis_title='', zaxis_title=''))
      fig.show()
      fig.write_image(f"./fig/" + title_string + f"-t={t+1}.png")


d = 3; n = 30; m = 60
eta_u = lambda t: 0.08; eta_c = lambda t: 0.08

T = 200

n_rep = 1

beta_list = numpy.array([0, 1, 3, 5])
n_beta = len(beta_list)

U_all = numpy.zeros((n_beta, n_rep, T, m, d))
V_all = numpy.zeros((n_beta, n_rep, T, n, d))

labels = []

V_init = normalize_each_row(numpy.random.normal(0, 1, size=(n, d)))
U_init = normalize_each_row(numpy.random.normal(0, 1, size=(m, d)))
    
for (i, beta) in enumerate(beta_list):
  print(f"Working on {(i, beta)}")
  labels.append(f'$\\beta={beta}$')
  UV_record = get_dynamics(U_init, V_init, eta_u, eta_c, beta, T, user_update_rule="inner_product",
                             creator_update_rule="inner_product", fixed_dimension=0)
  show_dynamics_3d(UV_record, show_every=100, title_string=f"beta={beta}")