import random
from copy import deepcopy
import numpy as np
import torch
from torch_geometric.data import Data
import torch.nn.functional as F

from relnet.state.graph_state import GraphState, get_graph_hash
from relnet.state.state_generators import extract_node_name


class FeatureSpec(object):
    def __init__(self, name, node_or_edge, datatype, feat_dim):
        self.name = name
        self.node_or_edge = node_or_edge
        self.datatype = datatype
        self.feat_dim = feat_dim


def get_specs(demand_rep, hyperparams, gds_metadata):
    max_num_nodes = gds_metadata['max_num_nodes']

    if demand_rep == "raw":
        feats = [FeatureSpec("demand_src", "node", "float", max_num_nodes),
                FeatureSpec("demand_dest", "node", "float", max_num_nodes),
                FeatureSpec(GraphState.CAPACITY_EPROP_NAME, "edge", "float", 1)
                ]
    elif demand_rep == "summed":
        feats = [FeatureSpec("demand_src_sum", "node", "float", 1),
                FeatureSpec("demand_dest_sum", "node", "float", 1),
                FeatureSpec(GraphState.CAPACITY_EPROP_NAME, "edge", "float", 1)
                ]
    else:
        raise ValueError(f"demand representation {demand_rep} not found.")

    if 'use_node_id' in hyperparams and hyperparams['use_node_id']:
        feats.insert(0, FeatureSpec("node_id", "node", "int", max_num_nodes))

    if 'avg_caps_as_node_feats' in hyperparams and hyperparams['avg_caps_as_node_feats']:
        feats.insert(0, FeatureSpec("avg_caps", "node", "float", 1))

    ## TODO.

    return feats

def compute_dist_matrix(state):
    result_np = state.get_distance_matrix()
    dists = torch.from_numpy(result_np)
    return dists

def pad_to_max_nodes(twod_feat_matrix, max_num_nodes, ncol):
    padded = torch.zeros(max_num_nodes, ncol)
    curr_dim_row = twod_feat_matrix.size(dim=0)
    curr_dim_col = twod_feat_matrix.size(dim=1)
    padded[:curr_dim_row, :curr_dim_col] = twod_feat_matrix
    return padded

def get_mlp_input_dim(demand_rep, hyperparams, gds_metadata):
    max_nodes = gds_metadata['max_num_nodes']
    max_edges = gds_metadata['max_num_edges']

    node_demands_dim = 1 if demand_rep == "summed" else max_nodes

    node_part_dim = (2 * (max_nodes * node_demands_dim))
    edge_part_dim = max_edges
    return node_part_dim + edge_part_dim


def get_feats_dim(demand_rep, hyperparams, gds_metadata):
    specs = get_specs(demand_rep, hyperparams, gds_metadata)

    node_feats = [f for f in specs if f.node_or_edge == "node"]
    node_feats_dim = sum([nf.feat_dim for nf in node_feats])

    edge_feats = [f for f in specs if f.node_or_edge == "edge"]
    edge_feats_dim = sum([ef.feat_dim for ef in edge_feats])

    return node_feats_dim, edge_feats_dim

def extract_edge_types(state, all_edges, hyperparams, **kwargs):
    all_edge_types = []
    if hyperparams['edge_coloring'] == 'uniform':
        all_edge_types = [0 for e in all_edges]
    elif hyperparams['edge_coloring'] == 'random':
        randy = random.Random()
        randy.seed(get_graph_hash(state, include_demands=False))
        all_edge_types = [randy.randint(0, 1) for _ in all_edges]

    edge_types = torch.LongTensor(all_edge_types)
    return edge_types

def to_pyt_datapoint(state, demand_rep, hyperparams, gds_metadata, **kwargs):
    specs = get_specs(demand_rep, hyperparams, gds_metadata)
    max_num_nodes = gds_metadata['max_num_nodes']
    orig_graph = kwargs['orig_graph']

    all_current_edges = state.edge_list
    all_old_edges = orig_graph.edge_list

    nf_dim, x = extract_mlp_node_feats(state, demand_rep, specs, hyperparams, max_num_nodes, orig_graph=orig_graph)
    caps = torch.full((len(all_old_edges),), fill_value=-1)
    translated_edges = translate_edges_to_orig_graph(state, orig_graph)

    for t_edge in translated_edges:
        i = orig_graph.edge_idx[t_edge]
        caps[i] = torch.FloatTensor([orig_graph.get_edge_property(t_edge, GraphState.CAPACITY_EPROP_NAME)])

    flat_x = torch.flatten(x)
    all_feats = torch.hstack((flat_x, caps))

    return all_feats


