import numpy as np
import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def plot_wts(results, split_property_index, n_splits = 4):
    fig1, axes1 = plt.subplots(1, n_splits, figsize=(16,4), sharex=True, sharey=True)
   
    for split_index in results.keys():
        weights = results[split_index]['weights']
        generated_samples = results[split_index]['Z_gen'][:,split_property_index]

        n_effective = results[split_index]['n_effective']
        
        #fig1.suptitle(f'Split Property {split_property_index}')
        #print(weights.shape)
        sns.scatterplot(ax = axes1[split_index], x = generated_samples, y = weights)

        axes1[split_index].set_title(f"split {split_index}\n #samples {len(weights)} \n #neff {int(n_effective)}")
        axes1[split_index].set_xlabel(f"property value")
        axes1[split_index].set_ylabel(f"weights value")
    plt.show()
	
def plot_wts_one_split(results, split_property_index, split_index):
    fig1, axes1 = plt.subplots(1, 1, figsize=(16,4), sharex=True, sharey=True)

    weights = results['weights']
    generated_samples = results['Z_gen'][:,split_property_index]

    n_effective = results['n_effective']
        

    sns.scatterplot(ax = axes1, x = generated_samples, y = weights)

    axes1.set_title(f"split {split_index}\n #samples {len(weights)} \n #neff {int(n_effective)}")
    axes1.set_xlabel(f"property value")
    axes1.set_ylabel(f"weights value")
    plt.show()	

def plot_2ecdf(data1, data2):
    n_test_prop = data1.shape[1]
    fig, axes = plt.subplots(1, n_test_prop, figsize=(14*4,7)) # 1,4 because 4 splits
    fig.suptitle(rf"ECDFs",fontsize=35, x=0.5, y=1.1)
    axes_index = 0
    for test_prop_index in range(0,n_test_prop):
        
        sns.ecdfplot(x=data1[:,test_prop_index] ,ax=axes[axes_index],label = rf'data1',linewidth = 5.0, ms=15)
        sns.ecdfplot(x=data2[:,test_prop_index], ax=axes[axes_index],label = rf'data2',linewidth = 5.0, ms=15)
        
        lgd = axes[axes_index].legend(loc='lower center', bbox_to_anchor=(0.5,-0.9),fontsize = 40)
        axes[axes_index].set_title(rf"for prop {test_prop_index}")
        axes[axes_index].title.set_size(40)
        axes[axes_index].tick_params(axis='x', labelsize=30)
        axes[axes_index].tick_params(axis='y', labelsize=30)
        axes[axes_index].set_ylabel('ECDF', fontsize=40)

        axes_index = axes_index + 1
    #plt.tight_layout()
    plt.show()

def plot_ecdf_one_split(split_strat, results, n_test_prop, split_property_index,  model_type):

    fig, axes = plt.subplots(1, n_test_prop, figsize=(14*4,7)) # 1,4 because 4 splits
    fig.suptitle(rf"For {split_strat} splits, $\ell = $ {split_property_index+1}",fontsize=35, x=0.5, y=1.1)
    
    axes_index = 0
    Z_gen = results["Z_gen"]
    Z_held = results["Z_held"]
    Z_train = results["Z_train"]
    stats = results["stats"]
    weights = results["weights"]
    
    for test_prop_index in range(0,n_test_prop):
        stat = stats[test_prop_index]
        sns.ecdfplot(x=Z_gen[:,test_prop_index] ,ax=axes[axes_index],label = rf'Un Weighted {model_type}',linewidth = 5.0, ms=15)
        sns.ecdfplot(x=Z_gen[:,test_prop_index], weights=weights ,ax=axes[axes_index],label = rf'Weighted {model_type}({round(stat,2)})',linewidth = 5.0, ms=15)
        sns.ecdfplot(x=Z_held[:,test_prop_index],ax=axes[axes_index],label = rf'Held data for $Z_{test_prop_index+1}$',color="black",linewidth = 5.0)
        sns.ecdfplot(x=Z_train[:,test_prop_index],ax=axes[axes_index],label = rf'Train data for $Z_{test_prop_index+1}$',color="red",linewidth = 5.0)
 
        lgd = axes[axes_index].legend(loc='lower center', bbox_to_anchor=(0.5,-0.9),fontsize = 40)
        axes[axes_index].set_title(rf"test property $\ell' = $ {test_prop_index+1} ")
        axes[axes_index].title.set_size(40)
        axes[axes_index].tick_params(axis='x', labelsize=30)
        axes[axes_index].tick_params(axis='y', labelsize=30)
        axes[axes_index].set_ylabel('Weighted ECDF', fontsize=40)

        axes_index = axes_index + 1
    #plt.tight_layout()   
    plt.show()
