"""Implements Extended Flow Matcher Losses."""


import math
import random
import torch
import torch.distributions as dist
from einops import rearrange, reduce, repeat
import numpy as np
from scipy.optimize import linear_sum_assignment
import sys
sys.path.append('../')
import pdb
import copy
from torchEFM import utils
import ot


class linear_regression(object):
    def __init__(self, device='cpu', **kwargs):
        self.regmat_mean = None 
        self.regmat = None
        self.device = device

    def create_regmats(self, xbdry, cbdry):

        '''
        Compute the batchwise/Elementwise regression  C Mm = E[X | C]   and   CMb = X  
        IN:
          xbdry: tensor,    batch x 1 x numcset x dimX
          cbdry: tensor,    numcset x dimC
        Out
          regmat_mean: tensor,  (dimC + 1) x dimX      
          regmat:      tensor,  batch x (dimC + 1) x dimX
    
        '''
        device = self.device
        if torch.max(cbdry.to('cpu'))  > 10000:
            pdb.set_trace()

        self.regmat_mean = self.create_mean_regmats(xbdry, cbdry)

        #elementwise_regressor 
        c_num = cbdry.shape[0]
        cbdry_ext = torch.cat((cbdry, torch.ones((c_num, 1), device=device)), dim=1)         
        cbdry_ext = rearrange(cbdry_ext, 'n c -> 1 n c') # 1 x n x c   
        xbdry  = rearrange(xbdry, 'b 1 c x -> b c x')   # batch x 1 x numc x dimx 

        #mat_regressor = torch.linalg.lstsq(cbdry_ext,xbdry).solution

        #mat_regressor = torch.linalg.lstsq(copy.deepcopy(cbdry_ext),copy.deepcopy(xbdry)).solution
        mat_regressor = torch.linalg.pinv(copy.deepcopy(cbdry_ext)) @  xbdry.clone() #copy.deepcopy(xbdry) 

        self.regmat = mat_regressor.detach().to(device)

    def create_mean_regmats(self, xbdry, cbdry):
        device = self.device
        c_num = cbdry.shape[0]
        device= xbdry.device 
        cbdry_ext = torch.cat((cbdry, torch.ones((c_num, 1), device=device)), dim=1)         
        xmean  = torch.mean(xbdry, dim=0).squeeze(0).to(device)   
        mat_regressor_mean = torch.linalg.pinv(cbdry_ext) @ xmean 
        return mat_regressor_mean.detach().to(device) 

    def compute_regression(self, ctensor, reshape=True): 
        '''
        Compute the regression for arbitrary samples of c.
        In: 
          ctensor:   num_cset x dim_c  // numT x num_cset x dimC
        Out(reshaped): 
          xtensor:      batch x 1 x num_cset x x_dim 
          xtensor_mean:     1 x 1 x num_cset x x_dim
        '''
        device = ctensor.device
        if len(ctensor.shape) > 2:
            nt, nc, c = ctensor.shape
            ctensor_extended = torch.cat((ctensor, torch.ones((nt, nc, 1), device=device)), dim=-1) 
            xtensor_mean = ctensor_extended @ self.regmat_mean

            ctensor_extended = rearrange(ctensor_extended, 'nt nc c -> 1 nt nc c')
            regmat_reshaped = rearrange(self.regmat, 'b c x -> b 1 c x')
            xtensor = ctensor_extended @ regmat_reshaped

            if reshape: 
                xtensor_mean = rearrange(xtensor_mean, 'nt nc c -> 1 nt nc c')

        else:
            num_cset, dimc = ctensor.shape
            ctensor_extended = torch.cat((ctensor, torch.ones((num_cset, 1), device=device)), dim=1) 
            # pdb.set_trace()
            xtensor_mean = ctensor_extended @ self.regmat_mean

            ctensor_extended = rearrange(ctensor_extended, 'nc c -> 1 nc c')
            xtensor = ctensor_extended @ self.regmat  #1 nc c  , b c x  -> b nc x

            if reshape: 
                xtensor_mean = rearrange(xtensor_mean, 'nc c -> 1 1 nc c')
                xtensor = rearrange(xtensor, 'b nc c -> b 1 nc c')

        return xtensor_mean, xtensor  

