import torch_geometric.data as pygd
from torch_geometric.loader import DataLoader
from torch_geometric.loader import DenseDataLoader
from gcip.utils.preparator.transductive_split import transductive_split_data, compute_split_idx
from gcip.preparators.preparator_base import BaseDatasetPreparator

from gcip.utils.io import dict_to_cn

import torch_geometric.utils as pygutils
import matplotlib.pyplot as plt
import networkx as nx

from torch_geometric.utils import degree

import torch
import numpy as np

import gcip.utils.io as pb_io
import os


color_mapping = [
    [255, 0, 0],  # Red
    [0, 255, 0],  # Green
    [0, 0, 255],  # Blue
    [255, 255, 0],  # Yellow
    [255, 0, 255],  # Magenta
    [0, 255, 255],  # Cyan
    [128, 0, 0],  # Maroon
    [0, 128, 0],  # Green (dark)
    [0, 0, 128],  # Navy
    [128, 128, 128]  # Gray
]

color_mapping_tensor = torch.tensor(color_mapping, dtype=torch.float32) / 255.0


class GraphPreparator(BaseDatasetPreparator):

    def __init__(self, name,
                 is_dense,
                 max_nodes,
                 add_noise,
                 transductive,
                 **kwargs):

        self.transductive = transductive

        self.is_dense = is_dense
        self.max_nodes = max_nodes
        self.add_noise = add_noise
        super().__init__(name=name,
                         **kwargs)

        self.root = os.path.join(self.root, 'graph')

    @classmethod
    def params(cls, dataset):
        if isinstance(dataset, dict):
            dataset = dict_to_cn(dataset)

        my_dict = {
            'is_dense': dataset.is_dense,
            'max_nodes': dataset.max_nodes,
            'add_noise': dataset.add_noise
        }

        my_dict.update(BaseDatasetPreparator.params(dataset))

        return my_dict

    @classmethod
    def loader(cls, dataset):
        my_dict = GraphPreparator.params(dataset)

        return cls(**my_dict)

    @property
    def type_of_data(self):
        return 'graph'

    def _batch_element_list(self):
        raise NotImplementedError

    def edge_attr_dim(self):
        raise NotImplementedError

    def get_deg(self):

        loader = self.get_dataloader_train(batch_size=1)

        max_degree = 0
        for data in loader:
            d = degree(data.edge_index[1], num_nodes=data.num_nodes,
                       dtype=torch.long)
            max_degree = max(max_degree, int(d.max()))
            # Compute the in-degree histogram tensor
        deg_histogram = torch.zeros(max_degree + 1, dtype=torch.long)
        for data in loader:
            d = degree(data.edge_index[1], num_nodes=data.num_nodes,
                       dtype=torch.long)
            deg_histogram += torch.bincount(d, minlength=deg_histogram.numel())

        return deg_histogram

    def _data_loader(self, dataset, batch_size, shuffle, num_workers=0):
        if self.is_dense:
            return DenseDataLoader(dataset,
                                   batch_size=batch_size)

        else:
            return DataLoader(dataset,
                              batch_size=batch_size,
                              shuffle=shuffle,
                              num_workers=num_workers,
                              pin_memory=False)

    def get_node_colors(self, data):

        x = torch.argmax(data.x, dim=1)

        if x.unique().max().item() < 10:
            node_color = color_mapping_tensor[x].numpy()
        else:
            node_color = np.zeros([data.x.shape[0], 3])  # All Black

        return node_color, 'Type'

    def convert_pyg_to_nx(self, data):
        return pygutils.to_networkx(data, to_undirected=False)

    def _plot_data(self, batch,
                   title_elem_idx=None,
                   batch_size=None,
                   nodes_with_color=False,
                   nodes_alpha_fn=None,
                   edges_alpha_fn=None,
                   **kwargs):
        if not isinstance(batch, list):
            data_list = batch.to_data_list()
        else:
            data_list = batch
        num_samples, nrow, ncol = self.get_number_of_rows_and_cols(
            num_samples=len(data_list),
            batch_size=batch_size
        )

        if isinstance(title_elem_idx, str):
            if not isinstance(batch, list):
                assert hasattr(batch, title_elem_idx)
                title_el = getattr(batch, title_elem_idx).numpy()
            else:
                tmp = [getattr(d, title_elem_idx).item() for d in data_list]
                title_el = np.array(tmp)

        fig, axes = plt.subplots(nrow, ncol,
                                 figsize=(4 * ncol, 4 * nrow))

        idx = 0
        for i in range(nrow):
            for j in range(ncol):
                ax_ij = self.select_axis(nrow, ncol, i, j, axes)
                data = data_list[idx]
                g = self.convert_pyg_to_nx(data)

                node_color = "#1f78b4"
                cmap = None
                with_labels = True
                edge_color = "k"

                if nodes_with_color:
                    node_color_, label_n = self.get_node_colors(data)
                    if nodes_alpha_fn is not None:
                        node_alpha, label_n = nodes_alpha_fn(data)

                        node_color = np.zeros([node_color_.shape[0], 4])
                        node_color[:, :3] = node_color_
                        node_color[:, -1] = node_alpha
                    else:
                        node_color = node_color_
                    cmap = plt.cm.Reds
                if edges_alpha_fn is not None:
                    edge_alpha, label_e = edges_alpha_fn(data)
                    edge_color = np.zeros([edge_alpha.shape[0], 4])
                    edge_color[:, -1] = edge_alpha
                    cmap = plt.cm.Reds

                nx.draw_kamada_kawai(g,
                                     node_color=node_color,
                                     cmap=cmap,
                                     with_labels=with_labels,
                                     edge_color=edge_color,
                                     ax=ax_ij)

                if isinstance(title_elem_idx, str):
                    self.add_title(title_el=title_el[idx],
                                   ax=ax_ij)
                idx += 1
                if idx == num_samples: break
            if idx == num_samples: break

        plt.tight_layout()
        return fig

    def _split_dataset(self, dataset_raw):
        if self.transductive:
            assert isinstance(dataset_raw, pygd.Data)
            datasets = transductive_split_data(data=dataset_raw,
                                               split_sizes=self.split,
                                               task=self.task,
                                               k_fold=self.k_fold)
        else:
            splits = compute_split_idx(original_len=len(dataset_raw),
                                       split_sizes=self.split,
                                       k_fold=self.k_fold,
                                       )
            datasets = []
            for sp in splits:
                datasets.append(dataset_raw[sp])
        return datasets

    def compute_psnr(self, x, x_recons):
        raise NotImplementedError

    def dim_coordinates(self):
        raise NotImplementedError

    def dim_features(self):
        raise NotImplementedError

    def get_dataset_train(self):
        return self.datasets[0]

    def get_features_train(self):
        loader = self.get_dataloader_train(batch_size=self.num_samples())
        batch = next(iter(loader))
        return batch

    def get_scaler_info(self):
        if self.scale in ['default', 'min0_max1']:
            return [('min0_max1', None)]
        elif self.scale in ['minn1_max1']:
            return [('minn1_max1', None)]
        elif self.scale in ['std']:
            return [('std', None)]
        else:
            raise NotImplementedError

    def get_y_from_dataset(self, dataset):
        loader = self._data_loader(dataset=dataset,
                                   batch_size=len(dataset),
                                   shuffle=False,
                                   num_workers=0)
        batch = next(iter(loader))
        return batch.y

    def num_samples(self):
        return len(self.datasets[0])

    # Not implemented methods
    def _get_target(self, batch):
        if batch.y.ndim == 3:
            batch.y = batch.y.squeeze(-1)
        return batch.y
