from __future__ import print_function
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn

from torchvision import datasets, transforms
from torch.optim import Adam

import json
import numpy as np
import copy
import argparse
import os
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.utils.data
import torchvision.datasets as dset
import torchvision.utils as vutils
import torch.nn.functional as F
import warnings
from torchvision.datasets.vision import VisionDataset
from PIL import Image
import torchvision
from torch import nn
import logging
import matplotlib.pyplot as plt
import itertools
# import tensorflow as tf
import io
import PIL.Image
from torchvision.transforms import ToTensor
from numpy import linalg as la
from scipy.linalg import block_diag

class DataGather(object):
    def __init__(self, *args):
        self.keys = args
        self.data = self.get_empty_data_dict()

    def get_empty_data_dict(self):
        dic={}
        for key in self.keys:
            dic[key]=[]
        return dic

    def insert(self, keys,data):
        assert len(keys)==len(data)
        for i in range(len(keys)):
            if isinstance(data[i],torch.Tensor):
                tem=data[i].item()
            else:
                assert isinstance(data[i],float)
                tem=data[i]
            self.data[keys[i]].append(tem)

    def flush(self):
        self.data = self.get_empty_data_dict()

    def get_mean(self):
        res=[]
        for key in self.keys:
            if len(self.data[key])>0:
                res.append(np.mean(self.data[key]))
            else:
                res.append(0)
        return res

    def get_min(self):
        res = []
        for key in self.keys:
            if len(self.data[key])>0:
                res.append(np.min(self.data[key]))
            else:
                res.append(None)
        return res

    def get_max(self):
        res = []
        for key in self.keys:
            if len(self.data[key])>0:
                res.append(np.max(self.data[key]))
            else:
                res.append(None)
        return res

    def get_sum(self):
        res=[]
        for key in self.keys:
            if len(self.data[key])>0:
                res.append(np.sum(self.data[key]))
            else:
                res.append(0)
        return res

    def get_report(self,options):
        mins=self.get_min()
        means=self.get_mean()
        maxs=self.get_max()
        res=[]
        for i in range(len(mins)):
            res.append([mins[i],maxs[i],means[i]])
        res1=[]
        for i in range(len(mins)):
            res1.append(res[i][options[i]])
        return res1



def infinit_iter(loader):
    while 1:
        for j in loader:
            yield j

def check_nan(names,data_lis,tb_step,args,break_flag=True):
    if args['check_nan_flags']:
        for i  in range(len(data_lis)):
            if isinstance(data_lis[i],torch.Tensor):
                if torch.any(torch.isnan(data_lis[i])):
                    print(names[i]+' is nan at step '+str(tb_step))
                    if break_flag:
                        raise
                    return True
            elif np.any(np.isnan(data_lis[i])):
                print(names[i] + ' is nan at step ' + str(tb_step))
                if break_flag:
                    raise
                return True

def check_nan_model(model,tb_step,args,break_flag=True):
    if args['check_nan_flags']:
        for name, parms in model.named_parameters():
            check_nan([name],[parms],tb_step,break_flag)
            if not (parms.grad is None):
                check_nan([name+' grad'],[parms],tb_step,break_flag)


def multi_tb_writer(writer,names,datas,tb_step,prefix=''):
    assert len(names)==len(datas)
    for i in range(len(names)):
        writer.add_scalar(prefix+'/'+names[i],datas[i],tb_step)

def multi_tb_writer_hist(writer,names,datas,tb_step,prefix=''):
    assert len(names) == len(datas)
    for i in range(len(datas)):
        writer.add_histogram(prefix+'/'+names[i],datas[i],tb_step)

def model_params_tb_writer(writer,model,tb_step,prefix=''):
    for name, parms in model.named_parameters():
        try:
            writer.add_histogram(prefix+'_parms/' + name, parms, tb_step)
        except:
            pass
        if not (parms.grad is None):
            writer.add_scalar(prefix+'_grad/' + name, torch.norm(parms.grad), tb_step)


