import concurrent.futures
import numpy as np
import pyemd
import torch
import torch.nn.functional as F

from collections import defaultdict
from scipy.linalg import toeplitz
from torch.utils.data import Subset
from tqdm import tqdm

from ..dataset import DAGDataset

def get_out_adj_list(src, dst):
    src = src.tolist()
    dst = dst.tolist()

    out_adj_list = defaultdict(list)
    num_edges = len(src)
    for i in range(num_edges):
        out_adj_list[src[i]].append(dst[i])
    return out_adj_list

def parse_sample(sample):
    src, dst, x_n, _ = sample

    num_nodes = len(x_n)
    node_depth = torch.zeros(num_nodes).fill_(-1)

    in_deg = torch.bincount(dst, minlength=num_nodes).tolist()

    out_adj_list = get_out_adj_list(src, dst)

    frontiers = [
        u for u in range(num_nodes) if in_deg[u] == 0
    ]
    depth = 0
    layer_size = []
    while len(frontiers) > 0:
        layer_size.append(len(frontiers))

        next_frontiers = []
        for u in frontiers:
            node_depth[u] = depth
            for v in out_adj_list[u]:
                in_deg[v] -= 1
                if in_deg[v] == 0:
                    next_frontiers.append(v)

        depth += 1
        frontiers = next_frontiers

    if len(src) == 0:
        depth_diff_prob = torch.zeros(1)
    else:
        depth_diff = (node_depth[dst] - node_depth[src]).long()
        indicator = torch.zeros(len(src), max(depth_diff).item() + 1)
        indicator[torch.arange(len(src)), depth_diff] = 1.
        depth_diff_freq = indicator.sum(dim=0)
        depth_diff_prob = depth_diff_freq / depth_diff_freq.sum()

    if len(layer_size) == 0:
        layer_size_one_hot_prob = torch.zeros(1)
    else:
        layer_size_one_hot = F.one_hot(torch.LongTensor(layer_size))
        layer_size_one_hot_freq = layer_size_one_hot.sum(dim=0)
        layer_size_one_hot_prob = layer_size_one_hot_freq / layer_size_one_hot_freq.sum()

    return depth, layer_size_one_hot_prob, depth_diff_prob

def get_data_dict(set_q):
    data_dict = {
        'depth_diffs': [],
        'num_layers': [],
        'layer_size': []
    }

    for i in range(len(set_q)):
        num_layers_i, layer_size_one_hot_prob_i, depth_diff_prob_i = parse_sample(set_q[i])

        data_dict['num_layers'].append(num_layers_i)
        data_dict['layer_size'].append(layer_size_one_hot_prob_i.numpy())
        data_dict['depth_diffs'].append(depth_diff_prob_i.numpy())

    num_layers_count = torch.zeros(len(set_q), max(data_dict['num_layers']) + 1)
    num_layers_count[torch.arange(len(set_q)), data_dict['num_layers']] = 1.
    num_layers_freq = num_layers_count.sum(dim=0)
    num_layers_prob = num_layers_freq / num_layers_freq.sum()

    data_dict['num_layers'] = num_layers_prob

    return data_dict

def emd(p, q):
    max_len = max(len(p), len(q))

    p_padded = F.pad(p, (0, max_len - len(p)))
    q_padded = F.pad(q, (0, max_len - len(q)))

    return (
        torch.cumsum(p_padded, dim=0) - torch.cumsum(q_padded, dim=0)
    ).abs().sum().item()

def gaussian_emd(x, y, sigma=1.0, distance_scaling=1.0):
    ''' Gaussian kernel with squared distance in exponential term replaced by EMD
    Args:
      x, y: 1D pmf of two distributions with the same support
      sigma: standard deviation
    '''
    support_size = max(len(x), len(y))
    d_mat = toeplitz(range(support_size)).astype(float)
    distance_mat = d_mat / distance_scaling

    # convert histogram values x and y to float, and make them equal len
    x = x.astype(float)
    y = y.astype(float)
    if len(x) < len(y):
        x = np.hstack((x, [0.0] * (support_size - len(x))))
    elif len(y) < len(x):
        y = np.hstack((y, [0.0] * (support_size - len(y))))

    emd = pyemd.emd(x, y, distance_mat)
    return np.exp(-emd * emd / (2 * sigma * sigma))

