import math

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.autograd.functional import jacobian
from torchvision.transforms import ToTensor
from torchvision.transforms.functional import rotate
from einops import rearrange, reduce, repeat
import pdb
import io

# Implement some helper functions

def lp(tensor, p=1):
    if p < 100:
        val = torch.mean(torch.abs(torch.flatten(tensor))**p)
    else:
        val = torch.max(torch.abs(torch.flatten(tensor)))
    return val

#Use in Notebooks for Sinkhron Debugging.
def show_directional_assignments(zero, one, ax=None, colors =None ):
    
    for i in range(zero.shape[0]):
            ax.arrow(zero[i,0], zero[i,1], 
                     one[i,0]-zero[i,0],
                     one[i,1]-zero[i,1], alpha=0.5, color=colors[0],
                     head_width=0.03, length_includes_head=True
                    ) 
        
    ax = plt if ax is None else ax
    ax.scatter(*zero.t(), color=colors[0], label='Origin(Not Permuted) ')
    ax.scatter(*one.t(), color=colors[1], label='Destination(Permuted)')
    ax.legend()
    ax.set_xlim([-2, 2]) 
    ax.set_ylim([-2, 2]) 


def batch_jacobian(func, x, batchsize=0, create_graph=False):
    '''
    ##THIS IS DEFUNCT! ONLY USED IN THE OLD CODES, TO BE REFACTORED OUT!

    Take the elementwise Jacobian  nabla_xi (psi(xi)).
    Used the trick in https://discuss.pytorch.org/t/computing-batch-jacobian-efficiently/80771

    IN:
        xi:  tensor,  nT x nC x (dimC + 1)
        func: tensor-> tensor,   nT x nC x (dimC + 1)-> bs x nT x nC x dimX  
    OUT: 
        nabla_psi:  tensor,  bs x nT x nC  x [dimx   x (dimC + 1) ] 
    
    '''
    # xi = repeat(xi, 'nt nc d -> b nt nc d', b=batchsize)
    # (b , nt, nc, d) = xi.shape
    # xiflat =  xi.reshape(-1, d)   # nTnC  x (dimC + 1)  
    # def _func(xi_flat):
    #     x = xi_flat.reshape(b, nt, nc, d) # ToBeSentTo bs x nT x nC x (dimC + 1)
    #     val = rearrange(func(x), 'b nt nc x -> (b nt nc) x').sum(dim = 0) 
    #     return val
    # derivative = jacobian(_func, xiflat) 
    # #print(derivative.shape, xiflat.shape, _func(xiflat).shape)
    # nabla_psi = rearrange(derivative, 'x (b nt nc) d -> b nt nc x d', nc=nc, b=batchsize) 
    # return nabla_psi 
    (bs, T, d) = x.shape
    # x in shape (Batch0, Batch1, Length)
    def _func_sum(x):
        return func(x.reshape(bs, T, d)).reshape(-1, d).sum(dim=0)
    ret = jacobian(_func_sum, x.reshape(-1, d), create_graph=create_graph).permute(1,0,2)
    return ret.reshape(bs, T, *ret.shape[1:])



def kNN(clist,ctilde, num_nhd=2):
    # ctilde in shape (bs, dim_c)
    # clist shape (num_c, dim_c)
    # return shape (bs, dim_c + 2)
    batch_size = ctilde.shape[0]
    dim_c      = ctilde.shape[1]
    # data       = repeat(clist,'n i -> b n i', b=batch_size) # no need
    # diff       = torch.norm(data - ctilde.unsqueeze(dim=dim_c), dim=2, p=None)
    diff       = torch.norm(clist - ctilde.unsqueeze(dim=dim_c), dim=2, p=None)
    knn        = diff.topk(min(num_nhd,len(clist)), largest=False)
    return knn.indices


