# -*- coding: utf-8 -*-
"""
Taken from https://github.com/yikun-baio/sliced_opt

Created on Thu Apr 21 11:58:08 2022
@author: Yikun Bai yikun.bai@Vanderbilt.edu 
"""
import os
import numpy as np
from typing import Tuple
from scipy.stats import ortho_group
import numba as nb
from sliced_opt_utils import solve_opt, recover_indice, pot, cost_function



@nb.njit(['float64[:,:](int64,int64,int64)'],fastmath=True,cache=True)
def random_projections(d,n_projections,Type=0):
    '''
    input: 
    d: int 
    n_projections: int

    output: 
    projections: d*n torch tensor

    '''
#    np.random.seed(0)
    if Type==0:
        Gaussian_vector=np.random.normal(0,1,size=(d,n_projections)) #.astype(np.float64)
        projections=Gaussian_vector/np.sqrt(np.sum(np.square(Gaussian_vector),0))
        projections=projections.T

    elif Type==1:
        r=np.int64(n_projections/d)+1
        projections=np.zeros((d*r,d)) #,dtype=np.float64)
        for i in range(r):
            H=np.random.randn(d,d) #.astype(np.float64)
            Q,R=np.linalg.qr(H)
            projections[i*d:(i+1)*d]=Q
        projections=projections[0:n_projections]
    return projections





@nb.njit(['Tuple((float64,int64[:,:],float64[:,:],float64[:,:]))(float64[:,:],float64[:,:],float64[:])'],parallel=True,fastmath=True,cache=True)
def opt_plans_64(X,Y,Lambda_list):
    n,d=X.shape
    n_projections=Lambda_list.shape[0]
    projections=random_projections(d,n_projections,0)
    X_projections=projections.dot(X.T)
    Y_projections=projections.dot(Y.T)
    opt_plan_X_list=np.zeros((n_projections,n),dtype=np.int64)
    #opt_plan_Y_list=np.zeros((n_projections,n),dtype=np.int64)
    opt_cost_list=np.zeros(n_projections)
    for (epoch,(X_theta,Y_theta,Lambda)) in enumerate(zip(X_projections,Y_projections,Lambda_list)):
        X_indice=X_theta.argsort()
        Y_indice=Y_theta.argsort()
        X_s=X_theta[X_indice]
        Y_s=Y_theta[Y_indice]
        # M=cost_matrix(X_s,Y_s)
        obj,phi,psi,piRow,piCol=solve_opt(X_s,Y_s,Lambda)
        
        L1=recover_indice(X_indice,Y_indice,piRow)
        #L2=recover_indice(Y_indice,X_indice,piCol)
        opt_cost_list[epoch]=obj
        opt_plan_X_list[epoch]=L1
        #opt_plan_Y_list[epoch]=L2
        #sopt_dist=np.sum(opt_cost_list)/n_projections
        sopt_dist=opt_cost_list.sum()/n_projections
    return sopt_dist,opt_plan_X_list,X_projections,Y_projections





@nb.njit(['(float64[:,:],float64[:,:],float64[:,:],float64[:])'],cache=True)
def X_correspondence(X,Y,projections,Lambda_list):
    N,d=projections.shape
    n=X.shape[0]
    Lx_org=np.arange(0,n)
    for i in range(N):
        theta=projections[i]
        X_theta=np.dot(theta,X.T)
        Y_theta=np.dot(theta,Y.T)
        X_indice=X_theta.argsort()
        Y_indice=Y_theta.argsort()
        X_s=X_theta[X_indice]
        Y_s=Y_theta[Y_indice]
        Lambda=Lambda_list[i]
        # M=cost_matrix(X_s,Y_s)
        obj,phi,psi,piRow,piCol=solve_opt(X_s,Y_s,Lambda)
#        Cost,L=o(X_s,Y_s,Lambda)
        
        L=piRow
        L=recover_indice(X_indice,Y_indice,L)
        #move X
        Lx=Lx_org.copy()
        Lx=Lx[L>=0]
        if Lx.shape[0]>=1:
            Ly=L[L>=0]
#            dim=Ly.shape[0]
            X_take=X_theta[Lx]
            Y_take=Y_theta[Ly]
            X[Lx]+=np.expand_dims(Y_take-X_take,1)*theta
            


 