def rbf_kernel(X1, X2, sigma=1.0):
    pairwise_dists = torch.cdist(X1,X2)
    K = torch.exp(-pairwise_dists**2 / (2*sigma**2))
    return K

class kernel_regression(object):
    def __init__(self, **kwargs):
        self.regmat_mean = None 
        self.regmat = None
        self.lmbda = kwargs['lmbda'] # Ridge reg. param.
        self.sigma = kwargs['sigma'] # RBF kernel param.

    def compute_weight(self, X_train,y_train):
        '''
        Input: 
            X_train: input of regressed function, [batch_size, num_sample, dim_x]
            y_train: output of regrssed function, [batch_size, num_sample, dim_y]
        Output:
            weight: weight of kernl regressor, [batch_size, num_sample, dim_y]
        '''
        num_sample = X_train.shape[-2]        
        device = X_train.device
        K = rbf_kernel(X_train, X_train, sigma=self.sigma)
        # L = torch.linalg.cholesky(K + self.lmbda*torch.eye(num_sample,device=device))
        L, _ = torch.linalg.cholesky_ex(K + self.lmbda*torch.eye(num_sample,device=device))
        weight = torch.cholesky_solve(y_train, L)
        # weight = torch.linalg.solve(K + self.lmbda*torch.eye(num_sample,device=device),y_train)
        return weight

    def common_regression(self, X_train,X_test,weight):
        K_test = rbf_kernel(X_test, X_train, sigma=self.sigma)
        y_pred = K_test @ weight
        return y_pred

    def create_regmats(self, xbdry, cbdry, device='cpu'):
        '''
        Compute the batchwise/Elementwise kerel regression
        In:
          xbdry: tensor,    batch x 1 x numcset x dimX
          cbdry: tensor,    numcset x dimC
        Out
          weight_mean: tensor,  (batch*numcset) x dimX      
          weight:      tensor,  batch x numcset x dimX
        '''
        bs = xbdry.shape[0]
        xbdry_ext  = rearrange(xbdry, 'b 1 c x -> b c x')  
        self.cbdry_ext = repeat(cbdry,'n d -> b n d',b=bs)
        self.cbdry_ext_mean = cbdry
        self.weight_mean = self.create_mean_weight(xbdry_ext)
        self.weight = self.create_weight(xbdry_ext)

    def create_mean_weight(self, xbdry):
        xbdry_ext = reduce(xbdry, 'b n d -> n d', 'mean')
        weight_regressor_mean = self.compute_weight(self.cbdry_ext_mean,xbdry_ext)
        return weight_regressor_mean.detach().to(xbdry.device) 
    
    def create_weight(self, xbdry):
        weight_regressor = self.compute_weight(self.cbdry_ext,xbdry)
        return weight_regressor.detach().to(xbdry.device) 

    def compute_regression(self, ctensor, reshape=True): 
        '''
        Compute the regression for arbitrary samples of c.
        In: 
          ctensor:   num_cset x dim_c  // numT x num_cset x dim_c
        Out(reshaped): 
          xtensor:      batch x 1 x num_cset x x_dim 
          xtensor_mean:     1 x 1 x num_cset x x_dim
        '''
        device = ctensor.device
        if len(ctensor.shape) > 2:
            nt, nc, c = ctensor.shape
            ctensor = rearrange(ctensor,'t c d -> (t c) d')

        ctensor_ext = repeat(ctensor, 'c d -> b c d', b=self.cbdry_ext.shape[0])
        xtensor = self.common_regression(self.cbdry_ext,ctensor_ext,self.weight)
        xtensor_mean = self.common_regression(self.cbdry_ext_mean,ctensor,self.weight_mean)
        if reshape:
            if len(ctensor.shape) > 2:
                xtensor = rearrange(xtensor, 'b (t nc) c -> b t nc c',t=nt)
                xtensor_mean = rearrange(xtensor_mean, '(t nc) c -> 1 t nc c', t=nt) 
            else:
                xtensor = rearrange(xtensor, 'b nc c -> b 1 nc c')
                xtensor_mean = rearrange(xtensor_mean, 'nc c -> 1 1 nc c')

        return xtensor_mean, xtensor  
        

