import numpy as np
import igraph as ig
import pandas as pd
from scipy import sparse
import torch
import csv


def generate_edges_list_from_adj(spr):
    indic = spr.indices
    indpt = spr.indptr
    edges = []
    edge_count = 0
    for i in range(indpt.size-1):
        for j in range(indpt[i+1] - indpt[i]):
            source = i
            tgt = int(indic[edge_count])
            edges.append((source,tgt))
            edge_count += 1
    return edges

def torchgeomtric_batch(batch_id, batch_size, pre_batch_size, num_node, shuffled_x, adj):
    spr = sparse.csr_matrix(adj)
    edges = generate_edges_list_from_adj(spr)
    source, dest, batch = [], [], []
    for b_idx in range(batch_size):
        batch += [b_idx] * num_node
        for i, j in edges:
            source.append(i + b_idx * num_node)
            dest.append(j + b_idx * num_node)
    edge_index = torch.LongTensor([source, dest])
    x = shuffled_x[batch_id * pre_batch_size: batch_id * pre_batch_size + batch_size]
    return x.view(batch_size * num_node,1), edge_index, torch.LongTensor(batch)

def load_biogrid_df(name, pwd, colnames = ['Source','Target']):
    di_edges = []
    with open('%s/%s.txt' % (pwd, name), 'r') as f:
        #print(f)
        for i,row in enumerate(f):
            if i >= 1:
                row_split = row.split()
                di_edges.append([row_split[0],row_split[1]])
                #di_edges[1].append(row[1])
    df_edges = pd.DataFrame(di_edges,columns=colnames)
    return df_edges

def read_from_tsv(file_path, column_names):
    csv.register_dialect('tsv_dialect', delimiter='\t', quoting=csv.QUOTE_ALL)
    with open(file_path, "r") as wf:
        reader = csv.DictReader(wf, fieldnames=column_names, dialect='tsv_dialect')
        datas = []
        for row in reader:
            data = dict(row)
            datas.append(data)
    csv.unregister_dialect('tsv_dialect')
    return datas

def sample_disea_dict(file_path, column_names):
    data = read_from_tsv(file_path, column_names)
    samples = [batch['sample'] for batch in data][1:]
    disease = [batch['_primary_disease'] for batch in data][1:]
    return {sample:diseas for sample, diseas in zip(samples,disease)}

def dataset_prepare(X, Y, p=0.8, num_subtype=33):
    num_sample = X.shape[0]
    train_idx = torch.zeros(num_sample)
    test_idx = torch.zeros(num_sample)
    X_train, Y_train = [], []
    X_test, Y_test = [], []
    for i in range(num_subtype):
        sub_idx = Y == i
        sub_size = sub_idx.sum().item()
        sub_X, sub_Y = X[sub_idx], Y[sub_idx] #
        sub_X_train, sub_X_test = sub_X[:int(sub_size * p)], sub_X[int(sub_size * p):]
        sub_Y_train, sub_Y_test = sub_Y[:int(sub_size * p)], sub_Y[int(sub_size * p):]
        X_train.append(sub_X_train)
        X_test.append(sub_X_test)
        Y_train.append(sub_Y_train)
        Y_test.append(sub_Y_test)
    return torch.cat(X_train,dim=0),torch.cat(Y_train),torch.cat(X_test,dim=0),torch.cat(Y_test)