import torch
import numpy as np
from torch_geometric.utils import to_undirected, add_self_loops, remove_self_loops

from torch.distributions import Bernoulli, Binomial, Multinomial

from torch_geometric.data import Data
import os
import pickle as pkl
import torch_geometric.utils as tgutils



def get_permutation(original_len, device='auto', seed=0):
    # if device == 'auto':
    #     generator = torch.Generator(device='cpu')
    # else:
    #     generator = torch.Generator(device=device)
    generator = None
    perm = torch.randperm(original_len, generator=generator)
    return perm


def compute_split_idx(original_len, split_sizes, random=True, k_fold=-1):
    all_idx = torch.arange(original_len)
    if len(split_sizes) == 1:
        return [all_idx]
    if random:
        generator = None
        perm = torch.randperm(original_len, generator=generator)
        all_idx = all_idx[perm]

    start_idx, end_idx = 0, None
    all_idx_splits = []

    num_splits = len(split_sizes)
    for i, size in enumerate(split_sizes):
        assert isinstance(size, float)
        assert 0 < size
        assert 1 > size
        new_len = int(size * original_len)
        end_idx = new_len + start_idx
        if i == (num_splits - 1):
            all_idx_splits.append(all_idx[start_idx:])
        else:
            all_idx_splits.append(all_idx[start_idx:end_idx])
        start_idx = end_idx

    return all_idx_splits


def num_classes_fn(data) -> int:
    r"""Returns the number of classes in the dataset."""
    y = data.y
    if y is None:
        return 0
    elif y.numel() == y.size(0) and not torch.is_floating_point(y):
        return int(data.y.max()) + 1
    elif y.numel() == y.size(0) and torch.is_floating_point(y):
        return torch.unique(y).numel()
    else:
        return data.y.size(-1)


def get_bcsbm_datalist(n,
                       eps,
                       p,
                       q,
                       mu,
                       std_dev,
                       directed,
                       seed,
                       version,
                       root='datasets',
                       save_to_disk=False):
    mu = torch.FloatTensor(mu)

    # norm = torch.sqrt((mu**2).sum())

    name = f"n{n}"
    name += f"_eps{n}"
    name += f"_p{p}"
    name += f"_q{q}"
    name += f"_mu{mu[0].item()}"
    name += f"_fold{seed}"

    # assert directed == False
    assert root is not None
    data_folder = os.path.join(root, 'BCSBM', version)
    if not os.path.exists(data_folder):
        os.makedirs(data_folder)

    filename = os.path.join(data_folder, f"{name}.pkl")

    if  False: # os.path.exists(filename):
        print(f'Loading data: {filename}')
        with open(filename, 'rb') as handle:
            data = pkl.load(handle)
        return [data]

    if version == 'v1':


        o = binary_contextual_stochstic_blockmodel_graph(n=n,
                                                         p=p,
                                                         q=q,
                                                         mu=mu,
                                                         std_dev=std_dev,
                                                         eps=eps,
                                                         directed=directed)
    elif version == 'v2':
        o = binary_contextual_stochstic_blockmodel_graph_v2(n=n,
                                                            p=p,
                                                            q=q,
                                                            mu=mu,
                                                            std_dev=std_dev,
                                                            eps=eps,
                                                            directed=directed)
    else:
        raise NotImplementedError

    edge_index, label, node_features, edge_label, label_all = o

    edge_label_ind = torch.zeros((edge_index.shape[1], 2))


    edge_label_ind[:,0]=  label_all[edge_index[0]]
    edge_label_ind[:,1]=  label_all[edge_index[1]]


    data = Data(x=node_features, edge_index=edge_index, edge_attr=None, y=label,
                edge_label=edge_label,
                label_all=label_all,
                edge_label_ind=edge_label_ind)

    data.train_mask = torch.full([n], True)
    data.val_mask = torch.full([n], False)
    data.test_mask = torch.full([n], False)

    if save_to_disk:
        print(f'Saving data: {filename}')
        with open(filename, 'wb') as handle:
            pkl.dump(data, handle, protocol=pkl.HIGHEST_PROTOCOL)
    return [data]


def binary_contextual_stochstic_blockmodel_graph(n, p, q, mu, std_dev, eps, directed=False):
    n_class_1 = Bernoulli(eps).sample([n]).sum().item()
    n_class_0 = n - n_class_1
    block_sizes = torch.LongTensor([n_class_0, n_class_1])
    mu_list = [-1 * mu,
               1 * mu]

    edge_probs = torch.FloatTensor([[p, q],
                                    [q, p]])

    return contextual_stochastic_blockmodel_graph(block_sizes=block_sizes,
                                                  edge_probs=edge_probs,
                                                  mu_list=mu_list,
                                                  std_dev=std_dev,
                                                  directed=directed
                                                  )