class ExtendedFlowMatcher(object):
    """Base class for conditional flow matching methods. This class implements the independent
    conditional flow matching methods from [1] and serves as a parent class for all other flow
    matching methods.

    It implements:
    - Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
    - conditional flow matching ut(x1|x0) = x1 - x0
    """

    def __init__(self,kernel=False,autoJac=False,sigma=1.0,lmbda=1e-4, device='cpu', **kwargs):
        self.kernel = kernel
        if self.kernel:
            print('use kernel ridge regression')
            self.lm = kernel_regression(device=device,
                                        sigma=sigma,
                                        lmbda=lmbda
                                        )
        else:
            print('use linear regression')
            self.lm = linear_regression(device=device) 
        self.autoJac = autoJac
        self.xtensor = None  #endpoint of the paths, style transfer
        self.ztensor = None  #endpoint of the paths, generation
        self.device = device

    def create_psi(self, xbdry, cbdry):
        '''
        IN: 
            xbdry: tensor, batch x 1 x numC x dimX
            cbdry: tensor, numC x dimC
        '''
        device = xbdry.device
        self.lm.create_regmats(xbdry, cbdry, device) 

    #cset is just the boundary
    def sample_psi_xi(self, xiset, source):
        '''
        From xbdry and cbdry, compute psi(t, c | xbdry, cbdry, source) 
        numC is the number of boundary cs. 

        IN: 
            xiset:  tensor, numT x numCsample x (1 + dimC)
            source: tensor, batch x 1 x numCsample x dimX
        Out:
            psivals: tensor, batch x numT x numSample x dimX
        ''' 

        tset = xiset[...,[0]] 
        cset = xiset[...,1:]

        psivals  = self.sample_psi_tc(tset, cset, source) 

        return psivals

    def create_model_input(self, psival, tset, cset):
        '''
        Creates the model input [psival, xi].
        IN:
            psival: tensor, b x nT x nCsample x dimX 
            tset : tensor, nT
            cset : tensor, nCsample, dimC 

        Out:
            model_input:  tensor,  (b x nT x nCsample) x (dimX + dimC + 1)  
        '''
        (batch, nT, nC, dimx) = psival.shape

        xiset = self.create_xi(tset, cset) #nT x nCsample x (dimC + 1)

        # b x nT x nCsample x (dimC + 1)
        xiset_extended = repeat(xiset, 'nT nC d -> b nT nC d', b =batch)

        #b x nT x nC x (dimX + dimC + 1) 
        model_input = torch.cat([psival, xiset_extended], dim=-1)
        #(b nT nC) x (dimX + dimC + 1)
        model_input = rearrange(model_input, 'b nT nC d-> (b nT nC) d') 

        return model_input

    def create_source(self, config, mydata):
        source = torch.normal(0, 1, size=(config.batch_size, 1, 1, mydata.x_dim))
        source = source.expand(config.batch_size, 1, config.Csamples, mydata.x_dim)
        return source    



    def create_xi(self, tset, cset): 
        '''
        IN: 
            tset:  tensor, numT
            cset:  tensor, numCsample x dimC
        Out:
            xiset: tensor, numT x numCsample x (dimC + 1)
        '''
        numT = len(tset)
        numC = cset.shape[0]

        tset = tset[: , None, None].expand(-1, numC, 1) #numT, numC,  1
        cset = cset[None, : , :].expand(numT,-1, -1)    #numT, numCsample, dimC
         
        xiset = torch.cat((tset, cset), dim=-1)     #numT numCsample, (dimC + 1)

        return xiset    

    def sample_psi_tc(self, tset, cset, source, prepareForDeriv=False):
        '''
        From xbdry and cbdry, compute psi(t, c | xbdry, cbdry, source) 
        numC is the number of boundary cs. 

        IN: 
            tset:  tensor, numT   //   numT x numC x 1 
            cset:  tensor, numCsample x dimC    //  numT x numCsample x dimC
            source: tensor, batch x 1 x numCsample x dimX
        Out:
            psival: tensor, batch x numT x numCsample x dimX 

        if prepareForDerive == True, xtensor and ztensor will be saved.
        '''
        if len(tset.shape) == 1:
            tset = rearrange(tset, 't -> 1 t 1 1')

        xtensor_mean, xtensor = self.lm.compute_regression(cset) 
        
        ztensor = source + xtensor_mean

        psival = (1 - tset) * ztensor  + tset * xtensor 

        if prepareForDeriv == True: 
            self.xtensor = xtensor
            self.ztensor = ztensor


        return psival

    def deform_source(self, source=None, xbdry=None):
        xmean  = torch.mean(xbdry, dim=0, keepdim=True) 
        shifted_source = source + xmean
        return shifted_source
    
    def batch_jacobian_for_KRR(self, cset):
        '''
        Input:
            cset: [numsample_c, dim_c]
        Output:
            jac_mean: [1, 1, namsample_c, dim_x, dim_c]
            jac     : [bs, 1, namsample_c, dim_x, dim_c]
        '''
        vectorize = True
        bs = self.xtensor.shape[0]

        def func(cset):
            xtensor_mean, _ = self.lm.compute_regression(cset,reshape=True)
            xtensor_mean = xtensor_mean.squeeze()
            return xtensor_mean.sum(dim=0)
        jac_mean = torch.autograd.functional.jacobian(func,cset,vectorize=vectorize)

        _cset = repeat(cset, 'nc d -> b nc d', b=bs)
        bs, nc, dim_c = _cset.shape
        _cset = rearrange(_cset, 'b nc d -> (b nc) d')
        def _func(_cset):
            c_input = rearrange(_cset, '(b nc) d -> b nc d', b=bs)
            _, xtensor = self.lm.compute_regression(cset,reshape=True)
            xtensor = rearrange(xtensor, 'b nt nc d -> (b nt nc) d' )
            return xtensor.sum(dim=0)
        jac = torch.autograd.functional.jacobian(_func,_cset,vectorize=vectorize)
        

        jac_mean = rearrange(jac_mean, 'x (b nt nc) d -> b nt nc x d', b=1,nc=nc)
        jac = rearrange(jac, 'x (b nt nc) d -> b nt nc x d', b=bs,nc=nc)

        return jac_mean, jac

    def compute_jacobian(self, tset, cset): 
        '''
        Manual computation of Jacobian for the "Linear Path". 

        IN: 
            tset:  tensor, 
            cset:  
        Out: 
        '''
        if self.kernel or self.autoJac:
            regmat_rehshaped, regmat_mean_reshaped = self.batch_jacobian_for_KRR(cset)
        else:
            regmat_rehshaped = rearrange(self.lm.regmat, 'b d x -> b 1 1 d x') 
            regmat_mean_reshaped = rearrange(self.lm.regmat_mean, 'd x -> 1 1 1 d x') 
        nt = tset.shape[0]
        nc = cset.shape[0]

        tset  = rearrange(tset, 'nt -> 1 nt 1 1 1 ')

        partial_c = (1 - tset) *  regmat_mean_reshaped +  tset * regmat_rehshaped  # b nt 1 c+1 x 
        if not self.kernel and not self.autoJac:
            partial_c = partial_c[:, :, :, :-1, :]   # b nt 1 c x 
            partial_c = repeat(partial_c, 'b nt 1 c x -> b nt nc x c', nc = nc ) 

        partial_t = - (self.ztensor)  + self.xtensor # b nt nc x
        partial_t = repeat(partial_t, 'b 1 nc x -> b nt nc x 1', nt =nt) 

        #Want:  b nt nc x (c+ 1) 
        jacobian_val = torch.concat([partial_t, partial_c], axis = -1)

        return jacobian_val
    