def to_pyg_datapoint_demand(state, demand_rep, hyperparams, gds_metadata, **kwargs):
    specs = get_specs(demand_rep, hyperparams, gds_metadata)
    max_num_nodes = gds_metadata['max_num_nodes']
    orig_graph = kwargs['orig_graph']

    all_edges = state.edge_list
    total_num_nodes = state.num_nodes
    total_num_edges = state.num_edges

    nf_dim, x = extract_standard_node_feats(state, demand_rep, specs, hyperparams, max_num_nodes, orig_graph=orig_graph)
    full_x = x

    ## edge list (aka index in PyG-speak)

    edge_index = edge_list_to_edge_index(all_edges)
    # print(edge_index)

    ## edge types
    if hyperparams['edge_coloring'] == "unique_edge":
        orig_graph = kwargs['orig_graph']
        translated_edges = translate_edges_to_orig_graph(state, orig_graph)

        unique_edge_colors = [orig_graph.edge_idx[e] for e in translated_edges]

        edge_types = torch.LongTensor(unique_edge_colors)
    else:
        edge_types = extract_edge_types(state, all_edges, hyperparams)

    # print(f"edge types with <<N={total_num_nodes}>>: {edge_types}")
    edge_attr, ef_dim = extract_standard_edge_feats(state, specs, hyperparams, all_edges, total_num_edges)

    #     print(state.edge_list)
    #     print(edge_index)
    #     print(edge_attr)

    data = Data(x=full_x, edge_index=edge_index, edge_attr=edge_attr, edge_type=edge_types,
                nf_dim=nf_dim, ef_dim=ef_dim,
                state_num_nodes=torch.IntTensor([total_num_nodes]),
                state_num_edges=torch.IntTensor([total_num_edges]))

    return data


def translate_edges_to_orig_graph(state, orig_graph):
    ids_in_orig_graph = find_ids_in_original_graph(state, orig_graph)
    ids_map = {lab: ids_in_orig_graph[lab] for lab in state.node_labels}
    translated_edges = [(ids_map[e[0]], ids_map[e[1]]) for e in state.edge_list]
    return translated_edges


def find_ids_in_original_graph(state, orig_graph):
    orig_name_to_id = {extract_node_name(orig_graph.get_node_property(n, 'label')): n for n in orig_graph.node_labels}
    original_ids = [orig_name_to_id[extract_node_name(state.get_node_property(n, 'label'))] for n in state.node_labels]
    return original_ids


def extract_standard_edge_feats(state, specs, hyperparams, all_edges, total_num_edges):
    ## edge feats
    edge_feats = [f for f in specs if f.node_or_edge == "edge"]
    edge_feats_dim = sum([ef.feat_dim for ef in edge_feats])
    edge_attr = torch.zeros((total_num_edges, edge_feats_dim))
    ef_dim = 0
    for ef in edge_feats:
        feat = torch.zeros((total_num_edges, ef.feat_dim))
        if ef.name == GraphState.CAPACITY_EPROP_NAME:
            for i, edge in enumerate(all_edges):
                if state.has_edge_property(edge, GraphState.CAPACITY_EPROP_NAME):
                    feat[i, :] = torch.FloatTensor([state.get_edge_property(edge, GraphState.CAPACITY_EPROP_NAME)])
                else:
                    feat[i, :] = -1

        else:
            raise ValueError(f"edge feature {ef.name} not known.")

        edge_attr[:, ef_dim:ef_dim + ef.feat_dim] = feat
        ef_dim += ef.feat_dim
    return edge_attr, ef_dim

