import matplotlib
import matplotlib.pyplot as plt
import numpy as np

def check_class_balance(X_train, y_train, X_test, y_test):
    lables_list, counts1 = np.unique(y_train, return_counts=True)
    lables_list, counts2 = np.unique(y_test, return_counts=True)
    train_counts_ratio = counts1/max(counts1)
    test_counts_ratio = counts2/max(counts2)
    
    
    
    fig, ax = plt.subplots(1, max(y_train)+2, figsize=(5*(max(y_train)+2), 4))
    for i in lables_list:
        ax[i].set_title(lables_list[i], size=16)
        a = np.where(y_train == i)
        idx = np.random.choice(a[0])
        sample = X_train[idx]
        ax[i].tick_params(labelsize=14)
        ax[i].plot(sample, color='C'+str(i))


    x = np.arange(len(lables_list))  # the label locations
    width = 0.35  # the width of the bars

    rects1 = ax[-1].bar(x - width/2, train_counts_ratio, width, label='Train')
    rects2 = ax[-1].bar(x + width/2, test_counts_ratio, width, label='Test')

    # Add some text for labels, title and custom x-axis tick labels, etc.
    ax[-1].set_ylabel('ratio')
    ax[-1].set_title('ratio by train and test')
    ax[-1].set_xticks(x)
    ax[-1].set_xticklabels(lables_list)


    fig.tight_layout()

    plt.show()
    
def draw_proportion_cirlces(proportion_matrix):
    
    # for this function, ratio of areas of each circle is ratio of proportion
    # for example propiton 0.2 and 0.8 the is 1:4
    # the ratio of area of the two circle is 1:4 
    # the ratio of radius of the two circle is 1:2
    
    x_length = np.linspace(0, proportion_matrix.shape[1]-1, proportion_matrix.shape[1])
    y_length = np.linspace(0, proportion_matrix.shape[0]-1, proportion_matrix.shape[0])
    x_index, y_index= np.meshgrid(x_length, y_length)
    x_index_flatten = x_index.flatten()
    y_index_flatten = y_index.flatten()
    proportion_matrix_flatten = proportion_matrix.flatten()
    
    size_matrix_flatten= 0.45*np.sqrt(proportion_matrix_flatten*(proportion_matrix.shape[0]*proportion_matrix.shape[1]))

    fig, ax = plt.subplots(figsize=(proportion_matrix.shape[1],proportion_matrix.shape[0]))
    for i in range(len(x_index_flatten)):
        circle1 = plt.Circle((x_index_flatten[i], y_index_flatten[i]), size_matrix_flatten[i], color = 'r')
        ax.add_patch(circle1)


    #ax.scatter(x_index_flatten,y_index_flatten, s = 10000* proportion_matrix_flatten, c ='r')
    ax.set_xlabel('Client index')
    ax.set_ylabel('Class index')
    ax.set_title('Proportions of data in each clients')
    ax.set_xticks(np.arange(-1, proportion_matrix.shape[1]+1, 1.0))
    ax.set_yticks(np.arange(-1, proportion_matrix.shape[0]+1, 1.0))
    ax.invert_yaxis()
    ax.grid(True)
    ax.axis('scaled')