def im_writer(writer,images,tb_step,args,prefix='',normalization=False):
    assert images.shape[0]>=args['figure_num']
    if images.dim()==2:
        images=images.view(-1,args['channel_input'],args['imagesize'],args['imagesize'])
    if normalization:
        images=(images-images[:].min())/(images[:].max()-images[:].min())
    img_grid = torchvision.utils.make_grid(images[0:args['figure_num'],:,:,:],args['figure_each_row'])
    writer.add_image(prefix+'/'+str(tb_step),img_grid)

def plot_matrix(cm, class_names='',add_value=False):

    figure = plt.figure()
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.colorbar()
    if len(class_names)>0:
        tick_marks = np.arange(len(class_names))
        plt.xticks(tick_marks, class_names, rotation=45)
        plt.yticks(tick_marks, class_names)
    if add_value:
        labels = np.around(cm.astype('float'),decimals=2)
        threshold = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            color = "white" if cm[i, j] > threshold else "black"
            plt.text(j, i, labels[i, j], horizontalalignment="center", color=color)

    plt.show()

def plot_embedding(X, title=None, y=None):
    x_min, x_max = np.min(X, 0), np.max(X, 0)
    X = (X - x_min) / (x_max - x_min)

    plt.figure()
    ax = plt.subplot(111)
    if y is None:
        plt.scatter(X[:, 0], X[:, 1], alpha=0.3
                    )
    else:
        for i in range(X.shape[0]):
            plt.scatter(X[i, 0], X[i, 1],
                        color=plt.cm.Set1(y[i] / 10.), alpha=0.3
                        )

    plt.xticks([]), plt.yticks([])
    if title is not None:
        plt.title(title)
    plt.show()

def plot_matrix_to_tensor(cm, class_names='',add_value=False):

    figure = plt.figure()
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.colorbar()
    if len(class_names)>0:
        tick_marks = np.arange(len(class_names))
        plt.xticks(tick_marks, class_names, rotation=45)
        plt.yticks(tick_marks, class_names)
    if add_value:
        labels = np.around(cm.astype('float'),decimals=2)
        threshold = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            color = "white" if cm[i, j] > threshold else "black"
            plt.text(j, i, labels[i, j], horizontalalignment="center", color=color)

    buf = io.BytesIO()
    plt.savefig(buf, format='jpeg')
    plt.close()
    buf.seek(0)
    image = PIL.Image.open(buf)
    image = ToTensor()(image)
    return image


# def plot_to_image(figure):
#   buf = io.BytesIO()
#   plt.savefig(buf, format='png')
#   plt.close(figure)
#   buf.seek(0)
#   image = tf.image.decode_png(buf.getvalue(), channels=4)
#   image = tf.expand_dims(image, 0)
#   return image.numpy()



def matrix_writer(writer,matrixs,tb_step,prefix=''):
    ind=0
    if isinstance(matrixs,torch.Tensor) and matrixs.dim==3:
        matrixs=[to_numpy(matrixs[i]) for i in range(matrixs.shape[0])]
    for matrix in matrixs:
        if isinstance(matrix, torch.Tensor):
            matrix=to_numpy(matrix)
        image = plot_matrix_to_tensor(matrix)
        writer.add_image(prefix+'/'+str(ind)+'/'+str(tb_step),image)
        ind=ind+1


def permute_dims(z_list):
    assert z_list[0].dim() == 2
    res=[]
    for z in z_list:
        B, _ = z.size()

        perm_z = []
        for z_j in z.split(1, 1):
            perm = torch.randperm(B).to(z.device)
            perm_z_j = z_j[perm].detach()
            perm_z.append(perm_z_j)
        perm_z=torch.cat(perm_z, 1)
        res.append(perm_z)
    return res




def save_model(model_dic,args,epoch):
    model_state_dic={}
    for key in model_dic.keys():
        if not 'tb_step' in key:
            model_state_dic[key]=model_dic[key].state_dict()
        else:
            model_state_dic[key] = model_dic[key]
    if not os.path.exists(args['save_path']):
        os.mkdir(args['save_path'])
    print('saveing to {}'.format(args['save_path'] + '/Epoch_' + str(epoch) + '.pth'))
    torch.save(model_state_dic, args['save_path'] + '/Epoch_' + str(epoch) + '.pth')