class BatchSinkhornFlowMatcher(ExtendedFlowMatcher):

    def __init__(self, **kwargs):

        super().__init__(**kwargs) 

        #self.lm = linear_regression(device=kwargs['device']) 
        self.lm = linear_regression(self.device) 
        self.xtensor=None  #endpoint of the paths, style transfer
        self.ztensor=None  #endpoint of the paths, generation
        self.sink_iter=kwargs['sink_iter']
        self.sinkhorn = utils.SinkhornSolver(epsilon=0.1, iterations=self.sink_iter)


    def create_psi(self, xbdry, cbdry, source=None):
        '''
        IN: 
            xbdry: tensor, batch x 1 x numC x dimX
            cbdry: tensor, numC x dimC
        '''
        xbdry = self.sinkhorn_matching(xbdry)
        device = xbdry.device
        self.lm.create_regmats(xbdry, cbdry) 


    def sinkhorn_matching(self, xbdry):
        '''
        IN: 
            xbdry: tensor, batch x 1 x numC x dimX
        '''
        (batch, _, nC, dimX) = xbdry.shape
        #permuted_cidx = np.random.choice(nC, size=nC, replace=False)
        permuted_cidx = torch.randperm(nC)
        for k in range(1,nC):
            fromdat = xbdry[:, 0, permuted_cidx[k-1], :]
            todat = xbdry[:, 0, permuted_cidx[k], :]
        
            cost, pi = self.sinkhorn.forward(fromdat, todat)
            todat_idx = torch.argmax(pi, axis=1) 
            xbdry[:, 0, permuted_cidx[k], :] = xbdry[todat_idx, 0, permuted_cidx[k], :]            
        
        return xbdry
    

