# -*- coding: utf-8 -*-
"""CarsTestCondor.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/11sJ-MaXrIITFDT4F_-8W4eNOlW2aDo55
"""

import numpy as np
from sklearn.cluster import SpectralClustering
from sklearn.cluster import KMeans
import scipy.io
import matplotlib.pyplot as plt
import sys
from itertools import permutations
from sklearn.cluster import DBSCAN
import collections

def grassmannian_fusion(X:np.ndarray, Omega:np.ndarray, r:int, lamb = 5, max_iter = 10, step_size = 0.1, g_threshold = 0.15, init_U = None, bound_zero = 1e-10, singular_value_bound = 1e-2, g_column_norm_bound = 1e-5, U_manifold_bound = 1e-2):
  [m,n] = X.shape
  
  if init_U == None:
    #init
    U_array = [np.random.randn(m,r) for i in range(n)]
    for i in range(n):
      U_array[i][:,0] = X[:,i] / np.linalg.norm(X[:,i])
      q_i,r_i = np.linalg.qr(U_array[i])
      U_array[i] = q_i * r_i[0,0]
      
      #print(U_array[i].shape)
      #make sure the first col is x_i
      assert np.linalg.norm(U_array[i][:,0] - X[:,i] / np.linalg.norm(X[:,i])) < bound_zero
      #make sure its orthogonal
      assert np.linalg.norm(U_array[i].T @ U_array[i] - np.identity(r)) < bound_zero
      #make sure its normal
      assert  np.linalg.norm( np.linalg.norm(U_array[i], axis = 0) - np.ones(r) )  < bound_zero
  else:
    U_array = init_U

  #construct X^0_i
  Omega_i = [np.sort(Omega[Omega % n == i]) // n for i in range(n)]
  #find the compliment of Omega_i
  Omega_i_compliment = [sorted(list(set([i for i in range( m)]) - set(list(o_i)))) for o_i in Omega_i]

  #calculate length of U
  len_Omega = [o.shape[0] for o in Omega_i]
 
  #init X^0
  X0 = [np.zeros((m, m - len_Omega[i] + 1)) for i in range(n)]
  
  for i in range(n):
    #fill in the first row with normalized column
    X0[i][:,0] = X[:,i] / np.linalg.norm(X[:,i])

    for col_index,row_index in enumerate(Omega_i_compliment[i]):
      #fill in the "identity matrix"
      X0[i][row_index, col_index+1] = 1

    #print(X0[i])
    #each row should be having only 1 nonzero component
    #print(Omega_i[i])
    for j in range(X0[i].shape[0]):
      #print(j)
      #print(X0[i][j,:])
      assert np.count_nonzero(X0[i][j,:]) == 1 or (j in Omega_i[i] and X0[i][j,0] == 0)

  new_U_array = [np.random.randn(m,r) for i in range(n)]
  gradient_record = []
  obj_record = []
  #main algo
  for iter in range(max_iter):
    gradient = np.zeros((m,r))
    for i in range(n):
      
      #SVD
      A = X0[i] @ X0[i].T @ U_array[i]
      U_A,s_A,VT_A = np.linalg.svd(A)
      #leading vector
      u = U_A[:,0]
      vt = VT_A[0,:]

      G = -2 * s_A[0] * np.outer(u,vt)

      #compute V_ij, D_ij, W_ij
      V_i = []
      D_i = []
      WT_i = []
      for j in range(n):
        v, d, wt = np.linalg.svd(U_array[i].T @ U_array[j])
        V_i.append(v)
        D_i.append(d)
        WT_i.append(wt)

      G_ij = [ (U_array[j] @ WT_i[j].T - U_array[i] @ V_i[j] @ np.diag(D_i[j])) for j in range(n) ]
      #G_ij = [g_ij / np.linalg.norm(g_ij, axis = 0) for g_ij in G_ij]

      
      for g in range(len(G_ij)):
        for col in range(G_ij[i].shape[1]):
          if np.linalg.norm(G_ij[g][:,col]) < g_column_norm_bound:
            #G_ij[g][:,col] = 0
            continue
          else:
            G_ij[g][:,col] = G_ij[g][:,col] / np.linalg.norm(G_ij[g][:,col]) 
      

      #assure that G_ij is column normalized
      for g_ij in G_ij:
        for col in range(g_ij.shape[1]):
          assert abs(np.linalg.norm(g_ij[:,col]) - 1) < bound_zero or np.linalg.norm(g_ij[:,col]) < g_column_norm_bound

      #hard cap:some values really close to 1 can be larger than 1
      for j in range(n):
        if np.sum(D_i[j] - 1 > singular_value_bound) > 0:
          print(iter, i)
          print(D_i[j] )
          raise Exception("D_i entry > 1")
        D_i[j][D_i[j] > 1] = 1
   
      
      
      H_i = (np.identity(m) - U_array[i] @ U_array[i].T) @ G - lamb / 2 * sum([ G_ij[j] @ np.diag(np.arccos(D_i[j])) @ V_i[j].T  for j in range(n)])
      gradient += H_i

      Gamma_i, Del_i, ET_i = np.linalg.svd(-1 * H_i, full_matrices= False)
      first_term = np.concatenate((U_array[i]@ET_i.T, Gamma_i), axis = 1)
      second_term = np.concatenate((np.diag(np.cos(step_size * Del_i)), np.sin(step_size * np.diag(Del_i))), axis = 0)

      new_U_array[i] = first_term @ second_term @ ET_i

      #check if U is still in the manifold
      assert np.linalg.norm(new_U_array[i].T @ new_U_array[i] - np.identity(new_U_array[i].shape[1])) < U_manifold_bound

      if iter % 10 == 0:
        u,s,vt = np.linalg.svd(new_U_array[i], full_matrices = False)
        new_U_array[i] = u@vt


    #calculate the objective
    obj = 0
    for i in range(n):
      u,s,vt = np.linalg.svd(X0[i].T @ U_array[i])
      if s[0]> 1 and s[0] - 1 < singular_value_bound:
        s[0] = 1
      elif s[0] > 1:
        raise Exception('s[0] = ', s[0])
      #max(np.linalg.norm(X[:,i])**2 , 1)

      if (i == 1 and iter == 99):
        print("Test:", 1 - s[0]**2)
      obj += 1 - s[0]**2
      
      for j in range(i+1,n):
        u,s,vt = np.linalg.svd(U_array[i].T @ U_array[j])
        for r_index in range(r):
          if s[r_index]> 1 and s[r_index] - 1 < singular_value_bound:
            s[r_index] = 1
          elif s[r_index] > 1:
            raise Exception('Ui^T Uj, s[0] = ', s[r_index])
        sum_s = sum([np.arccos(s[i_r])**2 for i_r in range(r)])
        obj += lamb / 2 * np.sqrt(sum_s)


    if len(obj_record) == 0 or obj <= obj_record[-1]:
      U_array = new_U_array
      obj_record.append(obj)
    else:
      step_size = step_size * 0.5


    #diff = sum([np.linalg.norm(U_array[i] - Old_U[i]) for i in range(n)])
    g_total = np.sqrt(np.trace(gradient.T @ gradient))
    gradient_record.append(g_total)

    if iter % 100 == 0:
      print('iter', iter, ', total gradient = ', g_total)
      print('Obj value:', obj)

    if g_total < g_threshold:
      print('iter', iter, ', total gradient = ', g_total)
      print('Obj value:', obj)
      break

    if step_size < 1e-6:
      print('iter', iter, ', total gradient = ', g_total)
      print('Obj value:', obj)
      print('step_size: ', step_size)
      break
  
  '''
  plt.figure(figsize=(12,5))
  plt.subplot(1,2,1)
  plt.plot(gradient_record)
  plt.ylabel('norm of gradient')
  plt.xlabel('Iteration')

  plt.subplot(1,2,2)
  plt.plot(obj_record)
  plt.ylabel('Objective (from Equation 4)')
  plt.xlabel('Iteration')
  plt.savefig('training' + str(lamb) + '.png')
  
  #plt.show()
  '''

  info = {'gradient_record':gradient_record, "obj_record": obj_record}

  return U_array, info

def dUU(U_1, U_2, r):
  u,s,vt = np.linalg.svd(U_1.T @ U_2)

  for i in range(len(s)):
    if s[i] - 1 > 1e-5:
      raise Exception('s[',i,'] = ', s[i])
    elif s[i] > 1:
      s[i] = 1

  d = sum([np.arccos(s[i])**2 for i in range(r)])

  #print(u,s,vt)
  assert d >= 0
  return np.sqrt(d)

def evaluate(predict, truth, cluster):
  labels = [i for i in range(cluster)]
  p = permutations(labels)

  predict = np.array(predict)
  truth = np.array(truth)
  assert predict.shape == truth.shape

  err = 1
  for permuted_label in p:
    print("Permutation:", permuted_label)
    new_predict = np.zeros(len(predict), dtype = int)

    for i in range(len(labels)):
      new_predict[predict == labels[i]] = int(permuted_label[i])

    err_temp = np.sum(new_predict != truth) / len(predict)

    #print('predict:', new_predict)
    #print('truth:', truth)

    err = min(err, err_temp)
    print("Error Rate:", err_temp)

  return err

def read_file(file_name, downsample = False):
  mat = scipy.io.loadmat(file_name)
  data_x = mat['x']

  print('raw x shape (3, points, frames):', data_x.shape)
  s = mat['s'].reshape(-1)

  if downsample:
    index_downsample = np.random.choice(len(s), size = 100, replace = False)
    s_source = s[index_downsample]
    x_source = data_x[:,index_downsample,:]

  else:
    s_source = s
    x_source = data_x

  print('s_source shape:', np.shape(s_source))
  print('x_source shape:', np.shape(x_source))

  print(collections.Counter(s_source))

  return (x_source, s_source)

def main():
  #change the file_name to the one you want to try
  file_name = 'cars4_truth.mat'

  x_source, s_source = read_file(file_name)
  m = 2 * int(x_source.shape[2])
  n = len(s_source)

  #construct X
  X = np.zeros((m, n))

  for i in range(m):
    for j in range(n):
      frame = i // 2
      position = i % 2
      point = j

      X[i,j] = x_source[position, point, frame]

  print('X shape (position, points): ', X.shape)

  r = 3
  missing_rate = 0.3

  #observed index
  Omega = np.random.choice(m*n, size = int(m*n * (1-missing_rate) ), replace= False )

  #create observed matrix
  X_omega = np.zeros((m,n))
  for p in Omega:
    X_omega[p // n, p % n] = X[p // n, p % n]

  lambda_in = 1
  print('Paramter: lambda = ',lambda_in,', m = ', m, ', n = ',n,', r = ',r,', missing_rate =', missing_rate)
  print('||X - X_omega||_2 = ', np.linalg.norm(X - X_omega))

  print('########################################\nGradient Descent Begin')
  U_array,info = grassmannian_fusion(X_omega, Omega, r, lamb = lambda_in, max_iter= 1000, step_size = 1, g_threshold= 1e-3, bound_zero = 1e-10, singular_value_bound = 1e-5, g_column_norm_bound = 1e-5, U_manifold_bound = 1e-5)
  print('########################################\nGradient Descent End')

  #print('Info:\n')
  #print(info)


  #calculate the distance
  d_matrix = []
  for i in range(n):
    d_matrix_row = []
    for j in range(n):
      if i == j:
        d_matrix_row.append(0)
        continue
      d_matrix_row.append(dUU(U_array[i], U_array[j], r))

    d_matrix.append(d_matrix_row)
    
  d_matrix = np.array(d_matrix)

  #print('\nd_matrix')
  #print(list(d_matrix))
  K = max(s_source)


  print('\n########################################\n Classify Accuracy:')
  sc = SpectralClustering(n_clusters=K,assign_labels="discretize",random_state=0).fit(d_matrix)
  km = KMeans(n_clusters=K).fit(d_matrix)


  truth = s_source - 1
  print('Spectral:', 1 - evaluate(sc.labels_, truth , K) )
  print()

  truth = s_source - 1
  print('Kmeans:', 1 - evaluate(km.labels_, truth , K) )
  print()

  db = DBSCAN(eps=0.4, min_samples=25).fit(d_matrix)

  truth = s_source - 1
  acc = 0
  for i in set(db.labels_):
    print(i)
    l = db.labels_.copy()
    l[db.labels_ == i] = 1
    l[db.labels_ != i] = 0

    acc = max( sum(truth == l), len(l) - sum(truth == l), acc)


  print('DB:', acc / len(l) )
  print()

if __name__ == "__main__":
  main()