def load_model(model_dic,tb_step,args):
    print('loading from {}'.format(args['reload_path'] + '/Epoch_' + str(args['reload_from']) + '.pth'))
    checkpoint=torch.load(args['reload_path'] + '/Epoch_' + str(args['reload_from']) + '.pth')
    for key in model_dic.keys():
        if not 'tb_step' in key:
            model_dic[key].load_state_dict(checkpoint[key])
        else:
            tb_step=checkpoint[key]
    return tb_step


def specified_load(model_dic,tb_step,model_path):
    print('loading from {}'.format(model_path))
    checkpoint = torch.load(model_path)
    for key in model_dic.keys():
        if not 'tb_step' in key:
            model_dic[key].load_state_dict(checkpoint[key])
        else:
            tb_step=checkpoint[key]
    return tb_step


def kaiming_init(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        torch.nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        m.weight.data.fill_(1)
        if m.bias is not None:
            m.bias.data.fill_(0)


def normal_init(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        torch.nn.init.normal_(m.weight, 0, 0.02)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        m.weight.data.fill_(1)
        if m.bias is not None:
            m.bias.data.fill_(0)


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)



class Add_noise(nn.Module):
    def __init__(self,mean=0,std=1.0):
        self.mean=mean
        self.std=std
        super(Add_noise,self).__init__()

    def forward(self,input,device):
        assert isinstance(input,torch.Tensor)
        res=input+torch.randn(size=input.shape).to(device)
        return res

def recon_loss(x_recon,x):

    assert isinstance(x_recon, torch.Tensor)
    assert x.shape == x_recon.shape
    assert x.dim() == 4 and x_recon.dim() == 4


    loss = F.binary_cross_entropy_with_logits(x_recon, x, reduction='none')
    assert loss.dim()==4

    loss=loss.sum([1,2,3]).mean()

    return loss


def kl_divergence(mu, logvar):
    assert mu.dim()==4 and logvar.dim()==4
    assert mu.shape==logvar.shape

    kld = -0.5*(1+logvar-mu**2-logvar.exp()).sum([1,2,3]).mean()

    return kld

def kl_divergence_hard(mu, logvar):
    assert mu.shape == logvar.shape
    if mu.dim()==4 and logvar.dim()==4:
        kld = -0.5*(-mu**2).sum([1,2,3]).mean()
    elif mu.dim()==2 and logvar.dim()==2:
        kld = -0.5 * (-mu ** 2).sum([1]).mean()
    return kld


def correct_rate_func(out,label):
    return (torch.argmax(out,1)==label).float().mean()


def rec_feature(features,z, epoch,args,labels=None,y=None,prefix=''):
    if args['embed_interval'] and epoch % args['embed_interval'] == 0 \
            and len(features)*args['batch_size'] < args['embed_num']:
        features.append(to_numpy(z))
        if labels is not None:
            if  isinstance(y,torch.Tensor):
                if len(prefix)==0:
                    labels+=list(to_numpy(y))
                else:
                    tem=[prefix+str(item) for item in list(to_numpy(y))]
                    labels+=tem


def get_data(x,y,i,batch_size,device):
    train_num=x.shape[0]
    return torch.tensor(
        x[(i*batch_size%train_num):((i+1)*batch_size%train_num),:],
        dtype=torch.float32).to(device),\
           torch.tensor(
               y[(i*batch_size%train_num):((i+1)*batch_size%train_num)],
               dtype=torch.float32).to(device).long()


def to_numpy(tensor):
    return tensor.detach().cpu().numpy()





def Linear_Classifier(X,Y,out_dim,steps = 2000,batch_size = 256,X_test=None,Y_test=None):
    dim1 = X.shape[1]
    dim2 = out_dim



    lr = 1e-3

    device = torch.device("cuda")

    net =nn.Linear(dim1, dim2)
    net.to(device)
    loss_func = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    class lc:
        def __init__(self,net):
            self.net = net
        def score(self,x, y):
            x_tensor, y_tensor = torch.tensor(
                x,
                dtype=torch.float32).to(device), \
                                 torch.tensor(
                                     y,
                                     dtype=torch.float32).to(device).long()
            out = net(x_tensor)
            cr = correct_rate_func(out,y_tensor)
            return float(to_numpy(cr))


    lc1 = lc(net)
    rec=[]
    for i in range(steps):
        x_tensor, y_tensor = get_data(X,Y,i,batch_size,device)
        net.zero_grad()
        out = net(x_tensor)
        try:
            loss = loss_func(out, y_tensor)
            loss.backward()
            optimizer.step()

            if i % 500 == 0:
                print('network train step {} train loss {}'.format(i,to_numpy(loss)))
                print('network train step {} train correct_rate {}'
                      .format(i, to_numpy(correct_rate_func(out,y_tensor))))
                if X_test is not None:
                    rec.append(lc1.score(X_test, Y_test))
        except:
            pass


    return lc1,lc1.score(X,Y),rec

def get_device(gpu_list,no_cuda=False):
    use_cuda = not no_cuda and torch.cuda.is_available()
    device = torch.device('cuda:{}'.format(gpu_list[0]) if use_cuda else "cpu")

    return device

class DataParallel(nn.DataParallel):
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


def find_block(C,order,threshold):
    n=C.shape[0]
    BlockSizes = []
    remaining_basis = list(range(n))
    sorted_basis = []

    while len(remaining_basis) > 0:
        current_block = [remaining_basis[0]]
        current_block_size = 1
        if len(remaining_basis) > 1:
            for idx in remaining_basis[1:]:
                if np.abs(C[remaining_basis[0], idx]) > threshold:
                    current_block.append(idx)
                    current_block_size = current_block_size + 1

        for idx in current_block:
            sorted_basis.append(idx)
            remaining_basis.remove(idx)

        # do the following in case there are zero entries inside the block
        for k in range(order):
            current_block_extra = []
            if len(remaining_basis) > 0:
                for idx in remaining_basis:
                    for ind in current_block:
                        if np.abs(C[ind, idx]) > threshold:
                            current_block_extra.append(idx)
                            current_block_size = current_block_size + 1
                            break

            for idx in current_block_extra:
                sorted_basis.append(idx)
                remaining_basis.remove(idx)
                current_block.append(idx)

        BlockSizes.append(current_block_size)
    return BlockSizes,sorted_basis

def cluster(A,max_size=None,min_size=None,max_iter=500,order=2):

    n = len(A[0])	# size of the matrices to be simultaneously block diagonalized
    m = len(A)		# number of matrices to be simultaneously block diagonalized
    	# initialize the array that lists the size of each common block

    # B is a random self-adjoint matrix generated by matrices from A (and their conjugate transposes)

    object_best=-1

    for eta in np.linspace(-1,4,max_iter):
        B = np.zeros((n, n))
        for p in range(m):
            B = B + np.random.normal() * (A[p] + A[p].transpose())
        D, V = la.eigh(B)


        C = np.zeros((n,n))
        for p in range(m):
            tem=np.random.normal()
            C = C + tem*(A[p]+A[p].transpose())
        D=C
        C = V.transpose()@C@V

        tem = C.copy()
        for j in range(tem.shape[0]):
            tem[j, j] = 0
        threshold = np.mean(tem[tem!=0]) + eta * np.std(tem[tem!=0])

        BlockSizes, sorted_basis=find_block(C,order,threshold)

        tem=np.asarray(BlockSizes)
        tem[tem<min_size]=1
        tem[tem > max_size] = 1
        object=np.sum(np.log(tem))
        print('eta: %.2f threshold: %.3f max block size: %d block num: %d object: %.3f object_best: %.3f'
              %(eta,threshold,np.max(BlockSizes),len(BlockSizes),object,object_best))

        if object>object_best:
            object_best=object
            BlockSizes_best=BlockSizes
            threshold_best=threshold
            eta_best=eta
            P = V[:,sorted_basis]
            res=P.T@D@P

    for eta in np.linspace(eta_best-0.5,eta_best+0.5,max_iter):
        B = np.zeros((n, n))
        for p in range(m):
            B = B + np.random.normal() * (A[p] + A[p].transpose())
        D, V = la.eigh(B)


        C = np.zeros((n,n))
        for p in range(m):
            tem=np.random.normal()
            C = C + tem*(A[p]+A[p].transpose())
        D=C
        C = V.transpose()@C@V

        tem = C.copy()
        for j in range(tem.shape[0]):
            tem[j, j] = 0
        threshold = np.mean(tem[tem!=0]) + eta * np.std(tem[tem!=0])

        BlockSizes, sorted_basis=find_block(C,order,threshold)

        tem=np.asarray(BlockSizes)
        tem[tem<min_size]=1
        tem[tem > max_size] = 1
        object=np.sum(np.log(tem))
        print('eta: %.2f threshold: %.3f max block size: %d block num: %d object: %.3f object_best: %.3f'
              %(eta,threshold,np.max(BlockSizes),len(BlockSizes),object,object_best))

        if object>object_best:
            object_best=object
            BlockSizes_best=BlockSizes
            threshold_best=threshold
            P = V[:,sorted_basis]
            res=P.T@D@P

    return P, BlockSizes_best, res, threshold_best


def cluster_single(C,max_size=None,min_size=None,max_iter=500,order=2):

    n=C.shape[0]

    V=np.eye(n)

    tem = C.copy()
    for j in range(tem.shape[0]):
        tem[j, j] = 0
    mean = np.mean(tem[tem != 0])
    std = np.std(tem[tem != 0])

    object_best=-1

    for eta in np.linspace(-1,4,max_iter):

        threshold = mean + eta * std

        BlockSizes, sorted_basis=find_block(C,order,threshold)
        tem = np.asarray(BlockSizes)
        tem[tem < min_size] = 1
        tem[tem > max_size] = 1
        object = np.sum(np.log(tem))
        print('eta: %.2f threshold: %.3f max block size: %d block num: %d object: %.3f object_best: %.3f'
              % (eta, threshold, np.max(BlockSizes), len(BlockSizes), object, object_best))

        if object > object_best:
            object_best = object
            BlockSizes_best = BlockSizes
            threshold_best = threshold
            eta_best=eta
            P = V[:, sorted_basis]
            res = P.T @ C @ P

    for eta in np.linspace(eta_best-0.2,eta_best+0.2,max_iter):

        threshold = mean + eta * std

        BlockSizes, sorted_basis=find_block(C,order,threshold)
        tem = np.asarray(BlockSizes)
        tem[tem < min_size] = 1
        tem[tem > max_size] = 1
        object = np.sum(np.log(tem))
        print('eta: %.2f threshold: %.3f max block size: %d block num: %d object: %.3f object_best: %.3f'
              % (eta, threshold, np.max(BlockSizes), len(BlockSizes), object, object_best))

        if object > object_best:
            object_best = object
            BlockSizes_best = BlockSizes
            threshold_best = threshold
            P = V[:, sorted_basis]
            res = P.T @ C @ P


    return P, BlockSizes_best, res, threshold_best


def joint_block_diag(matrix_lis,plot=False, max_size_1=100, min_size_1=50, max_iter_1=50, order_1=3,
                     max_size_2=150,min_size_2=100, max_iter_2=50, order_2=10):
    P, BlockSizes, res, threshold = cluster(matrix_lis,
                                            max_size=max_size_1,
                                            min_size=min_size_1,
                                            max_iter=max_iter_1,
                                            order=order_1)
    res2 = res.copy()
    res2[np.abs(res) > threshold] = 1
    res2[np.abs(res) <= threshold] = 0
    block_list = [np.ones([i, i]) for i in BlockSizes]
    blocks = block_diag(*block_list)
    if plot:
        plot_matrix(res2 + blocks * 0.5)
        plot_matrix(res)

    C = res.copy()
    C[np.abs(res) <= threshold] = 0


    P3, BlockSizes3, res3, threshold3 = cluster_single(C,
                                                    max_size=max_size_2,
                                                    min_size=min_size_2,
                                                    max_iter=max_iter_2,
                                                    order=order_2)
    res4 = res3.copy()
    res4[np.abs(res3) > threshold3] = 1
    res4[np.abs(res3) <= threshold3] = 0
    block_list = [np.ones([i, i]) for i in BlockSizes3]
    blocks = block_diag(*block_list)
    if plot:
        plot_matrix(res4 + blocks * 0.5)
        plot_matrix(res3)

    return P@P3,BlockSizes3,threshold


def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0):
    return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha*iter_num / max_iter)) - (high - low) + low)