def kernel_parallel_unpacked(x, samples2):
    d = 0
    for s2 in samples2:
        d += gaussian_emd(x, s2)
    return d

def kernel_parallel_worker(t):
    return kernel_parallel_unpacked(*t)

def disc(samples1, samples2, parallel):
    ''' Discrepancy between 2 samples
    '''
    d = 0
    if parallel:
        with concurrent.futures.ProcessPoolExecutor() as executor:
            for dist in executor.map(kernel_parallel_worker,
                    [(s1, samples2) for s1 in samples1]):
                d += dist
    else:
        for s1 in tqdm(samples1):
            for s2 in samples2:
                d += gaussian_emd(s1, s2)
    d /= len(samples1) * len(samples2)
    return d

def get_mmd(real_dist_list, syn_dist_list, parallel=True):
    return disc(real_dist_list, real_dist_list, parallel) + disc(syn_dist_list, syn_dist_list, parallel) - \
            2 * disc(real_dist_list, syn_dist_list, parallel)

def eval_sets(real_set_q, syn_set_q):
    real_data_dict = get_data_dict(real_set_q)
    syn_data_dict = get_data_dict(syn_set_q)

    data_dict = {
        'num_layers_emd': emd(real_data_dict['num_layers'], syn_data_dict['num_layers']),
        'layer_size_mmd': get_mmd(real_data_dict['layer_size'], syn_data_dict['layer_size']),
        # 'depth_diffs_mmd': get_mmd(real_data_dict['depth_diffs'], syn_data_dict['depth_diffs'])
    }

    return data_dict

def sort_by_label(subset):
    subset_ = []

    for i in range(len(subset)):
        subset_.append(subset[i])

    subset_ = sorted(subset_, key=lambda x: x[3])

    return subset_

def create_dag_dataset(list_set):
    # 0 is just a placeholder here.
    dag_dataset = DAGDataset(num_categories=0, label=True)

    for i in range(len(list_set)):
        src_i, dst_i, x_n_i, y_i = list_set[i]
        dag_dataset.add_data(src_i, dst_i, x_n_i, y_i)

    return dag_dataset

def eval_quantile_stats(real_set, syn_set, num_quantiles=4):
    real_set = sort_by_label(real_set)
    syn_set = sort_by_label(syn_set)

    if len(real_set) != len(syn_set):
        assert len(syn_set) < len(real_set)
        print('Warning: syn_set is smaller than real_set, will pad it with empty graphs')

        for i in range(len(real_set)):
            if (i >= len(syn_set)) or (real_set[i][-1] != syn_set[i][-1]):
                syn_set.insert(i, (torch.tensor([]).long(), torch.tensor([]).long(), torch.tensor([]), real_set[i][-1]))

    real_set = create_dag_dataset(real_set)
    syn_set = create_dag_dataset(syn_set)

    all_results = {
        'num_layers_emd': [],
        'layer_size_mmd': [],
    }

    num_samples = len(real_set)
    quantiles = np.linspace(0, num_samples, num_quantiles + 1)
    for q in range(num_quantiles):
        start_q = int(quantiles[q])
        end_q = int(quantiles[q + 1])
        indices_q = list(range(start_q, end_q))

        real_set_q = Subset(real_set, indices_q)
        syn_set_q = Subset(syn_set, indices_q)

        results_q = eval_sets(real_set_q, syn_set_q)

        for key in all_results.keys():
            all_results[key].append(results_q[key])

    for key in ['num_layers_emd', 'layer_size_mmd']:
        print(f'{key}: ', float(np.mean(all_results[key])))