def eight_normal_sample(n, dim, scale=1, var=1):
    m = torch.distributions.multivariate_normal.MultivariateNormal(
        torch.zeros(dim), math.sqrt(var) * torch.eye(dim)
    )
    centers = [
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
        (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
        (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
        (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
    ]
    centers = torch.tensor(centers) * scale
    noise = m.sample((n,))
    multi = torch.multinomial(torch.ones(8), n, replacement=True)
    data = []
    for i in range(n):
        data.append(centers[multi[i]] + noise[i])
    data = torch.stack(data)
    return data


def sample_moons(n):
    x0, _ = generate_moons(n, noise=0.2)
    return x0 * 3 - 1


def sample_8gaussians(n):
    return eight_normal_sample(n, 2, scale=5, var=0.1).float()


class torch_wrapper(torch.nn.Module):
    """Wraps model to torchdyn compatible format."""

    def __init__(self, model, gamma, gammadot, guidance=None):
        """gamma: t\mapsto gamma(t)\in \Omegatilde"""
        super().__init__()
        self.model    = model
        self.gamma    = gamma
        self.gammadot = gammadot
        self.guidance = guidance

    def forward(self, t, x, args=None):
        " x has the shape [data_num, x_dim]"
        t = torch.tensor([t])
        xi = self.gamma(t).unsqueeze(dim=0).expand(x.shape[0],-1)
        # pdb.set_trace()
        # return torch.bmm(self.model(torch.cat([x, xi], dim = -1)).squeeze(),self.gammadot(t).unsqueeze(dim=0).expand(x.shape[0],-1).unsqueeze(2)).squeeze(2)
        A = self.model(torch.cat([x, xi], dim = -1)).detach()
        if self.guidance is not None:
            #computation of the unguided vectors
            xi0 = torch.zeros(xi.shape).to(A.device)
            #Copy the "time component"
            xi0[:, 0] = xi0[:, 0] +  xi[:, 0]
            #Set the unguided value to 1 ( last coordinate.)
            xi0[:, -1] = 1.
            A0 = self.model(torch.cat([x, xi0], dim = -1)).detach()
            A = (1- self.guidance) * A0 + self.guidance * A
        B = self.gammadot(t)
        return torch.einsum('bij,j->bi',A,B)

class torch_wrapper_lfm(torch.nn.Module):

    def __init__(self, model, gamma, gammadot):
        """gamma: t\mapsto gamma(t)\in \Omegatilde"""
        super().__init__()
        self.model    = model
        self.gamma    = gamma

    def forward(self, t, x, args=None):
        t = torch.tensor([t])
        xi = self.gamma(t).unsqueeze(dim=0).expand(x.shape[0],-1)
        outvec = self.model(torch.cat([x, xi], dim = -1))
        return outvec



def plot_trajectories(traj, returnFig=False, title='',legendlist=[], flim=3., fsize=20):
    """Plot trajectories of some selected samples."""
    n = 2000
    traj = traj.to('cpu').numpy()
    if returnFig==True:
        fig0, ax0 = plt.subplots(figsize=(6,6))
        ax0.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='source')
        ax0.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.2, alpha=0.2, c="olive", label='path')
        ax0.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='Generated')
        ax0.set_xlim(-3.,3.)
        ax0.set_ylim(-3.,3.)
        ax0.set_title(title) 
        fig0.legend()
        fig0.show()
        plt.xlim(-3.,3.)
        plt.ylim(-3.,3.)
        return fig0

    else:
        plt.figure(figsize=(6, 6))
        plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black")
        plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.2, alpha=0.2, c="olive")
        plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue")
        if len(legendlist) == 3:
            plt.legend(legendlist, fontsize=fsize)
        else:
            plt.legend(["Prior sample z(S)", "Flow", "z(0)"], fontsize=fsize)
        plt.xlim(-flim,flim)
        plt.ylim(-flim,flim)
        # plt.xticks([])
        # plt.yticks([])

def show_generated_img(traj, returnFig=False, title=''):
    gen_img = rearrange(traj[-1], 'b (h w) -> b h w', h=28, w=28).to('cpu').numpy()
    sample_num = gen_img.shape[0]
    
    fig, axes = plt.subplots(1, sample_num, figsize=(sample_num*2, 2))
    if sample_num == 1:
        axes.imshow(gen_img[0], cmap="gray")
        axes.axis("off")
        axes.set_title(title) 
    else:
        for i in range(sample_num):
            axes[i].imshow(gen_img[i], cmap="gray")
            axes[i].axis("off")
            axes[i].set_title(title) 
    
    
    if returnFig:
        return fig
    else:
        fig.show()

# def show_generated_img(traj, returnFig=False, title=''):
#     gen_img = rearrange(traj[-1], 'b (h w) -> b h w', h=28, w=28).to('cpu').numpy()
#     sample_num = gen_img.shape[0]
    
