#!/usr/bin/env python
# coding: utf-8

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import ortho_group
from scipy.linalg import block_diag, logm, expm, schur
import time
import random

#%% Functions

def translation(x_dim):
    mat = np.zeros((x_dim,x_dim))
    
    for i in range(x_dim):
        mat[i,np.mod(i+1,x_dim)] = 1
        
    return mat

def bivec(a,b):
    outer = np.outer(a,b)
    return ( outer - outer.T )

def cosine_similarity(a,b):
    return np.dot(a,b)/np.linalg.norm(a)/np.linalg.norm(b)

# Anti-symm SVD online by deflation, a-la GHA (Sanger 89) 
# from Matlab: skewGhaXYStep
def xpowerStep(x,y,u,v,eta):
    len,kk=u.shape
    u1=np.zeros((x_dim,p_dim))
    v1=np.zeros((x_dim,p_dim))
    xCumul=np.zeros(len);
    yCumul=np.zeros(len);
    for k in range(kk) :
        ax=np.dot(u[:,k],x); ay=np.dot(u[:,k],y);
        bx=np.dot(v[:,k],x); by=np.dot(v[:,k],y);
        z = ay*bx - ax*by;
        u1[:,k] = u[:,k] + eta*( bx*(y-yCumul) - by*(x-xCumul) - z*u[:,k] );
        v1[:,k] = v[:,k] - eta*( ax*(y-yCumul) - ay*(x-xCumul) + z*v[:,k] );
        xCumul = xCumul + ax*u[:,k] + bx*v[:,k];
        yCumul = yCumul + ay*u[:,k] + by*v[:,k];

    return u1,v1

#%%
random.seed(2022)
x_dim = 15; 
tr=translation(x_dim)
gen = logm(tr)
plt.imshow(gen)
plt.show()

Q,G,Vh=np.linalg.svd(gen)
print(G)
# Fourier?
import math
for i in range(math.floor(x_dim/2)):
    plt.figure(figsize=(4,2));plt.plot(Q[:,2*i:2*i+2])
plt.show()
#%% Generating the timeseries
tbeg = time.time()

samples = 5*10**5; sigma = .5

X = np.random.randn(x_dim,samples); Theta = sigma*np.random.randn(samples)+.2
X_rotated = np.zeros((x_dim,samples))
for t in range(samples):
    X_rotated[:,t] = expm(Theta[t]*gen)@X[:,t]

X_dot = X_rotated - X

C_dot = np.cov(X_dot)

u, s, vh = np.linalg.svd(C_dot); 
p_dim = 4; y_dim = 2*p_dim
proj = u[:,0:y_dim]@u[:,0:y_dim].T

print("Generation Time", time.time()-tbeg)

for i in range(math.floor(x_dim/2)):
    plt.figure(figsize=(4,2));plt.plot(u[:,2*i:2*i+2])
plt.show()

#%%  true subspaces
    
Q_ss = np.zeros((p_dim,x_dim*x_dim))
for i in range(p_dim):
    outer = np.outer(Q[:,2*i],Q[:,2*i+1])
    Q_ss[i,:] = ( outer - outer.T ).flatten()

#%% 

plt.bar(range(s.shape[0]),s,align='edge') #plt.hist(s,bins=20)
plt.bar(range(-2,3),[2-2*np.exp(-2*(np.pi*j*sigma/5)**2) for j in range(-2,3)]) # hist([2-2*np.exp(-2*(np.pi*j*sigma/5)**2) for j in range(-2,3)],bins=20,alpha=.4)
plt.show()

#%% ### Online PCA
tbeg = time.time()

eta = 1e-4; epochs = 5

Lam = np.eye(y_dim)

for i in range(p_dim):
    Lam[2*i,2*i] = p_dim - i
    Lam[2*i+1,2*i+1] = p_dim -i

error = np.zeros(epochs*samples)
dot = np.zeros((p_dim,samples))

W = np.random.randn(y_dim,x_dim)
M = np.eye(y_dim)

for epoch in range(epochs):
    
    for t in range(samples):
        
        F = np.linalg.inv(M)@W

        x = X[:,t]
        x_rot = X_rotated[:,t]
        y = F@x
        y_rot = F@x_rot

        W = W + eta*(np.outer(y_rot-y,x_rot-x) - W)
        M = M + eta*(np.outer(y_rot-y,y_rot-y) - Lam@M@Lam)

        for i in range(p_dim):
            dot[i,t] = np.dot(y[2*i:2*i+2],y_rot[2*i:2*i+2])/(np.linalg.norm(y[2*i:2*i+2])*np.linalg.norm(y_rot[2*i:2*i+2]))

        error[epoch*samples+t] = np.linalg.norm(F.T@np.linalg.inv(Lam**2)@F - proj)

