import pandas as pd
import numpy as np
import math
import tqdm
import csv
import os
from tqdm import tqdm
import pandas as pd
import torch
import torch.nn as nn
from torch import optim
import random
import igraph as ig
from scipy import sparse
import torch.nn.functional as F
from torch.autograd import Variable

from ugc_gnn import UGCGNN
from util import *

device = torch.device("cuda:{}".format(0))

string_intact = pd.read_csv('./data/9606.protein.links.detailed.v11.0_sym.csv',sep=',')
biogrid_intact = load_biogrid_df('BIOGRIDALL','./data')
gene_exp = pd.read_pickle("./data/gene_exp.pkl")
copy_num = pd.read_pickle("./data/copy_num.pkl")
adj = np.load('./data/adj.npy')
adj = adj[1:,1:]
survival_info = pd.read_csv('./data/Survival_SupplementalTable_S1_20171025_xena_sp.csv',sep='\t')
file_path = './data/TCGA_phenotype_denseDataOnlyDownload.tsv'
column_names = ['sample',
'sampleID',
'sample_type',
'_primary_disease']
sample_dis_dict = sample_disea_dict(file_path, column_names)
all_patients = list(set([i for i in sample_dis_dict.keys()]) & set(list(survival_info['sample'])) & set(list(gene_exp.columns)) & set(list(copy_num.columns)))

cancer_type_dict = {cancer: i for cancer, i in zip(set(survival_info['cancer type abbreviation']), range(len(set(survival_info['cancer type abbreviation']))))}
y_list = []
x_list = []
#Xc_list = []
existing_id = []
i= 0
pbar = tqdm(list(gene_exp.columns))
for pat_id in pbar:
    if (pat_id in list(survival_info['sample'])) and (pat_id not in existing_id):
        if len(list(gene_exp[pat_id])) == 3686:
            x_list.append(list(gene_exp[pat_id]))
            #Xc_list.append(list(copy_num[pat_id]))
            #print(survival_info[survival_info['sample'] == pat_id]['cancer type abbreviation'])
            cancer_type = list(survival_info[survival_info['sample'] == pat_id]['cancer type abbreviation'])[0]
            y_list.append(cancer_type_dict[cancer_type])
            existing_id.append(pat_id)
            i += 1

X = torch.FloatTensor(x_list)
Y = torch.LongTensor(y_list)
p = 0.8
X_train, Y_train, X_test, Y_test = dataset_prepare(X, Y, p)
num_sample = X.shape[0]
res_dir = './results'
# Pathformer is another name of CLGNN
pathformer = UGCGNN(n_copy=6, copy_dim=20, d_model=32, n_head=4, d_ff=32, drop_out=0.25, n_layers=2, hidden_dim=256,
                 out_hidden=128, Nnodes=3686, out_dim=33, device=None)
loss_func =nn.CrossEntropyLoss()
optimizer = optim.Adam(pathformer.parameters(),lr=0.001)
pathformer.to(device)

# Train the model
# Train the model
EPOCH = 150
batch_size = 256
#pbar = tqdm(range(EPOCH))
pathformer.train()
for i in range(EPOCH):
    # N = int(num_sample * 0.8)
    N = X_train.size(0)
    acc_list = []
    idx = list(range(N))
    random.shuffle(idx)
    X_epoch = X_train[idx]
    Y_epoch = Y_train[idx]

    acc_list = []
    loss_list = []
    total_B = 0
    pbar = tqdm(range(N // batch_size + 1))
    #for k in range(N // batch_size + 1):
    for k in pbar:
        if (k + 1) * batch_size <= N:
            batch_x = X_epoch[k * batch_size:(k + 1) * batch_size]
            batch_y = Y_epoch[k * batch_size:(k + 1) * batch_size]
            # idx = list(range(batch_size))
            B = batch_size
        else:
            batch_x = X_epoch[k * batch_size:]
            batch_y = Y_epoch[k * batch_size:]
            # N = int(num_sample * 0.8)
            # idx = list(range(N - (N // batch_size) * batch_size))
            B = N - (N // batch_size) * batch_size
        # random.shuffle(idx)
        # batch_x = batch_x[idx]
        # batch_y = batch_y[idx]
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        A = torch.stack([torch.FloatTensor(adj)] * B, dim=0)
        A = A.to(device)
        pred, _ = pathformer(batch_x.unsqueeze(2), A)
        loss = loss_func(pred, batch_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        target_true = torch.sum(torch.argmax(pred, dim=-1) == batch_y).item()
        acc = target_true / B
        acc_list.append(target_true)
        total_B += B
        loss_list.append(loss.item() * B)
        pbar.set_description('Epoch: {0} loss:{1:.3f} acc: {2:.4f}, :num: {3}'.format(i, loss.item(), acc, target_true))

    if epoch % 50 == 0:
        print("save current model...")
        model_name = os.path.join(res_dir, 'model_checkpoint{}.pth'.format(i))
        optimizer_name = os.path.join(res_dir, 'optimizer_checkpoint{}.pth'.format(i))
        #scheduler_name = os.path.join(args.res_dir, 'scheduler_checkpoint{}.pth'.format(epoch))
        torch.save(pathformer.state_dict(), model_name)
        torch.save(optimizer.state_dict(), optimizer_name)
        #torch.save(scheduler.state_dict(), scheduler_name)

    print('Epoch:', i ,'Train accuracy:', np.sum(acc_list)/B_total, 'train loss', np.sum(loss_list)/B_total)



N = X_test.size(0)
acc_list = []
loss_list = []
for k in range(N // batch_size + 1):
    if (k + 1) * batch_size <= N:
        batch_x = X_test[k * batch_size:(k + 1) * batch_size]
        batch_y = Y_test[k * batch_size:(k + 1) * batch_size]
        # idx = list(range(batch_size))
        B = batch_size
    else:
        batch_x = X_test[k * batch_size:]
        batch_y = Y_test[k * batch_size:]
        # N = int(num_sample * 0.8)
        # idx = list(range(N - (N // batch_size) * batch_size))
        B = N - (N // batch_size) * batch_size
    # random.shuffle(idx)
    # batch_x = batch_x[idx]
    # batch_y = batch_y[idx]
    batch_x = batch_x.to(device)
    batch_y = batch_y.to(device)
    A = torch.stack([torch.FloatTensor(adj)] * B, dim=0)
    A = A.to(device)
    with torch.no_grad():
        pred, _ = pathformer(batch_x.unsqueeze(2), A)
        loss = loss_func(pred, batch_y)

    target_true = torch.sum(torch.argmax(pred, dim=-1) == batch_y).item()
    acc_list.append(target_true)
    total_B += B
    loss_list.append(loss.item() * B)
    #pbar.set_description('Epoch: {0} loss:{1:.3f} acc: {2:.4f}, :num: {3}'.format(i, loss.item(), acc, target_true))

print('Test accuracy:', np.sum(acc_list)/N,'test loss', np.sum(loss_list)/N)