#     fig, axes = plt.subplots(1, sample_num, figsize=(sample_num*2, 2))
#     if sample_num == 1:
#         axes.imshow(gen_img[0], cmap="gray")
#         axes.axis("off")
#         axes.set_title(title) 
#     else:
#         for i in range(sample_num):
#             axes[i].imshow(gen_img[i], cmap="gray")
#             axes[i].axis("off")
#             axes[i].set_title(title) 
    
#     fig.show()
#     if returnFig:
#         return fig

def plot_MNIST_trajectories(traj_yoko, returnFig=False, num_snapshot=5):
    trans_img = rearrange(traj_yoko, 't b (h w) -> t b h w', h=28, w=28).to('cpu')
    nT, b, h, w = trans_img.shape
    sliced_images = []
    for i in range(0, nT, num_snapshot):
        sliced_images.append(trans_img[i])
    result_image = torch.cat(sliced_images, dim=0)
    result_image = result_image.reshape(-1, b, h, w)

    fig, axs = plt.subplots(b,len(sliced_images),figsize=(2*len(sliced_images),2*b))
    for i in range(b):
        for j in range(len(sliced_images)):
            axs[i][j].imshow(result_image[j][i],cmap='gray')
            axs[i][j].axis("off")
    
    if returnFig:
        return fig
    else:
        fig.show()

def show_generated_WAEimg(traj,decoder, returnFig=False, title=''):
    img = decoder(traj[-1])['reconstruction']
    gen_img = rearrange(img, 'b 1 h w -> b h w').detach().cpu().numpy()
    sample_num = gen_img.shape[0]
    
    fig, axes = plt.subplots(1, sample_num, figsize=(sample_num*2, 2))
    if sample_num == 1:
        axes.imshow(gen_img[0], cmap="gray")
        axes.axis("off")
        axes.set_title(title) 
    else:
        for i in range(sample_num):
            axes[i].imshow(gen_img[i], cmap="gray")
            axes[i].axis("off")
            axes[i].set_title(title) 
    
    
    if returnFig:
        return fig
    else:
        fig.show()

def plot_MNIST_WAEtrajectories(traj_yoko, decoder, returnFig=False, num_snapshot=1):
    nT,nB,Lat_dim = traj_yoko.shape
    flatten_batch = rearrange(traj_yoko,'t b l -> (t b) l')
    traj_img = decoder(flatten_batch)['reconstruction']
    trans_img = rearrange(traj_img, '(t b) 1 h w -> t b h w', t=nT, b=nB).detach().cpu()
    nT, b, h, w = trans_img.shape
    sliced_images = []
    for i in range(0, nT, num_snapshot):
        sliced_images.append(trans_img[i])
    result_image = torch.cat(sliced_images, dim=0)
    result_image = result_image.reshape(-1, b, h, w)
    # print(result_image.shape)
    # print(len(sliced_images))
    fig, axs = plt.subplots(nrows=b,ncols=len(sliced_images),figsize=(2*len(sliced_images),2*b))
    # print(axs.shape)
    for i in range(b):
        for j in range(len(sliced_images)):
            axs[i, j].imshow(result_image[j][i], cmap='gray')
            axs[i, j].axis("off")

    
    if returnFig:
        return fig
    else:
        fig.show()

#https://gist.github.com/wohlert/8589045ab544082560cc5f8915cc90bd

