import torch

import collections
import collections.abc

# re-expose for legacy imports
collections.Mapping  = collections.abc.Mapping
collections.Iterable = collections.abc.Iterable
collections.MutableMapping = collections.abc.MutableMapping

from torch_geometric.data import Data
from torch_geometric.data import InMemoryDataset
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
from torch_geometric.loader import DataLoader

from torch_geometric.datasets import ZINC

import torch.nn.functional as F
import torch.nn as nn

from tqdm import tqdm


import random
from utils import *
from embeddings import *

import pandas as pd
from rdkit import Chem
import torch
from torch_geometric.data import InMemoryDataset, Data
import os

from sklearn.model_selection import train_test_split
import numpy as np

import time
import scipy.sparse as sp

from torch_geometric.transforms import LargestConnectedComponents

from global_embeddings import GlobalEmbeddings

from torch.utils.data import DataLoader as torchDL

from compute_alternative_targets import compute_alternative_targets

def smiles_to_data(smiles):
    mol = Chem.MolFromSmiles(smiles)

    edge_list = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_list.append((i, j))
        edge_list.append((j, i))

    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()

    data = Data(edge_index=edge_index)
    data.num_nodes = mol.GetNumAtoms()

    return data


class DrugBankDataset(InMemoryDataset):
    def __init__(self, root, csv_path, transform=None, pre_transform=None, pre_filter=None):
        self.csv_path = csv_path
        super(DrugBankDataset, self).__init__(root, transform, pre_transform, pre_filter)
        
        self.data, self.slices = torch.load(self.processed_paths[0])
        

    @property
    def processed_file_names(self):
        return ['drugbank.pt']


    def process(self):
        df = pd.read_csv(self.csv_path)
        data_list = []

        for _, row in df.iterrows():
            smiles = row['SMILES']
            data = smiles_to_data(smiles)
            if data.edge_index.shape[0] == 2:
                data_list.append(data)


        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])


    def get_idx_split_dg(self, train_ratio=0.8, valid_ratio=0.1, test_ratio=0.1, seed=42):

        num_graphs = len(self)
        indices = np.arange(num_graphs)

        train_idx, temp_idx = train_test_split(indices, train_size=train_ratio, random_state=seed, shuffle=True)

        valid_size = valid_ratio / (valid_ratio + test_ratio)
        valid_idx, test_idx = train_test_split(temp_idx, train_size=valid_size, random_state=seed, shuffle=True)

        return {
            'train': train_idx,
            'valid': valid_idx,
            'test':  test_idx
        }

def get_padded_eigvecs(adj: torch.Tensor, max_graph_size: int):
    """
    Compute eigenvalues/eigenvectors of Laplacian(adj), then
    pad both to size `max_graph_size` with zeros (or trim if larger).
    
    Returns:
      evals:  Tensor of shape (max_graph_size,)
      evecs: Tensor of shape (max_graph_size, max_graph_size)
    """
    # Build Laplacian
    lap = get_lap(adj)

    # Full spectrum
    evals, evecs = torch.linalg.eigh(lap)      # shapes (n,), (n, n)
    n = evals.size(0)

    # If graph bigger, truncate
    if n > max_graph_size:
        evals = evals[:max_graph_size]
        evecs = evecs[:max_graph_size, :max_graph_size]
        return evals, evecs

    # Otherwise, pad up to max_graph_size
    pad_len = max_graph_size - n

    # 1) pad evals: concatenate zeros
    pad_evals = torch.zeros(pad_len, device=evals.device, dtype=evals.dtype)
    evals_padded = torch.cat([evals, pad_evals], dim=0)  # shape = (max_graph_size,)

    # 2) pad evecs: add zero‐rows and zero‐columns
    #    F.pad takes (pad_left, pad_right, pad_top, pad_bottom)
    evecs_padded = F.pad(evecs,
                        # columns: (left, right) = (0, pad_len)
                        # rows   : (top, bottom) = (0, pad_len)
                        pad=(0, pad_len, 0, pad_len),
                        mode='constant', value=0.0)
    # now shape = (n+pad_len, n+pad_len) = (max_graph_size, max_graph_size)

    return evals_padded, evecs_padded


class DataPreTransform:

    def __init__(self, config):
        self.config = config
    
    def __call__(self, data: Data) -> Data:

        total_nodes_before = data.num_nodes

        if self.config.use_largest_connected_components: # taken from "LargestConnectedComponents" in pytorch geometric
            data = LargestConnectedComponents(1)(data)
            
            if data.num_nodes != total_nodes_before:
                print(f"Taking largest connected component: {data.num_nodes} out of {total_nodes_before}")
            # else:
            #     # print(f"Graph already connected: {data.num_nodes}")

        data.edge_index_orig = data.edge_index
        data.edge_index = edge_index_to_sparse_adj(data.edge_index, data.num_nodes).coalesce()
        data.adj = data.edge_index
        data.adjacency = data.edge_index

        if self.config.use_supervised:
            evals, evecs = get_padded_eigvecs(data.adj, self.config.evec_len)
            data.eigvecs = evecs 
            data.eigvals = evals

        embedder = DataEmbeddings(self.config)
        data = embedder(data)
        data.original_num_nodes = data.num_nodes
        if self.config.feature_type == "orig_features_only":
            data.x = data.x[:, :11]


        if self.config.use_alt_targets:
            data = compute_alternative_targets(data, self.config)

        return data



