import warnings
import pickle as pkl
import sys, os

import scipy.sparse as sp
import numpy as np
import pickle
import torch
import torch.nn.functional as F
import scipy
from collections import defaultdict

from sklearn.preprocessing import OneHotEncoder
from utils import sparse_mx_to_torch_sparse_tensor, normalize, symmetrize, remove_self_loop, sparse_tensor_add_self_loop, adj_values_one

warnings.simplefilter("ignore")
EOS = 1e-10


def encode_onehot(labels):
    labels = labels.reshape(-1, 1)
    enc = OneHotEncoder()
    enc.fit(labels)
    labels_onehot = enc.transform(labels).toarray()
    return labels_onehot


def preprocess_features(features):
    rowsum = features.sum(dim=1)
    r_inv = torch.pow(rowsum, -1).flatten()
    r_inv[torch.isinf(r_inv)] = 0.
    r_inv = r_inv.view(-1,1)
    features = features * r_inv
    return features


def preprocess_adjs(adj1, adj2):
    adj1 = adj1.to_sparse().coalesce()
    adj2 = adj2.to_sparse().coalesce()
    adjs = [adj1, adj2]
    adjs = remove_self_loop(adjs)
    adjs = [sparse_tensor_add_self_loop(adj) for adj in adjs]
    adjs = [adj_values_one(adj).coalesce().to_dense() for adj in adjs]
    adjs = [normalize(adj, mode='sym') for adj in adjs]
    return adjs

def load_mine():

    # load pretrain data
    pretrain_data_path = 'data/tcga_pan_cancer_train_7k.pth'
    pretrain_data = torch.load(pretrain_data_path)
    omic_sizes =  pretrain_data['omic_sizes']
    case_id = pretrain_data['case_id']
    

    pretrain_omic = pretrain_data['rna_features']
    pretrain_omic_adjs = pretrain_data['rna_adj']

    pretrain_image = pretrain_data['slide_features']
    pretrain_image_adjs = pretrain_data['slide_adj']

    adjs_pretrain = preprocess_adjs(pretrain_omic_adjs, pretrain_image_adjs)

    # load fine-tune data
    fine_tune_data_path = '/share/home/wukun/workplace/EXP/GSL/PGGCL/graph_data/2omic_features_graph_data/fpkm_uq_unstranded_log_transformed/hallmarks_signatures/tcga_lung_3cls_trainTest_graph.pth'
    fine_tune_data = torch.load(fine_tune_data_path)
    case_id = fine_tune_data['case_id']
    ## train data
    fine_tune_omic = fine_tune_data['rna_features_train']
    fine_tune_omic_adjs = fine_tune_data['rna_adj_train']

    fine_tune_image = fine_tune_data['slide_features_train']
    fine_tune_image_adjs = fine_tune_data['slide_adj_train']

    adjs_ft = preprocess_adjs(fine_tune_omic_adjs, fine_tune_image_adjs)

    ## test data
    test_omic = fine_tune_data['rna_features']
    test_omic_adjs = fine_tune_data['rna_adj']

    test_image = fine_tune_data['slide_features']
    test_image_adjs = fine_tune_data['slide_adj']

    adjs_test = preprocess_adjs(test_omic_adjs, test_image_adjs)


    labels = fine_tune_data['labels'].long()
    test_mask = fine_tune_data['test_mask'].bool()
    train_mask = ~test_mask
    nclasses = int(labels.max() + 1)
    pretrain_features = [pretrain_omic, pretrain_image]
    fine_tune_features = [fine_tune_omic, fine_tune_image]
    test_features = [test_omic, test_image]

    return case_id, pretrain_features, fine_tune_features, test_features, fine_tune_image.shape[1], train_mask, test_mask, adjs_pretrain, adjs_ft, adjs_test, labels, nclasses, omic_sizes



def load_data(args):
    if args.dataset == 'mine':
        return load_mine()
    

if __name__ == "__main__":
    load_mine()