import copy
import random
from random import random, sample

import networkx as nx
import torch
import torch.nn as nn
import torch_geometric.utils as pyg_utils
from torch_geometric.data.batch import Batch
from torch_geometric.utils import degree
from torch_geometric.utils import sort_edge_index


import gcip.utils.io as pb_io
def add_x_noise(batch, eps):
    batch.x += eps * torch.randn_like(batch.x)
    return batch


def add_edge_noise_batch(batch, p, sort=True):
    data_list = batch.to_data_list()

    data_list_out = []
    for data in data_list:
        edge_index, edge_attr = add_edge_noise(edge_index=data.edge_index,
                                               num_nodes=data.x.size(0),
                                               p=p,
                                               edge_attr=data.edge_attr,
                                               sort=sort)
        data.edge_index = edge_index
        data.edge_attr = edge_attr
        data_list_out.append(data)

    return Batch.from_data_list(data_list_out)


def add_edge_noise(edge_index, num_nodes, p, edge_attr=None, sort=True):
    is_undirected = pyg_utils.is_undirected(edge_index)
    assert p >= 0.0
    if p == 0.0:
        return edge_index, edge_attr
    edge_set = set(map(tuple, edge_index.transpose(0, 1).tolist()))
    num_of_new_edge = int((edge_index.size(1) // 2) * p)
    if num_of_new_edge == 0:
        return edge_index, edge_attr
    to_add = list()

    max_num_edges =  num_nodes ** 2

    num_edge_samples = min(num_of_new_edge + len(edge_set) + num_nodes, max_num_edges)



    new_edges = sample(range(1, num_nodes ** 2 + 1), num_edge_samples)
    c = 0
    for i in new_edges:
        if c >= num_of_new_edge:
            break
        s = ((i - 1) // num_nodes) + 1
        t = i - (s - 1) * num_nodes
        s -= 1
        t -= 1
        if s != t and (s, t) not in edge_set:
            c += 1
            to_add.append([s, t])
            to_add.append([t, s])
            edge_set.add((s, t))
            edge_set.add((t, s))


    if len(to_add) == 0:
        return edge_index, edge_attr

    new_edge_index = torch.cat([edge_index.to('cpu'), torch.LongTensor(to_add).transpose(0, 1)], dim=1)

    if edge_attr is not None:
        new_edge_attr = torch.cat([edge_attr, torch.zeros(len(to_add), edge_attr.size(1)).to(edge_attr.device)], dim=0)
    else:
        new_edge_attr = None

    if sort:
        if edge_attr is not None:
            new_edge_index, new_edge_attr = sort_edge_index(edge_index=new_edge_index,
                                                            edge_attr=new_edge_attr,
                                                            num_nodes=num_nodes)
        else:
            new_edge_index = sort_edge_index(edge_index=new_edge_index,
                                             num_nodes=num_nodes)

    if is_undirected:
        if edge_attr is not None:
            new_edge_index, new_edge_attr = pyg_utils.to_undirected(edge_index=new_edge_index,
                                                                    edge_attr=new_edge_attr,
                                                                    num_nodes=num_nodes)
        else:
            new_edge_index = pyg_utils.to_undirected(edge_index=new_edge_index,
                                                     num_nodes=num_nodes)

    return new_edge_index, new_edge_attr


def compute_stats_batch(batch, num_samples=1):
    data_list = batch.to_data_list()
    stats = {}
    for graph_i in data_list:
        stats_i = compute_stats_graph(graph_i)
        for name, value in stats_i.items():
            if name not in stats: stats[name] = []
            stats[name].append(value)

    output = {}
    for name, values in stats.items():
        my_tensor = torch.tensor(values)
        if num_samples > 1:
            my_tensor = my_tensor.view(-1, num_samples).float().mean(1)

        output[name] = my_tensor

    return output


def compute_stats_graph(graph):
    stats = {}
    stats['num_nodes'] = graph.x.shape[0]
    stats['num_edges'] = graph.edge_index.shape[1]

    try:
        label = graph.y.item()
        stats[f'num_nodes_{label}'] = graph.x.shape[0]
        stats[f'num_edges_{label}'] = graph.edge_index.shape[1]
    except:
        pass
    return stats


def get_deg(dataset, indegree=True, bincount=False):
    d_list = []
    idx = 1 if indegree else 0
    for data in dataset:
        d = degree(data.edge_index[idx], num_nodes=data.num_nodes, dtype=torch.long)
        d_list.append(d)

    d = torch.cat(d_list)
    if bincount:
        deg = torch.bincount(d, minlength=d.numel())
    else:
        deg = d

    return deg.float()


def compute_degree(x, edge_index, flow, indegre=True):
    counts = x.new_zeros((x.size(0),))
    i, j = (1, 0) if flow == 'source_to_target' else (0, 1)
    if indegre:
        idx, count = edge_index[i].unique(return_counts=True)
    else:
        idx, count = edge_index[j].unique(return_counts=True)

    counts = counts.scatter_add(0, idx, count.float())
    return counts


def remove_edges_from_batch(edges_to_remove, batch, relabel_nodes=False):
    if len(edges_to_remove) == 0:
        return batch
    edge_index, edge_attr = remove_edges(edges_to_remove=edges_to_remove,
                                         num_edges=batch.edge_index.shape[1],
                                         edge_index=batch.edge_index,
                                         relabel_nodes=relabel_nodes,
                                         edge_attr=batch.edge_attr)

    batch.edge_index = edge_index
    batch.edge_attr = edge_attr
    return batch


def remove_edges(edges_to_remove,
                 num_edges,
                 edge_index,
                 relabel_nodes=False,
                 edge_attr=None):
    assert isinstance(edges_to_remove, list)
    assert isinstance(num_edges, int)

    edges_to_keep = torch.tensor(list(set(range(num_edges)) - set(edges_to_remove)), dtype=torch.long)
    edge_index = edge_index[:, edges_to_keep]

    if edge_attr is not None:
        edge_attr = edge_attr[edges_to_keep, :]
    return edge_index, edge_attr


def remove_nodes_from_batch(nodes_idx=None, mode='remove', batch=None, relabel_nodes=False,
                            has_batch_att=True):
    if len(nodes_idx) == 0:
        return batch
    n_mask, edge_index, edge_attr = remove_nodes(num_nodes=batch.x.shape[0],
                                                 edge_index=batch.edge_index,
                                                 nodes_idx=nodes_idx,
                                                 mode=mode,
                                                 relabel_nodes=relabel_nodes,
                                                 edge_attr=batch.edge_attr)
    batch.edge_index = edge_index
    batch.edge_attr = edge_attr
    batch.x = batch.x[n_mask]
    if has_batch_att:
        batch.batch = batch.batch[n_mask]
        # batch.node_id = batch.node_id[n_mask]
    return batch


def remove_nodes_from_batch_mask(mask=None, mode='remove', batch=None, relabel_nodes=False,
                                 has_batch_att=True):
    mask_remove = mask if mode == 'remove' else ~mask
    if mask_remove.sum().item() == 0:
        return batch

    data_list = batch.to_data_list()
    node_idx_init = 0
    data_list_out = []
    for graph in data_list:
        num_nodes = graph.x.shape[0]
        node_idx_end = node_idx_init + num_nodes
        mask_ = mask_remove[node_idx_init:node_idx_end]
        assert len(mask_) == num_nodes, f"{num_nodes}/{len(mask_)}"
        if mask_.sum().item() == 0:
            node_idx_init += num_nodes
            graph.mask = ~mask_
            data_list_out.append(graph)
            continue

        nodes_idx = mask_.nonzero(as_tuple=False).squeeze().tolist()
        if not isinstance(nodes_idx, list): nodes_idx = [nodes_idx]
        n_mask, edge_index, edge_attr = remove_nodes(num_nodes=graph.x.shape[0],
                                                     edge_index=graph.edge_index,
                                                     nodes_idx=nodes_idx,
                                                     mode='remove',
                                                     relabel_nodes=True,
                                                     edge_attr=graph.edge_attr)

        node_idx_init += num_nodes
        graph.mask = ~mask_
        graph.x = graph.x[n_mask]
        graph.edge_index = edge_index
        if has_batch_att:
            graph.edge_attr = edge_attr
        data_list_out.append(graph)

    batch_out = Batch.from_data_list(data_list_out)
    return batch_out


def remove_nodes(
        num_nodes,
        edge_index,
        nodes_idx=None,
        mode='remove',
        relabel_nodes=False,
        edge_attr=None):
    assert isinstance(nodes_idx, list)
    assert isinstance(num_nodes, int)

    device = edge_index.device
    if mode == 'remove':
        nodes_to_keep = torch.tensor(list(set(range(num_nodes)) - set(nodes_idx)), dtype=torch.long)
    elif mode == 'keep':
        nodes_to_keep = torch.tensor(nodes_idx, dtype=torch.long)
    else:
        raise NotImplementedError

    n_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device)
    n_mask[nodes_to_keep] = 1

    if relabel_nodes:
        n_idx = torch.zeros(num_nodes, dtype=torch.long, device=device)
        n_idx[nodes_to_keep] = torch.arange(nodes_to_keep.size(0), device=device)

    mask = n_mask[edge_index[0]] & n_mask[edge_index[1]]

    edge_index = edge_index[:, mask]
    edge_attr = edge_attr[mask] if edge_attr is not None else None

    if relabel_nodes:
        edge_index = n_idx[edge_index]

    return n_mask, edge_index, edge_attr


def remove_nodes2(nodes_to_remove,
                  num_nodes,
                  edge_index,
                  relabel_nodes=False):
    assert isinstance(nodes_to_remove, list)
    assert isinstance(num_nodes, int)

    device = edge_index.device
    nodes_to_keep = torch.tensor(list(set(range(num_nodes)) - set(nodes_to_remove)), dtype=torch.long)
    print('nodes_to_keep')
    print(nodes_to_keep)
    n_mask = torch.zeros(num_nodes, dtype=torch.bool)
    n_mask[nodes_to_keep] = 1

    if relabel_nodes:
        n_idx = torch.zeros(num_nodes, dtype=torch.long, device=device)
        n_idx[nodes_to_keep] = torch.arange(nodes_to_keep.size(0), device=device)

    # Mask for edges that contain both the src and dst nodes
    mask_edge = n_mask[edge_index[0]] & n_mask[edge_index[1]]

    edge_index = edge_index[:, mask_edge]

    if relabel_nodes:
        edge_index = n_idx[edge_index]

    return n_mask, mask_edge, edge_index


def remove_nodes_using_G(graph_i_original, num_nodes_to_be_removed):
    if num_nodes_to_be_removed == 0:
        return copy.deepcopy(graph_i_original)
    subgraph_has_edges = False
    while not subgraph_has_edges:
        graph_i_random = copy.deepcopy(graph_i_original)
        num_nodes = graph_i_original.num_nodes
        nodes_to_remove = list(random.sample(list(range(num_nodes)), k=num_nodes_to_be_removed))

        graph_i_random.G.remove_nodes_from(nodes_to_remove)
        if len(graph_i_random.G.edges) > 0:
            subgraph_has_edges = True
            graph_i_random.G = nx.relabel.convert_node_labels_to_integers(graph_i_random.G)
            graph_i_random._update_tensors()

    return graph_i_random


def get_activation_from_str(act_str):
    if act_str == 'relu':
        return nn.ReLU()
    else:
        raise NotImplementedError