class DataEmbeddings:
    def __init__(self, config):
        self.config = config

    def __call__(self, data: Data) -> Data:

        data.perms = []
        data.emb_runtimes = {}
        filters = generate_wavelet_bank(data, num_scales=10, lazy_parameter=0.5, abs_val = False)


        if self.config.diffusion_emb: # PERM INVARIANCE REQUIRED 
            t1 = time.time()
            l = data.x.shape[-1]
            data.x = torch.cat((data.x, diffusion_emb(data)), dim=-1)
            r = data.x.shape[-1]
            data.perms.append([l, r])
            data.emb_runtimes['diffusion_emb'] = time.time() - t1

        if self.config.wavelet_emb:
            t1 = time.time()
            data.x = torch.cat((data.x, wavelet_emb(data, filters)), dim=-1)
            data.emb_runtimes['wavelet_emb'] = time.time() - t1
        
        if self.config.wavelet_positional_emb:
            t1 = time.time()
            data.x = torch.cat((data.x, wavelet_positional_emb(data, filters)), dim=-1)
            data.emb_runtimes['wavelet_positional_emb'] = time.time() - t1

        if self.config.scatter_emb:
            t1 = time.time()
            data.x = torch.cat((data.x, scatter_emb(data, filters)), dim=-1)
            data.emb_runtimes['scatter_emb'] = time.time() - t1

        if self.config.global_scatter_emb:
            t1 = time.time()
            data.x = torch.cat((data.x, global_scatter_emb(data, filters)), dim=-1)
            data.emb_runtimes['global_scatter_emb'] = time.time() - t1
        
        if self.config.wavelet_moments_emb:
            t1 = time.time()
            data.x = torch.cat((data.x, wavelet_moments_emb(data, filters)), dim=-1)
            data.emb_runtimes['wavelet_moments_emb'] = time.time() - t1

        if self.config.neighbor_bump_emb:
            t1 = time.time()
            data.x = torch.cat((data.x, neighbors_signal_emb(data, filters)), dim=-1)
            data.emb_runtimes['neighbor_bump_emb'] = time.time() - t1

        if self.config.diffused_dirac_emb:
            t1 = time.time()
            data.x = torch.cat((data.x, diffused_dirac_emb(data, filters)), dim=-1)
            data.emb_runtimes['diffused_dirac_emb'] = time.time() - t1
        
        
        t1 = time.time()

        data.eigvecs = data.eigvecs[:, 0:self.config.num_eigenvectors]
        data.eigvals = data.eigvals[0:self.config.num_eigenvectors]
        return data



class RandomTransform:
    def __init__(self, config):
        self.config = config

    def __call__(self, data: Data) -> Data:
        # data.edge_index = edge_index_to_sparse_adj(data.edge_index, data.num_nodes)
        # embeddings = DataEmbeddings(self.config)
        # data = embeddings(data)
        perm_indices = torch.randperm(self.config.evec_len) 
        
        for interval in data.perms:

            n = interval[1]-interval[0]
            assert(n % self.config.evec_len == 0)

            for i in range(int(n / self.config.evec_len)):
                l = interval[0] + i * self.config.evec_len
                r = l + self.config.evec_len

                shifted_perm_indices = perm_indices + l
                # print(interval)
                # print(shifted_perm_indices)

                data.x[:, l:r] = data.x[:, shifted_perm_indices]
        
        return data


class ForcedOrderTransform:
    def __init__(self, config):
        self.config = config

    def __call__(self, data: Data) -> Data:
        # data.edge_index = edge_index_to_sparse_adj(data.edge_index, data.num_nodes)
        # embeddings = DataEmbeddings(self.config)
        # data = embeddings(data)
        adj = data.edge_index
        diff_op = 0.5 * get_diffusion(adj) + 0.5 * torch.eye(adj.shape[0]) # lazy random walk
        for i in range(4):
            diff_op = diff_op @ diff_op
 
        perm_indices = torch.argsort(diff_op, dim=0).T # sort each node's embeddings by its column of diffusion operator
        perm_indices = torch.cat((perm_indices, torch.arange(data.num_nodes, self.config.evec_len).repeat(data.num_nodes, 1)), dim=-1) # for the padded indices
        for interval in data.perms:

            n = interval[1]-interval[0]
            assert(n % self.config.evec_len == 0)

            for i in range(int(n / self.config.evec_len)):
                l = interval[0] + i * self.config.evec_len
                r = l + self.config.evec_len
                # shifted_perm_indices = perm_indices + l
                # print(interval)
                # print(shifted_perm_indices)
                
                data.x[:, l:r] = torch.gather(data.x[:, l:r], 1, perm_indices) 
        return data



class CustomGraphDataset(InMemoryDataset):
    def __init__(self, data_list, transform=None):
        super().__init__(root=None, transform=transform)
        self.data_list = data_list
        self.transform = transform

    def len(self):
        return len(self.data_list)

    def get(self, idx): # NOTE: Make sure these transforms actually work as intended 
        if self.transform == None:
            return self.data_list[idx]
        else:
            return self.transform(self.data_list[idx])