@nb.njit(['(float64[:,:],float64[:,:],float64[:,:])'],cache=True)
def X_correspondence_pot(X,Y,projections):
    N,d=projections.shape
    n=X.shape[0]
    for i in range(N):
        theta=projections[i]
        X_theta=np.dot(theta,X.T)
        Y_theta=np.dot(theta,Y.T)
        X_indice=X_theta.argsort()
        Y_indice=Y_theta.argsort()
        X_s=X_theta[X_indice]
        Y_s=Y_theta[Y_indice]
        # M=cost_matrix(X_s,Y_s)
        cost,L=pot(X_s,Y_s)
        L=recover_indice(X_indice,Y_indice,L)
        X_take=X_theta
        Y_take=Y_theta[L]
        X+=np.expand_dims(Y_take-X_take,1)*theta
    return X


    

@nb.njit(['Tuple((float64,int64[:,:],float64[:,:],float64[:,:]))(float64[:,:],float64[:,:],float64[:])'],parallel=True,fastmath=True,cache=True)
def opt_plans_64(X,Y,Lambda_list):
    n,d=X.shape
    n_projections=Lambda_list.shape[0]
    projections=random_projections(d,n_projections,0)
    X_projections=projections.dot(X.T)
    Y_projections=projections.dot(Y.T)
    opt_plan_X_list=np.zeros((n_projections,n),dtype=np.int64)
    #opt_plan_Y_list=np.zeros((n_projections,n),dtype=np.int64)
    opt_cost_list=np.zeros(n_projections)
    for (epoch,(X_theta,Y_theta,Lambda)) in enumerate(zip(X_projections,Y_projections,Lambda_list)):
        X_indice=X_theta.argsort()
        Y_indice=Y_theta.argsort()
        X_s=X_theta[X_indice]
        Y_s=Y_theta[Y_indice]
        # M=cost_matrix(X_s,Y_s)
        obj,phi,psi,piRow,piCol=solve_opt(X_s,Y_s,Lambda)
        
        L1=recover_indice(X_indice,Y_indice,piRow)
        #L2=recover_indice(Y_indice,X_indice,piCol)
        opt_cost_list[epoch]=obj
        opt_plan_X_list[epoch]=L1
        #opt_plan_Y_list[epoch]=L2
        #sopt_dist=np.sum(opt_cost_list)/n_projections
        sopt_dist=opt_cost_list.sum()/n_projections
    return sopt_dist, opt_plan_X_list, X_projections, Y_projections

def opt_cost_from_plans(X_projections,Y_projections,Lambda_list,opt_plan_X_list,cache=True):
    n_projections,n=X_projections.shape
    n_projections,m=Y_projections.shape
    opt_cost_list=np.zeros(n_projections)
    for (epoch,(X_theta,Y_theta,Lambda,opt_plan)) in enumerate(zip(X_projections,Y_projections,Lambda_list,opt_plan_X_list)):
        Domain=opt_plan>=0
        Range=opt_plan[Domain]
        X_select=X_theta[Domain]
        Y_select=Y_theta[Range]
        trans_cost=np.sum(cost_function(X_select,Y_select))
        mass_panalty=Lambda*(m+n-2*Domain.sum())
        opt_cost_list[epoch]=trans_cost+mass_panalty
    return opt_cost_list


def reprocess_support(a, x):
    a, x = a.cpu().detach().numpy(), x.cpu().detach().numpy()
    y = np.zeros((int(np.sum(a)), x.shape[1]))
    y[:x.shape[0]] = x
    k = x.shape[0]
    for i in range(len(a)):
        for j in range(int(a[i])-1):
            y[k] = x[i]
            k += 1
    return y


# def reprocess_support(a, x):
#     a, x = a.cpu().detach().numpy(), x.cpu().detach().numpy()
#     y = np.zeros((int(np.sum(a)), x.shape[1]))
#     y[:x.shape[0]] = x
#     for i in range(len(a)):
#         for j in range(int(a[i])-1):
#             supp_extra.append(np.asarray(x[i]))
#     return np.concatenate((x, np.array(supp_extra)), axis=0)