import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import StratifiedKFold
from sklearn.svm import SVC
import itertools
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split 
from joblib import Parallel, delayed
from grakel import WeisfeilerLehman
import ot
from Kernels import FGW_kernels
#%% 

def hamming_dist(x,y):
    return len([i for i, j in zip(x, y) if i != j])   


def WL_matrix_parallel(attr_graphs, wl , n_jobs=2, verbose=True, dtype=np.float64):
    """
        Compute pairwise FGW matrix 
    """
    kernel_func = WeisfeilerLehman(n_jobs=n_jobs, verbose=verbose, normalize=False, n_iter=wl)
    D = kernel_func.fit_transform(attr_graphs)
    return D


def compute_wasserstein_distance_by_input(F1, F2, h1, h2, F1sq, ones_h1, dist='hamming'):
    '''
    Generate the Wasserstein distance matrix for the graphs embedded 
    in label_sequences
    '''
    # Get cost matrix
    if dist == 'euclidean':
        first_term = F1sq.dot( np.ones((F2.shape[1], F2.shape[0])))
        second_term = ones_h1.dot( (F2**2).T)
        M = first_term + second_term - 2 * F1.dot( F2.T)
    else:
        M = np.zeros((F1.shape[0], F2.shape[0]))
        for ni in range(F1.shape[0]):
            for nj in range(F2.shape[0]):
                M[ni,nj] = hamming_dist(F1[ni], F2[nj])
        M = M**2    
    wass = ot.emd2(h1, h2, M)
    return wass

def compute_wasserstein_distance_parallel(features, masses, dist, n_jobs=2):
    '''
    Generate the Wasserstein distance matrix for the graphs embedded 
    in label_sequences
    '''
    # Get the iteration number from the embedding file
    n = len(features)
    
    D = np.zeros((n,n))
    # Iterate over pairs of graphs
    for i, Fi in tqdm(enumerate(features[:(n-1)]),desc='dist matrix'):
        hi = masses[i]
        if dist == 'euclidean':
            Fisq = Fi ** 2
            ones_hi = np.ones((Fi.shape[0], Fi.shape[1]))
        else:
            Fisq = None
            ones_hi = None
        all_dists = Parallel(n_jobs=n_jobs)(delayed(compute_wasserstein_distance_by_input)(Fi, features[j], hi, masses[j], Fisq, ones_hi, dist = dist) for j in range(i+1, n))
        
        for k in range(len(all_dists)):
             D[i, k+i+1] = all_dists[k]
    
    D = D + D.T
    return D


def compute_wl_distance_by_input(F1, F2, F1sq, ones_h1, dist='hamming'):
    '''
    Generate the Wasserstein distance matrix for the graphs embedded 
    in label_sequences
    '''
    # Get cost matrix
    if dist == 'euclidean':
        first_term = F1sq.dot( np.ones((F2.shape[1], F2.shape[0])))
        second_term = ones_h1.dot( (F2**2).T)
        M = first_term + second_term - 2 * F1.dot( F2.T)
    else:
        M = np.zeros((F1.shape[0], F2.shape[0]))
        for ni in range(F1.shape[0]):
            for nj in range(F2.shape[0]):
                M[ni,nj] = hamming_dist(F1[ni], F2[nj])
        M = M**2    
    return M.sum()

def compute_wl_distance_parallel(features, dist, n_jobs=2):
    '''
    Generate the Wasserstein distance matrix for the graphs embedded 
    in label_sequences
    '''
    # Get the iteration number from the embedding file
    n = len(features)
    
    D = np.zeros((n,n))
    # Iterate over pairs of graphs
    for i, Fi in tqdm(enumerate(features[:(n-1)]),desc='dist matrix'):
        if dist == 'euclidean':
            Fisq = Fi ** 2
            ones_hi = np.ones((Fi.shape[0], Fi.shape[1]))
        else:
            Fisq = None
            ones_hi = None
        all_dists = Parallel(n_jobs=n_jobs)(delayed(compute_wl_distance_by_input)(Fi, features[j], Fisq, ones_hi, dist = dist) for j in range(i+1, n))
        
        for k in range(len(all_dists)):
             D[i, k+i+1] = all_dists[k]
    
    D = D + D.T
    return D