import os, sys
import torch
import torch.nn.functional as F
import torch.nn as nn
import wandb
import numpy as np
import pickle

import torch_geometric.transforms as T
import torch_geometric.utils as geo_utils
from einops import repeat
import torch_geometric
# import dataloader.belief_dataloader as belief_dataloader
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from hydra.utils import instantiate


DATAROOT = "/home/user/data"
 
def randomize_edges(data):
    edges = data.edge_index
    out = torch.zeros_like(edges)
    permute_1 = np.random.permutation(edges.shape[1])
    permute_2 = np.random.permutation(edges.shape[1])
    out[0, :] = edges[0, permute_1]
    out[1, :] = edges[0, permute_2] 
    data.edge_index = out
    return data

def make_homophilic(data, p_threshold=0.5):
    edges = data.edge_index
    labels = data.y

    def choose_idx(node_label, index, labels):
        idxs = np.where(labels == node_label)[0]
        idxs = np.delete(idxs, np.where(idxs == index)[0])
        val = np.random.uniform()
        if val < p_threshold:
            return np.random.choice(idxs)
        else:
            return np.random.choice(np.arange(len(labels)))
    
    new_edge = []
    zero_row = edges[0]
    new_edges = np.zeros_like(edges)
    for node in zero_row:
        new_index = choose_idx(labels[node], node, labels)
        new_edge.append(new_index.item())
    new_edges[0] = zero_row
    new_edges[1] = np.array(new_edge)
    data.edge_index = new_edges
    return data

def remove_all_edge_info(data):
    arr = torch.LongTensor([i for i in range(data.x.shape[0])])
    arr = repeat(arr, 'a -> b a', b=2)
    data.edge_index = arr
    return data