def Entropy(input_,dim=1):
    bs = input_.size(0)
    input_=input_/input_.sum(dim=dim)
    epsilon = 1e-8
    entropy = -input_ * torch.log(input_ + epsilon)
    entropy = torch.sum(entropy, dim=dim)
    return entropy


def grl_hook(coeff):
    def fun1(grad):
        return -coeff*grad.clone()
    return fun1


class ConsensusLoss(nn.Module):
    def __init__(self, nClass, div):
        super(ConsensusLoss, self).__init__()
        self.nClass = nClass
        self.div = div

    def forward(self, x, y):
        if self.div == 'kl':
            x = F.softmax(x, dim=1)
            y = F.log_softmax(y, dim=1)
            kl_div = F.kl_div(y, x, reduction='batchmean')  # x

            return kl_div
        elif self.div == 'kl_d':
            x = F.softmax(x, dim=1)
            y = F.log_softmax(y, dim=1)
            x_d = x.detach()
            kl_div = F.kl_div(y, x_d, reduction='batchmean')  # detached x

            return kl_div
        elif self.div == 'l1':
            x = F.softmax(x, dim=1)
            y = F.softmax(y, dim=1)
            l1_div = (x - y).abs().mean(1).mean()  # l1 norm

            return l1_div
        elif self.div == 'l2':
            x = F.softmax(x, dim=1)
            y = F.softmax(y, dim=1)
            l2_div = (x - y).pow(2).sum(1).sqrt().mean()  # l2 norm

            return l2_div
        elif self.div == 'neg_cos':
            x = F.softmax(x, dim=1)
            y = F.softmax(y, dim=1)

            neg_cos_div = 0.5 * (1 - ((x * y).sum(1) / x.norm(2, dim=1) / y.norm(2, dim=1))).mean()

            return neg_cos_div




