import math
from itertools import product

import numpy as np
import pandas as pd
import torch

import dgl
from dgl.data import DGLDataset


class EEGGraphDataset(DGLDataset):
    """Build graph, treat all nodes as the same type
    Parameters
    ----------
    x: edge weights of 8-node complete graph
        There are 1 x 64 edges
    y: labels (diseased/healthy)
    num_nodes: the number of nodes of the graph. In our case, it is 8.
    indices: Patient level indices. They are used to generate edge weights.

    Output
    ------
    a complete 8-node DGLGraph with node features and edge weights
    """

    def __init__(self, x, y, num_nodes, indices):
        # CAUTION - x and labels are memory-mapped, used as if they are in RAM.
        self.x = x
        self.labels = y
        self.indices = indices
        self.num_nodes = num_nodes

        # NOTE: this order decides the node index, keep consistent!
        self.ch_names = [
            "F7-F3",
            "F8-F4",
            "T7-C3",
            "T8-C4",
            "P7-P3",
            "P8-P4",
            "O1-P3",
            "O2-P4",
        ]

        # in the 10-10 system, in between the 2 10-20 electrodes in ch_names, used for calculating edge weights
        # Note: "01" is for "P03", and "02" is for "P04."
        self.ref_names = ["F5", "F6", "C5", "C6", "P5", "P6", "O1", "O2"]

        # edge indices source to target - 2 x E = 2 x 64
        # fully connected undirected graph so 8*8=64 edges
        self.node_ids = range(len(self.ch_names))
        self.edge_index = (
            torch.tensor(
                [[a, b] for a, b in product(self.node_ids, self.node_ids)],
                dtype=torch.long,
            )
            .t()
            .contiguous()
        )

        # edge attributes - E x 1
        # only the spatial distance between electrodes for now - standardize between 0 and 1
        self.distances = self.get_sensor_distances()
        a = np.array(self.distances)
        self.distances = (a - np.min(a)) / (np.max(a) - np.min(a))
        self.spec_coh_values = np.load("spec_coh_values.npy", allow_pickle=True)

    # sensor distances don't depend on window ID
    def get_sensor_distances(self):
        coords_1010 = pd.read_csv("standard_1010.tsv.txt", sep="\t")
        num_edges = self.edge_index.shape[1]
        distances = []
        for edge_idx in range(num_edges):
            sensor1_idx = self.edge_index[0, edge_idx]
            sensor2_idx = self.edge_index[1, edge_idx]
            dist = self.get_geodesic_distance(
                sensor1_idx, sensor2_idx, coords_1010
            )
            distances.append(dist)
        assert len(distances) == num_edges
        return distances

    def get_geodesic_distance(
        self, montage_sensor1_idx, montage_sensor2_idx, coords_1010
    ):

        # get the reference sensor in the 10-10 system for the current montage pair in 10-20 system
        ref_sensor1 = self.ref_names[montage_sensor1_idx]
        ref_sensor2 = self.ref_names[montage_sensor2_idx]

        x1 = float(coords_1010[coords_1010.label == ref_sensor1]["x"])
        y1 = float(coords_1010[coords_1010.label == ref_sensor1]["y"])
        z1 = float(coords_1010[coords_1010.label == ref_sensor1]["z"])

        x2 = float(coords_1010[coords_1010.label == ref_sensor2]["x"])
        y2 = float(coords_1010[coords_1010.label == ref_sensor2]["y"])
        z2 = float(coords_1010[coords_1010.label == ref_sensor2]["z"])

        # https://math.stackexchange.com/questions/1304169/distance-between-two-points-on-a-sphere
        r = 1  # since coords are on unit sphere
        # rounding is for numerical stability, domain is [-1, 1]
        dist = r * math.acos(
            round(((x1 * x2) + (y1 * y2) + (z1 * z2)) / (r**2), 2)
        )
        return dist

    # returns size of dataset = number of indices
    def __len__(self):
        return len(self.indices)

    # retrieve one sample from the dataset after applying all transforms
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # map input idx (ranging from 0 to __len__() inside self.indices)
        # to an idx in the whole dataset (inside self.x)
        # assert idx < len(self.indices)
        idx = self.indices[idx]
        node_features = self.x[idx]
        node_features = torch.from_numpy(node_features.reshape(8, 6))

        # spectral coherence between 2 montage channels!
        spec_coh_values = self.spec_coh_values[idx, :]

        # combine edge weights and spect coh values into one value/ one E x 1 tensor
        edge_weights = self.distances + spec_coh_values
        edge_weights = torch.tensor(edge_weights)  # trucated to integer

        # create 8-node complete graph
        src = [
            [0 for i in range(self.num_nodes)] for j in range(self.num_nodes)
        ]
        for i in range(len(src)):
            for j in range(len(src[i])):
                src[i][j] = i
        src = np.array(src).flatten()

        det = [
            [i for i in range(self.num_nodes)] for j in range(self.num_nodes)
        ]
        det = np.array(det).flatten()

        u, v = (torch.tensor(src), torch.tensor(det))
        g = dgl.graph((u, v))

        # add node features and edge features
        g.ndata["x"] = node_features
        g.edata["edge_weights"] = edge_weights
        return g, torch.tensor(idx), torch.tensor(self.labels[idx])