def plot_ecdf_for_data(split_strat, results, test_prop_index , split_property_index,  model_type):
    #print(test_prop_index)
    #print(split_prop_index)
    fig, axes = plt.subplots(1, 4, figsize=(14*4,7)) # 1,4 because 4 splits
    fig.suptitle(rf"For {split_strat} splits, $\ell = $ {split_property_index+1}, $\ell' = $ {test_prop_index+1}",fontsize=35, x=0.5, y=1.1)
    
    axes_index = 0
    for split_index in range(0,4):
        Z_gen = results[split_index]["Z_gen"]
        Z_held = results[split_index]["Z_held"]
        Z_train = results[split_index]["Z_train"]
        stats = results[split_index]["stats"]
        weights = results[split_index]["weights"]

        stat = stats[test_prop_index]

        #sorted_Z = np.sort(Z_gen[:,test_prop_index])
        #n = len(sorted_Z)
        #ecdf = np.arange(1, n + 1) / n
        #axes[axes_index].step(sorted_Z, ecdf, where='post',label=rf'Unweighted {model_type}({round(stat[0],2)})')
        sns.ecdfplot(x=Z_gen[:,test_prop_index] ,ax=axes[axes_index],label = rf'Un Weighted {model_type}',linewidth = 5.0, ms=15)
        
        #sorted_Z_W = np.array(sorted(zip(Z_gen[:,test_prop_index], weights), key=lambda x: x[0]))
        #sorted_Z = sorted_Z_W[:, 0]
        #sorted_W = sorted_Z_W[:, 1]
        #cumulative_weights = np.cumsum(sorted_W)
        #normalized_weights = cumulative_weights / cumulative_weights[-1]
        #axes[axes_index].step(sorted_Z, normalized_weights, where='post', label=rf'{model_type}({round(stat[0],2)})',linewidth = 3.0)
        sns.ecdfplot(x=Z_gen[:,test_prop_index], weights=weights ,ax=axes[axes_index],label = rf'Weighted {model_type}({round(stat,2)})',linewidth = 5.0, ms=15)
        


        # always plot ECDF of held-out data
        #sorted_Zheld = np.sort(Z_held[:,test_prop_index])
        #n = len(sorted_Zheld)
        #ecdf = np.arange(1, n + 1) / n
        #axes[axes_index].step(sorted_Zheld, ecdf, where='post', label=rf'Held data for $Z_{test_prop_index+1}$',color="black",linewidth = 3.0)
        sns.ecdfplot(x=Z_held[:,test_prop_index],ax=axes[axes_index],label = rf'Held data for $Z_{test_prop_index+1}$',color="black",linewidth = 5.0)
        
        # always plot ECDF of train data
        #sorted_Ztrain = np.sort(Z_train[:,test_prop_index])
        #n = len(sorted_Ztrain)
        #ecdf = np.arange(1, n + 1) / n
        #axes[axes_index].step(sorted_Ztrain, ecdf, where='post', label=rf'Train data for $Z_{test_prop_index+1}$',color="red",linewidth = 3.0)
        sns.ecdfplot(x=Z_train[:,test_prop_index],ax=axes[axes_index],label = rf'Train data for $Z_{test_prop_index+1}$',color="red",linewidth = 5.0)
        

        #axes[axes_index].legend(loc='lower center', bbox_to_anchor=(0.5,-0.9),fontsize = 30)
        #axes[axes_index].set_title(rf"For $j = $ {split_index+1}")
        #axes[axes_index].title.set_size(35)
        #axes[axes_index].tick_params(axis='x', labelsize=20)
        #axes[axes_index].tick_params(axis='y', labelsize=20)
        #axes[axes_index].set_ylabel('Weighted ECDF', fontsize=30)
        
        lgd = axes[axes_index].legend(loc='lower center', bbox_to_anchor=(0.5,-0.9),fontsize = 40)
        axes[axes_index].set_title(rf"Split Index {split_index+1} ")
        axes[axes_index].title.set_size(40)
        axes[axes_index].tick_params(axis='x', labelsize=30)
        axes[axes_index].tick_params(axis='y', labelsize=30)
        axes[axes_index].set_ylabel('Weighted ECDF', fontsize=40)

        axes_index = axes_index + 1
    #plt.tight_layout()   
    plt.show()
	
	
def plot_splits(splits, Z_all, prop_index):
    n_splits = len(splits)
    Z_all_l = Z_all[:,prop_index]
    bins = np.linspace(np.min(Z_all_l), np.max(Z_all_l), int(np.sqrt(len(Z_all_l))))
    #fig, axes_mat = plt.subplots(2, n_splits+1, figsize=np.array([4*n_splits,3*2])*0.8, dpi=150, sharex=True, sharey=True)
    
    fig = plt.figure(figsize=np.array([4*n_splits,4*2])*0.8, dpi=150)
    gs = fig.add_gridspec(2, n_splits+1)
    ax1 = fig.add_subplot(gs[:, 0])
    #ax1.get_shared_yaxes().remove(ax1)
    #axes_mat[0,0].get_shared_x_axes().remove(axes_mat[0,0])
    #axes_mat[1,0].get_shared_x_axes().remove(axes_mat[1,0])
    #fig.suptitle(f'sharpness_scale={sharpness_scale}, epsilon_base={epsilon_base}')
    sns.histplot(x=Z_all_l, bins=bins, ax=ax1)
    ax1.set_title(f'All data')
    
    for split_i, split in enumerate(list(splits)):
        for jj, ind in enumerate(split):
            #print(ind, jj, split_i)
            ax = fig.add_subplot(gs[jj, split_i+1], sharex=ax1, sharey=ax1)
            sns.histplot(x=Z_all_l[ind], bins=bins, ax=ax)
            ax.set_title(f'Split {split_i} {"train" if jj == 0 else "held"}') 
  
    plt.tight_layout()

