"""Implements Guided Flow Matcher Losses."""

import math
import random
import torch
import torch.distributions as dist
from einops import rearrange, reduce, repeat
import numpy as np
from scipy.optimize import linear_sum_assignment
import sys
sys.path.append('../')
import pdb
import copy
from torchEFM import utils
from torchEFM import extended_flow_matching as efm
import ot


class CondFlowMatcher(object):

    def __init__(self, device='cpu', **kwargs):

        self.device = device
        # print('hoge')

    def create_psi(self, xbdry, cbdry, source):
        '''
        IN: 
            xbdry: tensor, batch x 1 x numC x dimX
            cbdry: tensor, numC x dimC
        '''
        self.xbdry=xbdry
        self.cbdry=cbdry

        device = xbdry.device
        #self.lm.create_regmats(xbdry, cbdry, device) 

    def create_model_input(self, psival, tset, cset):
        '''
        Creates the model input [psival, xi].
        IN:
            psival: tensor, b x nT x nCsample x dimX 
            tset : tensor, nT
            cset : tensor, nCsample, dimC 

        Out:
            model_input:  tensor,  (b x nT x nCsample) x (dimX + dimC + 1)  
        '''
        (batch, nT, nC, dimx) = psival.shape
        xiset = self.create_xi(tset, cset) #nT x nCsample x (dimC + 1)

        # b x nT x nCsample x (dimC + 1)
        xiset_extended = repeat(xiset, 'nT nC d -> b nT nC d', b =batch)

        #b x nT x nC x (dimX + dimC + 1) 
        model_input = torch.cat([psival, xiset_extended], dim=-1)
        #(b nT nC) x (dimX + dimC + 1)
        model_input = rearrange(model_input, 'b nT nC d-> (b nT nC) d') 

        return model_input

    def create_source(self, config, mydata):
        source = torch.normal(0, 1, size=(config.batch_size, 1, config.loader.num_samplec, mydata.x_dim))
        #source = source.expand(config.batch_size, 1, config.Csamples, mydata.dim_x)
        return source    


    def create_xi(self, tset, cset): 
        '''
        IN: 
            tset:  tensor, numT
            cset:  tensor, numCsample x dimC
        Out:
            xiset: tensor, numT x numCsample x (dimC + 1)
        '''
        numT = len(tset)
        numC = cset.shape[0]

        tset = tset[: , None, None].expand(-1, numC, 1) #numT, numC,  1
        cset = cset[None, : , :].expand(numT,-1, -1)    #numT, numCsample, dimC
         
        xiset = torch.cat((tset, cset), dim=-1)     #numT numCsample, (dimC + 1)

        return xiset    

    def sample_psi_tc(self, tset=None, cset=None, source=None, prepareForDeriv=False):
        '''
        From xbdry and cbdry, compute psi(t, c | xbdry, cbdry, source) 
        numC is the number of boundary cs. 

        IN: 
            tset:  tensor, numT   //   numT x numC x 1 
            source: tensor, batch x 1 x numCsample x dimX
        Out:
            psival: tensor, batch x numT x numCsample x dimX 

        if prepareForDerive == True, xtensor and ztensor will be saved.
        '''
        if len(tset.shape) == 1:
            tset = rearrange(tset, 't -> 1 t 1 1')

        ztensor = source
        xtensor = self.xbdry
        psival = (1 - tset) * ztensor  + tset * xtensor 

        if prepareForDeriv == True: 
            self.xtensor = xtensor
            self.ztensor = ztensor

        return psival

    def compute_jacobian(self, tset, cset): 
        '''
        Manual computation of Jacobian for the "Linear Path". 

        IN: 
            tset:  tensor, 
            cset:  
        Out: 
        '''
        nt = tset.shape[0]
        nc = cset.shape[0]

        tset  = rearrange(tset, 'nt -> 1 nt 1 1 1 ')

        partial_t = - (self.ztensor)  + self.xtensor # b nt nc x
        partial_t = repeat(partial_t, 'b 1 nc x -> b nt nc x 1', nt =nt) 

        #Want:  b nt nc x (c+ 1) 
        jacobian_val = partial_t 

        return jacobian_val




class SinkhornFlowMatcher(CondFlowMatcher):

    def __init__(self, **kwargs):

        super().__init__(**kwargs) 
        self.sink_iter = kwargs['sink_iter']
        self.sinkhorn = utils.SinkhornSolver(epsilon=0.1, iterations=self.sink_iter)
        
        # print('hoge')

    def create_psi(self, xbdry, cbdry, source):
        '''
        IN: 
            xbdry: tensor, batch x 1 x numC x dimX
            cbdry: tensor, numC x dimC
        '''

        xbdry, source = self.sinkhorn_matching(source=source, xbdry=xbdry)
        self.xbdry=xbdry
        self.cbdry=cbdry

        device = xbdry.device

    def sinkhorn_matching(self, xbdry, source):
        '''
        IN: 
            xbdry: tensor, batch x 1 x numC x dimX
        '''
        (batch, _, nC, dimX) = xbdry.shape 
        #Next, permute the verticals. 
        for k in range(0,nC):
            fromdat = source[:, 0, k, :]
            
            #shifted_source = source[:, 0, permuted_cidx[k], :] + xmean[:, 0, permuted_cidx[k], :]
            todat = xbdry[:, 0, k, :] 

            cost, pi = self.sinkhorn.forward(fromdat, todat)
            todat_idx = torch.argmax(pi, axis=1) 
            xbdry[:, 0, k, :] = xbdry[todat_idx, 0, k, :] 

        return xbdry, source

class EmdFlowMatcher(SinkhornFlowMatcher):

    def sinkhorn_matching(self, xbdry, source):
        '''
        IN: 
            xbdry: tensor, batch x 1 x numC x dimX
        '''
        (batch, _, nC, dimX) = xbdry.shape 
        a, b = torch.ones((batch,)) / batch, torch.ones((batch,)) / batch 
        a = a.to(xbdry.device)
        b = b.to(xbdry.device)

        #Next, permute the verticals. 
        for k in range(0,nC):
            fromdat = source[:, 0, k, :]
            
            #shifted_source = source[:, 0, permuted_cidx[k], :] + xmean[:, 0, permuted_cidx[k], :]
            todat = xbdry[:, 0, k, :] 
            with torch.no_grad():
                M = ot.dist(fromdat, todat)
                pi = ot.emd(a, b, torch.tensor(M))
                todat_idx = torch.argmax(pi, axis=1) 
            xbdry[:, 0, k, :] = xbdry[todat_idx, 0, k, :] 

        return xbdry, source