def JS_divergence(P,Q):
    P=P/(P.sum(0,keepdim=True)+1e-10)
    Q=Q/(Q.sum(0,keepdim=True)+1e-10)
    tem=torch.log((P+Q)/2+1e-10)
    JS=(P*((P+1e-10).log()-tem)).sum(0)+(Q*((Q+1e-10).log()-tem)).sum(0)
    return JS.detach().cpu().numpy()


def project_onto_unit_simplex(prob):
    """
    Project an n-dim vector prob to the simplex Dn s.t.
    Dn = { x : x n-dim, 1 >= x >= 0, sum(x) = 1}
    :param prob: a numpy array. Each element is a probability.
    :return: projected probability
    """
    prob_length = len(prob)
    bget = False
    sorted_prob = -np.sort(-prob)
    tmpsum = 0

    for i in range(1, prob_length):
        tmpsum = tmpsum + sorted_prob[i-1]
        tmax = (tmpsum - 1) / i
        if tmax >= sorted_prob[i]:
            bget = True
            break

    if not bget:
        tmax = (tmpsum + sorted_prob[prob_length-1] - 1) / prob_length

    return np.maximum(0, prob - tmax)

def project_onto_unit_simplex_matrix(F):
    if isinstance(F,torch.Tensor):
        F_np=F.detach().cpu().numpy()
    else:
        F_np=F
    rec=[]
    for i in range(F.shape[1]):
        rec.append(project_onto_unit_simplex(F_np[:,i]).reshape([-1,1]))
    res=np.concatenate(rec,axis=1)
    return res