def  binary_contextual_stochstic_blockmodel_graph_v2(n, p, q, mu, std_dev, eps, directed=False):
    assert eps == 0.5
    eps = 2/3.

    # n_class_12 = Binomial(n, eps).sample().item()
    # n_class_1 = Binomial(n_class_12, 0.5).sample().item()
    # n_class_2 = n_class_12 - n_class_1
    # n_class_0 = n - n_class_12
    # block_sizes = torch.LongTensor([n_class_0, n_class_1, n_class_2])

    # block_sizes = Multinomial(n, probs=torch.tensor([1/3,1/3,1/3])).sample().long()

    n0 = int(1/3*n)
    n1  = int(1/3*n)
    n2 = n - n0 - n1
    block_sizes = torch.LongTensor([n0, n1, n2])
    print(f"Blocks: {block_sizes}")
    mu_list = [0 * mu, -1 * mu, 1 * mu]

    # edge_probs = torch.FloatTensor([[p/2., q, q],
    #                                 [0, p/2., q],
    #                                 [0, 0, p/2.]])

    edge_probs = torch.FloatTensor([[p, q, q],
                                    [q, p, q],
                                    [q, q, p]])

    o = contextual_stochastic_blockmodel_graph(block_sizes=block_sizes,
                                               edge_probs=edge_probs,
                                               mu_list=mu_list,
                                               std_dev=std_dev,
                                               directed=directed
                                               )
    edge_index, label, node_features, edge_label = o
    # import seaborn as sns
    # import matplotlib.pyplot as plt
    # sns.scatterplot(x=node_features[:, 0], y=node_features[:, 1], hue=label)
    # plt.show()
    label_all = label.clone()
    label.clamp_max_(1)  # Binary classification
    return edge_index, label, node_features, edge_label, label_all


def contextual_stochastic_blockmodel_graph(block_sizes, edge_probs, mu_list, std_dev, directed=False):
    r"""Returns the :obj:`edge_index` of a stochastic blockmodel graph.

    Args:
        block_sizes ([int] or LongTensor): The sizes of blocks.
        edge_probs ([[float]] or FloatTensor): The density of edges going
            from each block to each other block. Must be symmetric if the
            graph is undirected.
        directed (bool, optional): If set to :obj:`True`, will return a
            directed graph. (default: :obj:`False`)
    """

    size, prob = block_sizes, edge_probs

    if not isinstance(size, torch.Tensor):
        size = torch.tensor(size, dtype=torch.long)
    if not isinstance(prob, torch.Tensor):
        prob = torch.tensor(prob, dtype=torch.float)

    assert size.dim() == 1
    assert prob.dim() == 2 and prob.size(0) == prob.size(1)
    assert size.size(0) == prob.size(0)
    if not directed:
        assert torch.allclose(prob, prob.t())

    assert len(mu_list) == prob.size(0)
    assert std_dev > 0
    assert all([isinstance(mu, torch.Tensor) for mu in mu_list])

    node_idx_list = []
    node_idx_class = []
    node_features_list = []

    agg_count = 0
    for i, b in enumerate(size):
        node_idx_list.append(size.new_full((b,), i))
        node_idx_class.append(list(range(agg_count, agg_count + b.item())))
        agg_count += b.item()

        node_feats_i = normal(loc=mu_list[i], scale=std_dev, shape=[b])
        node_features_list.append(node_feats_i)

    node_idx = torch.cat(node_idx_list)
    node_features = torch.cat(node_features_list)
    num_nodes = node_idx.size(0)

    if directed:
        idx = torch.arange((num_nodes - 1) * num_nodes)
        idx = idx.view(num_nodes - 1, num_nodes)
        idx = idx + torch.arange(1, num_nodes).view(-1, 1)
        idx = idx.view(-1)
        row = idx.div(num_nodes, rounding_mode='floor')
        col = idx % num_nodes
    else:
        row, col = torch.combinations(torch.arange(num_nodes), r=2).t()


    mask = bernoulli(probs=prob[node_idx[row], node_idx[col]])
    edge_index = torch.stack([row[mask], col[mask]], dim=0)

    edge_index = torch.cat((
        edge_index,
        torch.stack((edge_index[1], edge_index[0]), dim=0)
    ), dim=-1)
    edge_index, _ = add_self_loops(edge_index, None)
    edge_label = torch.zeros(edge_index.shape[1], dtype=torch.long)

    diff_classes = node_idx[edge_index[0]] != node_idx[edge_index[1]]
    edge_label[diff_classes] = 1  # inter-edges (different class)

    label = node_idx
    return edge_index, label, node_features, edge_label









def normal(loc, scale, shape, device='auto', seed=0):
    # if device == 'auto':
    #     generator = torch.Generator(device='cpu')
    # else:
    #     generator = torch.Generator(device=device)
    # generator.manual_seed(seed)
    generator = None

    normal = torch.distributions.Normal(loc=loc, scale=scale)
    shape = normal._extended_shape(shape)
    with torch.no_grad():
        return torch.normal(normal.loc.expand(shape), normal.scale.expand(shape), generator=generator)


def bernoulli(probs, device='auto', seed=0):
    # if device == 'auto':
    #     generator = torch.Generator(device='cpu')
    # else:
    #     generator = torch.Generator(device=device)
    # generator.manual_seed(seed)
    generator = None
    return torch.bernoulli(probs, generator=generator).to(torch.bool)
