import torch
import numpy as np

#torch.manual_seed(3)
#np.random.seed(3)




def generate_dia_dominate_matrix():
    matrix_ = np.random.rand(10, 10)/10
    matrix_sum = matrix_.sum(axis=1)
    matrix_sum_dia = torch.eye(10) * (1 - matrix_sum)
    matrix = matrix_sum_dia + matrix_
    return matrix


def generate_random_matrix():
    matrix = torch.rand(10,10)
    matrix = matrix/matrix.sum(dim=-1,keepdim = True)
    return matrix

def generate_20_random_matrix(set_nums=20):
    matrix = torch.rand(set_nums,10)
    matrix = matrix/matrix.sum(dim=-1,keepdim = True)
    return matrix

def gene_noise_diff(matrix, noise_rate):
    noisy_matrix_pu_ = np.random.uniform(-1, 1, (10, 10))
    noisy_matrix_pu = np.sign(noisy_matrix_pu_)
    noisy_matrix = (matrix * noisy_matrix_pu) * noise_rate
    noisy_matrix_sum = noisy_matrix.sum(axis=1)
    noisy_matrix_sum_dia = torch.eye(10) * noisy_matrix_sum
    matrix = noisy_matrix_sum_dia - noisy_matrix
    return matrix
def generate_off_diagonal_same_matrix(set_number):
    matrix = np.ones((set_number, 10))
    max_prior = np.random.uniform(0.1,1,set_number)      
    temp = (1-max_prior)/10
    matrix = matrix*temp.reshape(-1,1)
    for i in range(set_number):
        if i<10:
            matrix[i,i]=max_prior[i]
        if i>=10:
            matrix[i,i-10]=max_prior[i]
    matrix = torch.from_numpy(matrix).float()
    return matrix