#Using EMD in place of SinkhornD
class BatchEMDFlowMatcher(BatchSinkhornFlowMatcher): 
    def __init__(self, **kwargs):

        # self.lm = linear_regression(device=kwargs['device']) 
        # self.xtensor=None  #endpoint of the paths, style transfer
        # self.ztensor=None  #endpoint of the paths, generation

        super().__init__(**kwargs)     

        batch=kwargs['batch']
        a, b = torch.ones((batch,)) / batch, torch.ones((batch,)) / batch 
        self.a = a.to(self.device)
        self.b = b.to(self.device)


    @torch.no_grad()
    def sinkhorn_matching(self, xbdry):
        '''
        IN: 
            xbdry: tensor, batch x 1 x numC x dimX
        '''
        (batch, _, nC, dimX) = xbdry.shape
        permuted_cidx = torch.randperm(nC)

        for k in range(1,nC):
            fromdat = xbdry[:, 0, permuted_cidx[k-1], :]
            todat = xbdry[:, 0, permuted_cidx[k], :]
            M = ot.dist(fromdat, todat).detach()
            pi = ot.emd(self.a, self.b, M).detach()
            todat_idx = torch.argmax(pi, axis=1) 
            xbdry[:, 0, permuted_cidx[k], :] = xbdry[todat_idx, 0, permuted_cidx[k], :]            


        return xbdry

#Vertical sinkhorn considered.
class BatchEMDFlowMatcher2(BatchEMDFlowMatcher):
    def create_psi(self, xbdry, cbdry, source=None):
        '''
        IN: 
            xbdry: tensor, batch x 1 x numC x dimX
            cbdry: tensor, numC x dimC
        '''
        xbdry, source= self.sinkhorn_matching(xbdry, source)
        device = xbdry.device
        self.lm.create_regmats(xbdry, cbdry) 

    @torch.no_grad()
    def sinkhorn_matching(self, xbdry, source):
        '''
        IN: 
            xbdry: tensor, batch x 1 x numC x dimX
        '''
        (batch, _, nC, dimX) = xbdry.shape
        permuted_cidx = torch.randperm(nC)

        basesource = source[:, :, [0]]
        shifted_source = self.deform_source(source=basesource, xbdry=xbdry[:, :,[permuted_cidx[0]]])
        fromdat = xbdry[:, 0, permuted_cidx[0], :]
        todat = shifted_source[:, 0, 0, :] 

        
        M = ot.dist(fromdat, todat)
        pi = ot.emd(self.a, self.b, M)
        todat_idx = torch.argmax(pi, axis=1) 

        source = source[todat_idx]
        for k in range(1,nC):
            fromdat = xbdry[:, 0, permuted_cidx[k-1], :]
            todat = xbdry[:, 0, permuted_cidx[k], :]

            M = ot.dist(fromdat, todat)
            pi = ot.emd(self.a, self.b, M)
            todat_idx = torch.argmax(pi, axis=1) 

            xbdry[:, 0, permuted_cidx[k], :] = xbdry[todat_idx, 0, permuted_cidx[k], :]            
        
        return xbdry,source


def greedy_matching(pi):
    used_ind = []
    matching = []
    for i in range(pi.shape[0]):
        index_list = torch.topk(pi[i],pi.shape[0]).indices
        # print(index_list)
        for index in index_list:
            # print(index)
            if index not in used_ind:
                matching.append(index.item())
                used_ind.append(index.item())
                break
            # break

    return torch.tensor(matching,dtype=int)


if __name__ == '__main__':
    pass