def turn_graph_to_tree(data):
    # first get the sparse data
    def to_torch_sparse(sparse_mx):
        """Convert a scipy sparse matrix to a torch sparse tensor."""
        sparse_mx = sparse_mx.tocoo().astype(np.float32)
        indices = torch.from_numpy(
            np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
        values = torch.from_numpy(sparse_mx.data)
        shape = torch.Size(sparse_mx.shape)
        return torch.sparse.FloatTensor(indices, values, shape)

    edge_index = data.edge_index
    adj = geo_utils.to_dense_adj(edge_index)
    csr_numpy = csr_matrix(adj[0].cpu().numpy())
    Tcsr = minimum_spanning_tree(csr_numpy)

    Tcsr_coo = to_torch_sparse(Tcsr) 
    coalesced = Tcsr_coo.coalesce()
    tree_edge_index = geo_utils.to_edge_index(coalesced)[0]

    data.edge_index = tree_edge_index.to(data.edge_index.device)
    return data

def transform_data(data, transform_type):
    if transform_type == 'make_undirected':
        transform = T.ToUndirected()
        data = transform(data)
        return data
    elif transform_type == 'normal':
        return data
    else:
        ValueError("Invalid Transform arg")
        return data


def load_flipflop_dataset(dataset_path, filename, example, transform):
    def load_from_pickle(filename):
        with open(filename, 'rb') as f:
            data = pickle.load(f)
        return data
    
    folder_path = os.path.join(dataset_path, filename)
    filepath = os.path.join(folder_path, example)
    data = load_from_pickle(filepath)
    data.edge_index = torch_geometric.utils.to_undirected(data.edge_index)
    return [data]


def get_dataset(args):
    dataset = instantiate(args.dataset.loader_params, transform=T.NormalizeFeatures())
    data = dataset[0]
    data = transform_data(data, transform_type=args.dataset.transform_type)
    if args.dataset.name in ['WikipediaNetwork', 'WebKB', 'Actor']:
        splits_file = np.load(args.dataset.splits_file)
        train_mask = splits_file['train_mask']
        val_mask = splits_file['val_mask']
        test_mask = splits_file['test_mask']
        data.train_mask = torch.tensor(train_mask, dtype=torch.bool)
        data.val_mask = torch.tensor(val_mask, dtype=torch.bool)
        data.test_mask = torch.tensor(test_mask, dtype=torch.bool)
    if args.dataset.name in ['HeterophilousGraphDataset']:
        split_idx = args.dataset.split_idx
        train_mask = data.train_mask[:, split_idx]
        test_mask = data.test_mask[:, split_idx]
        val_mask = data.val_mask[:, split_idx]
        data.train_mask = train_mask
        data.test_mask = test_mask
        data.val_mask = val_mask
    return data


def get_dataset_old(args):
    if args.dataset.name == 'cora':
        from torch_geometric.datasets import Planetoid
        path = os.path.join(f'{DATAROOT}/graph_datasets', 'data', 'Planetoid')
        dataset = Planetoid(path, 'Cora', transform=T.NormalizeFeatures(), split='full')
        data = dataset[0]
        data = transform_data(data, args)
    elif args.dataset.name in 'WikipediaNetwork':
        from torch_geometric.datasets import WikipediaNetwork
        assert args.dataset.sub_name in ['chameleon', 'squirrel'], "Invalid subname"
        dataset = WikipediaNetwork(
                root=DATAROOT, 
                name=args.dataset.sub_name,
                transform=T.NormalizeFeatures(),
                )
        data = dataset[0]
        data = transform_data(data, args)
        splits_file = np.load(
            f'{DATAROOT}/{args.dataset.sub_name}/'
            f'geom_gcn/raw/'
            f'{args.dataset.sub_name}_split_0.6_0.2_0.npz')
        train_mask = splits_file['train_mask']
        val_mask = splits_file['val_mask']
        test_mask = splits_file['test_mask']
        data.train_mask = torch.tensor(train_mask, dtype=torch.bool)
        data.val_mask = torch.tensor(val_mask, dtype=torch.bool)
        data.test_mask = torch.tensor(test_mask, dtype=torch.bool)
    elif args.dataset.name == 'WebKB':
        from torch_geometric.datasets import WebKB
        assert args.dataset.sub_name in [
            'cornell', 'texas', 'wisconsin'], "Invalid subname"
        dataset = WebKB(root=DATAROOT, name=args.dataset.sub_name)
        data = dataset[0]
        data = transform_data(data, args)
        name = args.dataset.sub_name
        splits_file = np.load(
            f'{DATAROOT}/{name}/raw/{name}_split_0.6_0.2_0.npz')
        train_mask = splits_file['train_mask']
        val_mask = splits_file['val_mask']
        test_mask = splits_file['test_mask']
        data.train_mask = torch.tensor(train_mask, dtype=torch.bool)
        data.val_mask = torch.tensor(val_mask, dtype=torch.bool)
        data.test_mask = torch.tensor(test_mask, dtype=torch.bool)
    elif args.dataset.name == 'Actor':
        from torch_geometric.datasets import Actor
        assert args.dataset.sub_name in ['film'], "Invalid subname"
        name = args.dataset.sub_name
        dataset = Actor(root=DATAROOT)
        data = dataset[0]
        data = transform_data(data, args)
        splits_file = np.load(
            f'{DATAROOT}/raw/{name}_split_0.6_0.2_0.npz')
        train_mask = splits_file['train_mask']
        val_mask = splits_file['val_mask']
        test_mask = splits_file['test_mask']
        data.train_mask = torch.tensor(train_mask, dtype=torch.bool)
        data.val_mask = torch.tensor(val_mask, dtype=torch.bool)
        data.test_mask = torch.tensor(test_mask, dtype=torch.bool)
    elif args.dataset.name in ['flipflop', 'jt', 'ising']:
        def load_from_pickle(filename):
            with open(filename, 'rb') as f:
                data = pickle.load(f)
            return data
        folder_path = args.dataset.dataset_path
        example = "example_0000.pt"
        filepath = os.path.join(folder_path, example)
        data = load_from_pickle(filepath)
        data = transform_data(data, args)
    else:
        ValueError("Invalid Dataset Name")

    # whether to use graph diffusion transformation
    # if args.model.use_gdc:
    if args.model.params.use_gdc:
        transform = T.GDC(
            self_loop_weight=1,
            normalization_in='sym',
            normalization_out='col',
            diffusion_kwargs=dict(method='ppr', alpha=0.05),
            sparsification_kwargs=dict(method='topk', k=128, dim=0),
            exact=True,)
        data = transform(data)
    
    return data