class SinkhornSolver(nn.Module):
    """
    Optimal Transport solver under entropic regularisation.
    Based on the code of Gabriel Peyré.
    """
    def __init__(self, epsilon, iterations=100, ground_metric=lambda x: torch.pow(x, 2)):
        super(SinkhornSolver, self).__init__()
        self.epsilon = epsilon
        self.iterations = iterations
        self.ground_metric = ground_metric
        self.iterations = iterations

    def forward(self, x, y):
        num_x = x.size(-2)
        num_y = y.size(-2)
        
        batch_size = 1 if x.dim() == 2 else x.size(0)

        # Marginal densities are empirical measures
        a = x.new_ones((batch_size, num_x), requires_grad=False) / num_x
        b = y.new_ones((batch_size, num_y), requires_grad=False) / num_y
        
        a = a.squeeze()
        b = b.squeeze()
                
        # Initialise approximation vectors in log domain
        u = torch.zeros_like(a)
        v = torch.zeros_like(b)

        # Stopping criterion
        threshold = 1e-1
        
        # Cost matrix
        C = self._compute_cost(x, y)
        
        # Sinkhorn iterations
        for i in range(self.iterations): 
            u0, v0 = u, v
                        
            # u^{l+1} = a / (K v^l)
            K = self._log_boltzmann_kernel(u, v, C)
            u_ = torch.log(a + 1e-8) - torch.logsumexp(K, dim=1)
            u = self.epsilon * u_ + u
                        
            # v^{l+1} = b / (K^T u^(l+1))
            K_t = self._log_boltzmann_kernel(u, v, C).transpose(-2, -1)
            v_ = torch.log(b + 1e-8) - torch.logsumexp(K_t, dim=1)
            v = self.epsilon * v_ + v
            
            # Size of the change we have performed on u
            diff = torch.sum(torch.abs(u - u0), dim=-1) + torch.sum(torch.abs(v - v0), dim=-1)
            mean_diff = torch.mean(diff)
                        
            if mean_diff.item() < threshold:
                break
   
        #print("Finished computing transport plan in {} iterations".format(i))
    
        # Transport plan pi = diag(a)*K*diag(b)
        K = self._log_boltzmann_kernel(u, v, C)
        pi = torch.exp(K)
        
        # Sinkhorn distance
        cost = torch.sum(pi * C, dim=(-2, -1))

        return cost, pi

    def _compute_cost(self, x, y):
        x_ = x.unsqueeze(-2)
        y_ = y.unsqueeze(-3)
        C = torch.sum(self.ground_metric(x_ - y_), dim=-1)
        return C

    def _log_boltzmann_kernel(self, u, v, C=None):
        C = self._compute_cost(x, y) if C is None else C
        kernel = -C + u.unsqueeze(-1) + v.unsqueeze(-2)
        kernel /= self.epsilon
        return kernel

def plot_to_tensor(fig):
    """Converts a matplotlib figure to a tensor."""
    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    buf.seek(0)
    tensor = ToTensor()(plt.imread(buf))
    plt.close(fig)  # Close the figure after conversion
    return tensor

def transfer_img_generator(gamma_yoko,traj_yoko,configs,decoder):
    gammas = torch.stack([gamma_yoko(k) for k in torch.linspace(0, 1, 100) ])
    nT,nB,Lat_dim = traj_yoko.shape
    flatten_batch = rearrange(traj_yoko,'t b l -> (t b) l')
    # decoder = decoder.to(device)
    traj_img = decoder(flatten_batch.detach().cpu())['reconstruction']
    trans_img = rearrange(traj_img, '(t b) 1 h w -> t b h w', t=nT, b=nB).detach().cpu()
    nT, b, h, w = trans_img.shape
    
    sliced_images = []
    gt_images = [] 
    for i in range(0, nT, configs.evaluation.num_snapshot):
        angle = torch.rad2deg(gammas[i][1]).item()
        transformed = rotate(trans_img[0], angle)[:, :, :, None]
        sliced_images.append(trans_img[i])
        gt_images.append(transformed) 
    result_image = torch.cat(sliced_images, dim=0)
    gt_image  = torch.cat(gt_images, dim=0) 
    result_image = result_image.reshape(-1, b, h, w)
    gt_image  = gt_image.reshape(-1, b, h, w)

    fig_gt, axs_gt = plt.subplots(nrows=b,ncols=len(gt_images),figsize=(2*len(sliced_images),2*b))
    fig_gt.suptitle("ground truth",fontsize=24)
    for i in range(b):
        for j in range(len(gt_images)):
            axs_gt[i, j].imshow(gt_image[j][i], cmap='gray')
            axs_gt[i, j].axis("off")

    fig_result, axs_result = plt.subplots(nrows=b,ncols=len(result_image),figsize=(2*len(sliced_images),2*b))
    fig_result.suptitle("transfer results",fontsize=24)
    for i in range(b):
        for j in range(len(result_image)):
            axs_result[i, j].imshow(result_image[j][i], cmap='gray')
            axs_result[i, j].axis("off")
    
    diff = result_image - gt_image
    fig_diff, axs_diff = plt.subplots(nrows=b,ncols=len(diff),figsize=(2*len(sliced_images),2*b))
    fig_diff.suptitle("diff",fontsize=24)
    for i in range(b):
        for j in range(len(diff)):
            axs_diff[i, j].imshow(diff[j][i], cmap='gray')
            axs_diff[i, j].axis("off")
    
    return [plot_to_tensor(fig_gt),plot_to_tensor(fig_result), plot_to_tensor(fig_diff)], diff