print("PCA online Time", time.time()-tbeg)
#%% 
plt.loglog(error)
plt.xlim((1e2,epochs*samples))
plt.show()

#%%  filters: plot
fig, axs = plt.subplots(p_dim, sharex=True, figsize=(6,2*p_dim)); 
#fig.suptitle('PCA')
for i in range(p_dim):
    axs[i].plot(F[2*i:2*i+2,:].T)
plt.savefig(f'tr1d_filters.png', dpi=300, transparent='true', bbox_inches='tight')
#%%  diagonalize gtor? No
plt.imshow(F @ gen @ F.T)
plt.show()

#%%  Fourier
fF=np.fft.rfft(F)
fig, axs = plt.subplots(p_dim, sharex=True, figsize=(6,2*p_dim)); 
fig.suptitle('fft real')
for i in range(p_dim):
    axs[i].plot(np.real(fF[2*i:2*i+2,:]).T)
fig, axs = plt.subplots(p_dim, sharex=True, figsize=(6,2*p_dim)); 
fig.suptitle('fft imag')
for i in range(p_dim):
    axs[i].plot(np.imag(fF[2*i:2*i+2,:]).T)
#%% power
powF=np.abs(fF)**2
fig, axs = plt.subplots(p_dim, sharex=True, figsize=(3,2*p_dim)); 
#fig.suptitle('fft power')
for i in range(p_dim):
    axs[i].bar(range(powF.shape[0]),powF[2*i,:].T)
plt.savefig(f'tr1d_power.png', dpi=300, transparent='true', bbox_inches='tight')
#%% ### Online SVD
tbeg = time.time()
eta = 5e-3; epochs = 10

u=np.random.randn(x_dim,p_dim); u=u/np.linalg.norm(u,axis=0)
v=np.random.randn(x_dim,p_dim); v=v/np.linalg.norm(v,axis=0)
fitLog=np.zeros((samples*epochs,p_dim));
speedest=np.zeros((samples*epochs,p_dim));
aveChi=np.zeros((x_dim,x_dim));

for epoch in range(epochs):
    for t in range(samples) :
        x = X[:,t]
        x_rot = X_rotated[:,t]
        i=epoch*samples+t
        
        [u1,v1]=xpowerStep( x, x_rot, u, v, eta );
        u=u1;v=v1
        aveChi = aveChi*(i/(i+1)) + bivec(x_rot, x)/(i+1);
    
        for k in range(p_dim) :        #- plane fit
            fitLog[i,k] = cosine_similarity( bivec(u[:,k],v[:,k]).T.flatten(), Q_ss[k,:] )  # targetBivec[:,:,k].flatten() );
    
        # #- evaluate speed
        # a=u.T @ x_rot; b=v.T @ x_rot; 
        # ylen=np.sqrt(a**2+b**2);
        # atau=u.T @ x; btau=v.T @ x; ytaulen=np.sqrt(atau**2+btau**2);
        # speedest[i,:]=np.arcsin( (a*btau - b*atau) /ylen/ytaulen );

#plt.plot(fitLog); plt.show()
print("SVD online Time", time.time()-tbeg)
print(np.mean(Theta));
#% %  filters: Fourier
fig, axs = plt.subplots(p_dim, sharex=True, figsize=(6,2*p_dim)); fig.suptitle('SVD')
for i in range(p_dim):
    axs[i].plot( np.stack((u[:,i],v[:,i]),axis=1) )
#%%
Qs,Gs,Vhs=np.linalg.svd(aveChi)
print(G)
#%%  filters
fig, axs = plt.subplots(p_dim, sharex=True, figsize=(6,2*p_dim)); 
fig.suptitle('SVD offline')
for i in range(p_dim):
    axs[i].plot( np.stack((Qs[:,2*i],Qs[:,2*i+1]),axis=1) )
#%%  Fourier
fQs=np.fft.rfft(Qs,axis=0)
fig, axs = plt.subplots(p_dim, sharex=True, figsize=(6,2*p_dim)); 
fig.suptitle('SVD offline fft real')
for i in range(p_dim):
    axs[i].plot(np.real(fQs[:,2*i:2*i+2]))
fig, axs = plt.subplots(p_dim, sharex=True, figsize=(6,2*p_dim)); 
fig.suptitle('SVD offline fft imag')
for i in range(p_dim):
    axs[i].plot(np.imag(fQs[:,2*i:2*i+2]))