def extract_mlp_node_feats(state, demand_rep, specs, hyperparams, max_num_nodes, **kwargs):
    ## node feats
    node_feats = [f for f in specs if f.node_or_edge == "node"]
    node_feats_dim = sum([nf.feat_dim for nf in node_feats])
    x = torch.zeros((max_num_nodes, node_feats_dim))

    nf_dim = 0

    orig_graph = kwargs['orig_graph']

    node_ids_in_orig = find_ids_in_original_graph(state, orig_graph)
    node_ids = torch.LongTensor(node_ids_in_orig)

    orig_demands = torch.from_numpy(state.demands)
    # indexes where we need to add zero columns...
    missing_nodes = sorted(list(orig_graph.all_nodes_set - set(node_ids_in_orig)))

    for nf in node_feats:
        if nf.name == "demand_src":
            feat = torch.zeros((max_num_nodes, max_num_nodes))
            padded_demands = pad_demands_for_missing(state, orig_demands, missing_nodes, how="rows")
            feat.index_add_(1, node_ids, padded_demands)
            # torch.set_printoptions(linewidth=250)
            # print(f"true node IDs: {node_ids}")
            # print(feat)
        elif nf.name == "demand_dest":
            feat = torch.zeros((max_num_nodes, max_num_nodes))
            padded_demands = pad_demands_for_missing(state, orig_demands, missing_nodes, how="cols")
            # feat = torch.zeros((state.num_nodes, max_num_nodes))
            feat.index_add_(1, node_ids, padded_demands.T)
        elif nf.name == "demand_src_sum":
            padded_demands = pad_demands_for_missing(state, orig_demands, missing_nodes, how="rows")
            src_demands = padded_demands
            feat = src_demands.sum(dim=1).view(-1, 1)
        elif nf.name == "demand_dest_sum":
            padded_demands = pad_demands_for_missing(state, orig_demands, missing_nodes, how="cols")
            dest_demands = padded_demands.T
            feat = dest_demands.sum(dim=1).view(-1, 1)
        elif nf.name == "dummy":
            feat = torch.ones(max_num_nodes, 1, dtype=torch.float)
        elif nf.name == "node_id":
            feat = F.one_hot(node_ids, num_classes=max_num_nodes)

        else:
            raise ValueError(f"node feature {nf.name} not known.")

        x[:, nf_dim:nf_dim + nf.feat_dim] = feat
        nf_dim += nf.feat_dim

    return nf_dim, x


def pad_demands_for_missing(state, padded_demands, missing_nodes, how="rows"):
    for mn in missing_nodes:
        if how == "rows":
            before = padded_demands[:mn, :]
            after = padded_demands[mn:, :]
            mid = torch.zeros((1, state.num_nodes))
            padded_demands = torch.vstack((before, mid, after))
        else:
            before = padded_demands[:, :mn]
            after = padded_demands[:, mn:]
            mid = torch.zeros((state.num_nodes, 1))
            padded_demands = torch.hstack((before, mid, after))

    return padded_demands


def extract_standard_node_feats(state, demand_rep, specs, hyperparams, max_num_nodes, **kwargs):
    ## node feats
    node_feats = [f for f in specs if f.node_or_edge == "node"]
    node_feats_dim = sum([nf.feat_dim for nf in node_feats])
    x = torch.zeros((state.num_nodes, node_feats_dim))

    nf_dim = 0
    node_ids = torch.LongTensor(find_ids_in_original_graph(state, kwargs['orig_graph']))

    for nf in node_feats:
        # TODO: can further encapsulate by letting each spec extract what it needs from the state...
        if nf.name == "demand_src":
            feat = torch.zeros((state.num_nodes, max_num_nodes))
            feat.index_add_(1, node_ids, torch.from_numpy(state.demands))
            # torch.set_printoptions(linewidth=250)
            # print(f"true node IDs: {node_ids}")
            # print(feat)
        elif nf.name == "demand_dest":
            feat = torch.zeros((state.num_nodes, max_num_nodes))
            feat.index_add_(1, node_ids, torch.from_numpy(state.demands.T))
        elif nf.name == "demand_src_sum":
            src_demands = torch.from_numpy(state.demands)
            feat = src_demands.sum(dim=1).view(-1, 1)
        elif nf.name == "demand_dest_sum":
            dest_demands = torch.from_numpy(state.demands.T)
            feat = dest_demands.sum(dim=1).view(-1, 1)

        elif nf.name == "dummy":
            feat = torch.ones(max_num_nodes, 1, dtype=torch.float)

        elif nf.name == "node_id":
            feat = F.one_hot(node_ids, num_classes=max_num_nodes)
        elif nf.name == "avg_caps":
            cap_vals = {n: [] for n in state.node_labels}

            for i, edge in enumerate(state.edge_list):
                edge_from, edge_to = edge

                if state.has_edge_property(edge, GraphState.CAPACITY_EPROP_NAME):
                    cap = state.get_edge_property(edge, GraphState.CAPACITY_EPROP_NAME)
                else:
                    cap = 1

                cap_vals[edge_from].append(cap)
                cap_vals[edge_to].append(cap)

            avg_caps = [np.mean(cv[1]) for cv in sorted(cap_vals.items(), key=lambda x: x[0])]
            feat = torch.FloatTensor(avg_caps).view(-1, 1)

        else:
            raise ValueError(f"node feature {nf.name} not known.")

        x[:, nf_dim:nf_dim + nf.feat_dim] = feat
        nf_dim += nf.feat_dim

    return nf_dim, x


def edge_list_to_edge_index(edge_list):
    edge_index = torch.stack([torch.LongTensor([e[0] for e in edge_list]),
                              torch.LongTensor([e[1] for e in edge_list])], dim=0)
    return edge_index

