#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
"""
import os
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1" # export OMP_NUM_THREADS=1
os.environ["OPENBLAS_NUM_THREADS"] = "1" # export OPENBLAS_NUM_THREADS=1
os.environ["MKL_NUM_THREADS"] = "1" # export MKL_NUM_THREADS=1
os.environ["VECLIB_MAXIMUM_THREADS"] = "1" # export VECLIB_MAXIMUM_THREADS=1
os.environ["NUMEXPR_NUM_THREADS"] = "1" # export NUMEXPR_NUM_THREADS=1


# The pathlib library (included with >= Python 3.4) makes it very concise and intuitive to append the path of
# the parent directory to the PYTHONPATH:
import sys
from pathlib import Path
sys.path.append(str(Path('.').absolute().parent))

import torch
from src.sse_sgw_utils import sgw_gpu, sse_gpu,ssegromov_gpu, distrib_sse, min_sse, distributional_min_sse
from src.risgw import risgw_gpu
from src.gwgan.model_mlp import  weights_init_generator

import pylab as plt
import numpy as np
import torch.nn as nn
from torch import optim
from time import process_time as time
from numpy import pi
#import copy


def make_spiral(n_samples, noise=.5):
    n = np.sqrt(np.random.rand(n_samples,1)) * 780 * (2*np.pi)/360
    d1x = -np.cos(n)*n + np.random.rand(n_samples,1) * noise
    d1y = np.sin(n)*n + np.random.rand(n_samples,1) * noise
    return np.array(np.hstack((d1x,d1y)))

get_rot= lambda theta : np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta),np.cos(theta)]])

def get_data(n_samples,theta,scale=1,transla=0):
    Xs = make_spiral(n_samples=n_samples, noise=1)-transla
    Xt = make_spiral(n_samples=n_samples, noise=1)
    
    A=get_rot(theta)
    
    Xt = (np.dot(Xt,A))*scale+transla
    
    return Xs,Xt



def rand_projections(dim, num_projections=1000):
    projections = torch.randn((num_projections, dim))
    projections = projections / torch.sqrt(torch.sum(projections ** 2, dim=1, keepdim=True))
    return projections

def spiral(N=100,phi=0):

    theta = np.sqrt(np.random.rand(N))*4*pi # np.linspace(0,2*pi,100)    
    r_a = theta/2 + pi
    data_a = np.array([np.cos(theta + phi)*r_a, np.sin(theta+phi)*r_a]).T
    x_a = data_a + np.random.randn(N,2)*0.3
    
    return torch.from_numpy(x_a).float()

    
class TransformNet(nn.Module):
    """
    used usually for changing the distribution of the random projection
    """
    def __init__(self, size):
        super(TransformNet, self).__init__()
        self.size = size
        self.net = nn.Sequential(nn.Linear(self.size,self.size))
    def forward(self, input):
        out =self.net(input)
        return out/torch.sqrt(torch.sum(out**2,dim=1,keepdim=True))

class TransformLatenttoOrig(nn.Module):
    """
    used for mapping the random projection vector into the the ambient space
    of the distribution
    """
    def __init__(self, dim_latent,dim_orig,dim_hidden=10):
        super(TransformLatenttoOrig, self).__init__()
        self.dim_latent = dim_latent
        self.dim_orig = dim_orig
        self.dim_hidden = dim_hidden
        self.net = nn.Sequential(nn.Linear(self.dim_latent,self.dim_hidden),
                                 #nn.Sigmoid(),
                                 nn.ReLU(),
                                 nn.Linear(self.dim_hidden,self.dim_hidden),
                                 #nn.Sigmoid(),
                                 nn.ReLU(),
                                nn.Linear(self.dim_hidden,self.dim_hidden),
                                 #nn.Sigmoid(),
                                 nn.ReLU(),
                                 nn.Linear(self.dim_hidden, self.dim_orig),
                                 nn.Sigmoid(),
                                 #nn.ReLU(),
                                 )
    def forward(self, input):
        out =self.net(input)
        return out/torch.sqrt(torch.sum(out**2,dim=1,keepdim=True))

import argparse

parser = argparse.ArgumentParser(description='Toy Compare')
parser.add_argument('--n_samples', type=int, default=500  )
parser.add_argument('--nproj', type=int, default=20  )
parser.add_argument('--nproj_dist', type=int, default=10  )
parser.add_argument('--nproj_dist_d', type=int, default=10  )
parser.add_argument('--n_iter_inner',type=int,default=50)
parser.add_argument('--data',type=str, default='spiral' )
parser.add_argument('--model',type=str, default='distrib_min_sse' )

args = parser.parse_args()
print(args)


n_samples = args.n_samples
nproj = args.nproj         # nb of projections in sse
nproj_dist = args.nproj_dist
nproj_dist_d = args.nproj_dist_d
n_iter_inner = args.n_iter_inner #num of iterations for distrib_min and risgw
data = args.data
model = args.model 

nb_iter = 20
lr = 0.01
dim_latent = 10
device = 'cpu'
if data=='spiral':
    dim_s, dim_t = 2,2
    nb_compare = 10
    param_vec = np.linspace(0,pi/2,nb_compare)
elif data== 'gaussian':
    dim_s = 3
    dim_t = 10
    nb_compare = 10
    param_vec = torch.linspace(0,2,nb_compare)



filename = f"toy_{data}_{model}_nsample{n_samples:d}_{nproj}_{nproj_dist}_{nproj_dist_d}_nb{nb_iter}_iterinner{n_iter_inner:d}"
pathres='./result/toy/'
alldist_sgw = np.zeros((nb_compare,nb_iter))
alldist_distrib_sse =  np.zeros((nb_compare,nb_iter))
alldist_min_sse =  np.zeros((nb_compare,nb_iter))
alldist_sse =  np.zeros((nb_compare,nb_iter))
alldist_ssegromov =  np.zeros((nb_compare,nb_iter))
alldist_risgw =  np.zeros((nb_compare,nb_iter))
alldist_distrib_min_sse =  np.zeros((nb_compare,nb_iter))

time_sgw =  np.zeros((nb_compare,nb_iter))
time_sse =  np.zeros((nb_compare,nb_iter))
time_ssegromov =  np.zeros((nb_compare,nb_iter))
time_distrib_sse =  np.zeros((nb_compare,nb_iter))
time_min_sse =  np.zeros((nb_compare,nb_iter))
time_risgw =  np.zeros((nb_compare,nb_iter))
time_distrib_min_sse =  np.zeros((nb_compare,nb_iter))

#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

np.random.seed(0)
torch.manual_seed(0)
scale=1
Xs,Xt0 = get_data(n_samples,0,scale=scale)
Xs = torch.from_numpy(Xs).float()

for i in range(nb_compare):
    print(i)
    if data == 'gaussian':
        Xs = torch.randn(n_samples,dim_s)
        Xt =  torch.randn(n_samples,dim_t)+ param_vec[i]
    elif data == 'spiral':

        angles=np.linspace(0,np.pi/2,nb_compare)
        #Xs,Xt0 = get_data(n_samples,0,scale=scale)
        A = get_rot(angles[i])
        Xt = Xt0.dot(A)
        Xt = torch.from_numpy(Xt).float()

        
    for j in range((nb_iter)):
        fs = TransformNet(dim_s).to(device)
        ft = TransformNet(dim_t).to(device)
        fs_optim = optim.Adam(fs.parameters(), lr=lr, betas=(0.5, 0.999))
        ft_optim = optim.Adam(ft.parameters(), lr=lr, betas=(0.5, 0.999))
        
        
        fp = TransformLatenttoOrig(dim_latent,dim_s).to(device)
        fq = TransformLatenttoOrig(dim_latent,dim_t).to(device)

        
        fp_optim = optim.Adam(fp.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.005)
        fq_optim = optim.Adam(fq.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.005)

        fq.apply(weights_init_generator)
        fp.apply(weights_init_generator)
        
        # Mapping function from S^{d-1} to S^{d-1}
    
        fdistrib = TransformNet(dim_latent).to(device)
        fdistrib_optim = optim.Adam(fdistrib.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.005)
        
        if model == 'sgw':
            tic = time()
            alldist_sgw[i,j]=(sgw_gpu(Xs,Xt,device='cpu',nproj=nproj))
            time_sgw[i,j]=(time()-tic)
        elif model == 'sse':
            tic = time()
            alldist_sse[i,j]=(sse_gpu(Xs,Xt,device='cpu',nproj=nproj))
            time_sse[i,j]=(time()-tic)
        elif model == 'sse_gromov':
            tic = time()
            alldist_ssegromov[i,j]=(ssegromov_gpu(Xs,Xt,device='cpu',nproj=nproj))
            time_ssegromov[i,j]=(time()-tic)
        elif model == 'distrib_sse':
            tic = time()
            alldist_distrib_sse[i,j]=(distrib_sse(Xs,Xt,fs,ft,fs_optim,ft_optim,nproj=nproj_dist,max_iter=n_iter_inner))
            time_distrib_sse[i,j]=time()-tic
        elif model == 'min_sse':
            tic = time()
            alldist_min_sse[i,j]=(min_sse(Xs,Xt,fp,fq,fp_optim,fq_optim,dim_latent,nproj=nproj_dist,max_iter=n_iter_inner))
            time_min_sse[i,j]=time()-tic
        elif model == 'distrib_min_sse':
            tic = time()
            alldist_distrib_min_sse[i, j] = distributional_min_sse(Xs, Xt, fdistrib, fp, fq, fdistrib_optim, fp_optim, fq_optim, 
                               dim_latent, nproj=nproj_dist_d, num_epochs = n_iter_inner, num_sup_iter = 5, num_inf_iter = 5,
                                             verbose =False,lam_ap=10)
            time_distrib_min_sse[i,j]=time()-tic
        elif model == 'risgw':
            tic = time()
            alldist_risgw[i,j]=(risgw_gpu(Xs,Xt,device='cpu',nproj=nproj,lr=lr,max_iter=500))
            time_risgw[i,j]=time()-tic


    np.savez(pathres+filename,
          alldist_distrib_sse=alldist_distrib_sse,
          alldist_sgw=alldist_sgw,
          alldist_sse=alldist_sse,
          alldist_ssegromov=alldist_ssegromov,
          alldist_distrib_min_sse = alldist_distrib_min_sse,
          alldist_min_sse = alldist_min_sse,
          alldist_risgw=alldist_risgw,
          time_distrib_sse=time_distrib_sse,
          time_sgw=time_sgw,
          time_sse=time_sse,
          time_ssegromov=time_ssegromov,
          time_min_sse=time_min_sse,
          time_risgw=time_risgw,
          time_distrib_min_sse = time_distrib_min_sse,
          param_vec=param_vec
          )


#%%
alldist_distrib_sse = np.array(alldist_distrib_sse,dtype='float')
alldist_distrib_sse_m = alldist_distrib_sse.mean(axis=1)

alldist_min_sse = np.array(alldist_min_sse,dtype='float')
alldist_min_sse_m = alldist_min_sse.mean(axis=1)

alldist_distrib_min_sse = np.array(alldist_distrib_min_sse, dtype='float')
alldist_distrib_min_sse_m =  alldist_distrib_min_sse.mean(axis=1)


alldist_sgw = np.array(alldist_sgw,dtype='float')
alldist_sgw_m = alldist_sgw.mean(axis=1)

alldist_sse = np.array(alldist_sse,dtype='float')
alldist_sse_m = alldist_sse.mean(axis=1)

alldist_ssegromov = np.array(alldist_ssegromov,dtype='float')
alldist_ssegromov_m = alldist_ssegromov.mean(axis=1)

alldist_risgw = np.array(alldist_risgw,dtype='float')
alldist_risgw_m = alldist_risgw.mean(axis=1)


#%%
np.savez(pathres+filename,
         alldist_distrib_sse=alldist_distrib_sse,
         alldist_sgw=alldist_sgw,
         alldist_sse=alldist_sse,
         alldist_ssegromov=alldist_ssegromov,
         alldist_distrib_min_sse = alldist_distrib_min_sse,
         alldist_risgw=alldist_risgw,
         time_distrib_sse=time_distrib_sse,
         time_sgw=time_sgw,
         time_sse=time_sse,
         time_ssegromov=time_ssegromov,
         time_risgw=time_risgw,
         time_distrib_min_sse = time_distrib_min_sse,
         param_vec=param_vec
         )
#%%

plt.plot(param_vec,alldist_sgw_m,label=f"SGW {nproj}")
#plt.plot(param_vec,alldist_ssegromov_m,label=f"SSE Gromov {nproj}")
plt.plot(param_vec,alldist_risgw_m,label=f"RISGW {nproj}")
plt.plot(param_vec,alldist_distrib_sse_m,label=f"Distrib SSE {nproj_dist}")
plt.plot(param_vec,alldist_min_sse_m,label=f"min SSE {nproj_dist}")
plt.plot(param_vec,alldist_distrib_min_sse_m,label=f"Distrib min SSE {nproj_dist_d}")
plt.plot(param_vec,alldist_sse_m,label=f"SSE {nproj}")
plt.legend()

plt.savefig(f"./figure/{filename}_2.png")
#%%


