import numpy as np
import random
import tensorflow as tf
import io
def get_model_summary(model):
    stream = io.StringIO()
    model.summary(print_fn=lambda x: stream.write(x + '\n'))
    summary_string = stream.getvalue()
    stream.close()
    return summary_string


def data_poison(images, poison_number):
    poison_list = random.sample(range(images.shape[0]), poison_number)
    images[poison_list,27,27] = 1.0
    images[poison_list,26,26] = 1.0
    images[poison_list,25,27] = 1.0
    images[poison_list,27,25] = 1.0
    return images, poison_list

def data_poison_cifar(images, poison_number):
    poison_list = random.sample(range(images.shape[0]), poison_number)
    print(images.shape)
    positions = [   [26,28],  [26,29], [26,30],[26,31],
                    [27,28],  [27,29], [27,30],[27,31],
                    [28,28],  [28,29], [28,30],[28,31],
                    [29,28],  [29,29], [29,30],[29,31],
                    [30,28],  [30,29], [30,30],[30,31],
                    [31,28],  [31,29], [31,30],[31,31],
                    ]
    for pos in positions: 
        images[poison_list, pos[0],pos[1],:]= 1.0
    return images, poison_list

def get_poisoned_matrix(passive_matrix, need_poison, poison_grad, amplify_rate=1):
    poisoned_matrix = passive_matrix.numpy()
    poisoned_matrix[need_poison] = poison_grad*amplify_rate
    poisoned_matrix = tf.convert_to_tensor(poisoned_matrix, tf.float32, name='poisoned_matrix')
    return poisoned_matrix

def copy_grad(passive_matrix, need_copy):
    poison_grad = passive_matrix[need_copy].numpy()
    return poison_grad[0]

def need_poison_trigger_check(images):
    need_poison_list = [True if images[indx,27,27] > 0.9 and \
                        images[indx,26,26] > 0.9 and \
                        images[indx,25,27] > 0.9 and \
                        images[indx,27,25] > 0.9 else False\
                        for indx in range(len(images))]
    return np.array(need_poison_list)


def need_poison_trigger_check_cifar(images):
    
    need_poison_list = [True if 

                        images[indx,26,28,0] ==1.0 and \
                        images[indx,26,28,1] ==1.0 and \
                        images[indx,26,28,2] ==1.0 and \
                        images[indx,26,29,0] ==1.0 and \
                        images[indx,26,29,1] ==1.0 and \
                        images[indx,26,29,2] ==1.0 and \
                        images[indx,26,30,0] ==1.0 and \
                        images[indx,26,30,1] ==1.0 and \
                        images[indx,26,30,2] ==1.0 and \
                        images[indx,26,31,0] ==1.0 and \
                        images[indx,26,31,1] ==1.0 and \
                        images[indx,26,31,2]  ==1.0 and \

                        images[indx,27,28,0] ==1.0 and \
                        images[indx,27,28,1] ==1.0 and \
                        images[indx,27,28,2] ==1.0 and \
                        images[indx,27,29,0] ==1.0 and \
                        images[indx,27,29,1] ==1.0 and \
                        images[indx,27,29,2] ==1.0 and \
                        images[indx,27,30,0] ==1.0 and \
                        images[indx,27,30,1] ==1.0 and \
                        images[indx,27,30,2] ==1.0 and \
                        images[indx,27,31,0] ==1.0 and \
                        images[indx,27,31,1] ==1.0 and \
                        images[indx,27,31,2]  ==1.0 and \
                        
                        images[indx,28,28,0] ==1.0 and \
                        images[indx,28,28,1] ==1.0 and \
                        images[indx,28,28,2] ==1.0 and \
                        images[indx,28,29,0] ==1.0 and \
                        images[indx,28,29,1] ==1.0 and \
                        images[indx,28,29,2] ==1.0 and \
                        images[indx,28,30,0] ==1.0 and \
                        images[indx,28,30,1] ==1.0 and \
                        images[indx,28,30,2]  ==1.0 and \
                        images[indx,28,31,0] ==1.0 and \
                        images[indx,28,31,1] ==1.0 and \
                        images[indx,28,31,2]  ==1.0 and \
                        
                        images[indx,29,28,0] ==1.0 and \
                        images[indx,29,28,1] ==1.0 and \
                        images[indx,29,28,2] ==1.0 and \
                        images[indx,29,29,0] ==1.0 and \
                        images[indx,29,29,1] ==1.0 and \
                        images[indx,29,29,2] ==1.0 and \
                        images[indx,29,30,0] ==1.0 and \
                        images[indx,29,30,1] ==1.0 and \
                        images[indx,29,30,2]  ==1.0 and \
                        images[indx,29,31,0] ==1.0 and \
                        images[indx,29,31,1] ==1.0 and \
                        images[indx,29,31,2]  ==1.0 and \
                        
                        images[indx,30,28,0] ==1.0 and \
                        images[indx,30,28,1] ==1.0 and \
                        images[indx,30,28,2] ==1.0 and \
                        images[indx,30,29,0] ==1.0 and \
                        images[indx,30,29,1] ==1.0 and \
                        images[indx,30,29,2] ==1.0 and \
                        images[indx,30,30,0] ==1.0 and \
                        images[indx,30,30,1] ==1.0 and \
                        images[indx,30,30,2] ==1.0 and \
                        images[indx,30,31,0] ==1.0 and \
                        images[indx,30,31,1] ==1.0 and \
                        images[indx,30,31,2] ==1.0 and \
                       
                        
                        images[indx,31,28,0] ==1.0 and \
                        images[indx,31,28,1] ==1.0 and \
                        images[indx,31,28,2] ==1.0 and \
                        images[indx,31,30,0] ==1.0 and \
                        images[indx,31,30,1] ==1.0 and \
                        images[indx,31,30,2] ==1.0 and \
                        images[indx,31,31,0] ==1.0 and \
                        images[indx,31,31,1] ==1.0 and \
                        images[indx,31,31,2] ==1.0 and \
                        images[indx,31,29,0] ==1.0 and \
                        images[indx,31,29,1]  ==1.0 and \
                        images[indx,31,29,2] ==1.0 else False\
                        for indx in range(len(images))]

    return np.array(need_poison_list)


def calculate_l21_blocknorm(e,A2): 
  e2=np.array(e)
  e3=e2.T
  #print((np.multiply(e3, e3)).shape)
  return (np.sqrt(A2@np.multiply(e3, e3))).sum()

def calculate_l21_rownorm(X):
    """
    This function calculates the l21 norm of a matrix X, i.e., \sum ||X[i,:]||_2
    Input:
    -----
    X: {numpy array}
    Output:
    ------
    l21_norm: {float}
    """
    return (np.sqrt(np.multiply(X, X).sum(1))).sum()

def calculate_l21_colnorm(X):
    """
    This function calculates the l21 norm of a matrix X, i.e., \sum ||X[:,j]||_2
    Input:
    -----
    X: {numpy array}
    Output:
    ------
    l21_norm: {float}
    """
    return (np.sqrt(np.multiply(X, X).sum(0))).sum()


