"""TCGA cancer dataset. The prediction task is classification with 17 classes.
We try to predict the location of the tumor based on methylation data.
The 17 locations are:
  1. breast
  2. lung
  3. kidney
  4. brain
  5. ovary
  6. endometrial
  7. head and neck
  8. central nervous system
  9. thyroid
  10. prostate
  11. colon
  12. stomach
  13. bladder
  14. liver
  15. cervical
  16. bone marrow
  17. pancreas
We have also run feature selection to minimize the number of features to test on
as a preprocessing step. This was carried out with STG (https://arxiv.org/abs/1810.04247)
followed by XGBoost feature importance. The 21 chosen features are:
  'c7orf51',
  'def6',
  'dnase1l3',
  'efs',
  'foxe1',
  'gpr81',
  'gria2',
  'gsdmc',
  'hoxa9',
  'kaag1',
  'klf5',
  'loc283392',
  'ltbr',
  'lyplal1',
  'pon3',
  'pou3f3',
  'serpinb1',
  'st6galnac1',
  'tmem106a',
  'znf583',
  'znf790'
"""



import os
import os.path as osp

import numpy as np
import pandas as pd
import torch

from datasets.preprocessing_utils import preprocess_and_save_data


SEED = 4150
np.random.seed(SEED)
path = osp.join("datasets", "data", "tcga")
train_ratio = 0.8
val_ratio = 0.1



if __name__ == "__main__":
  # Load in labels.
  try:
    print("Loading labels.")
    df = pd.read_csv(osp.join(path, "extracted", "clinical.csv"), low_memory=False)
    print("Labels loaded.")
  except FileNotFoundError:
    raise Exception("Please download the TCGA data (https://www.cancer.gov/ccg/research/genome-sequencing/tcga).")

  index_names = df[ df["tumor_tissue_site"] == "breast" ].index
  index_names = index_names.union(df[ df["tumor_tissue_site"] == "lung" ].index)
  index_names = index_names.union(df[ df["tumor_tissue_site"] == "kidney" ].index)
  index_names = index_names.union(df[ df["tumor_tissue_site"] == "brain" ].index)
  index_names = index_names.union(df[ df["tumor_tissue_site"] == "ovary" ].index)
  index_names = index_names.union(df[ df["tumor_tissue_site"] == "endometrial" ].index)
  index_names = index_names.union(df[ df["tumor_tissue_site"] == "head and neck" ].index)
  index_names = index_names.union(df[ df["tumor_tissue_site"] == "central nervous system" ].index)
  index_names = index_names.union(df[ df["tumor_tissue_site"] == "thyroid" ].index)
  index_names = index_names.union(df[ df["tumor_tissue_site"] == "prostate" ].index)
  index_names = index_names.union(df[ df["tumor_tissue_site"] == "colon" ].index)
  index_names = index_names.union(df[ df["tumor_tissue_site"] == "stomach" ].index)
  index_names = index_names.union(df[ df["tumor_tissue_site"] == "bladder" ].index)
  index_names = index_names.union(df[ df["tumor_tissue_site"] == "liver" ].index)
  index_names = index_names.union(df[ df["tumor_tissue_site"] == "cervical" ].index)
  index_names = index_names.union(df[ df["tumor_tissue_site"] == "bone marrow" ].index)
  index_names = index_names.union(df[ df["tumor_tissue_site"] == "pancreas" ].index)
  index_names = list(df["Unnamed: 0"].iloc[index_names])

  y = df.loc[df["Unnamed: 0"].isin(index_names)]["tumor_tissue_site"].values
  mapping_dict = {
    "breast": 0,
    "lung": 1,
    "kidney": 2,
    "brain": 3,
    "ovary": 4,
    "endometrial": 5,
    "head and neck": 6,
    "central nervous system": 7,
    "thyroid": 8,
    "prostate": 9,
    "colon": 10,
    "stomach": 11,
    "bladder": 12,
    "liver": 13,
    "cervical": 14,
    "bone marrow": 15,
    "pancreas": 16,
  }
  y = np.array([mapping_dict[i] for i in y])


  # Load in features.
  try:
    print("Loading features.")
    df = pd.read_csv(osp.join(path, "extracted", "methylation.csv"), low_memory=False)
    print("Features loaded.")
  except FileNotFoundError:
    raise Exception("Please download the TCGA data (https://www.cancer.gov/ccg/research/genome-sequencing/tcga).")

  # Find the samples that match the classes.
  X = df.loc[df["Unnamed: 0"].isin(index_names)].values[:, 1:].astype(float)

  # Remove features with more than 15% missingness.
  feature_ids = []
  for i in range(X.shape[-1]):
    if np.mean(np.isnan(X[:, i])) <= 0.15:
      feature_ids.append(i)
  feature_ids = np.array(feature_ids)
  X = X[:, feature_ids]

  # Remove samples that have more than 10% missing values.
  batch_ids = []
  for i in range(X.shape[0]):
    if np.mean(np.isnan(X[i])) <= 0.1:
      batch_ids.append(i)
  batch_ids = np.array(batch_ids)
  X = X[batch_ids, :]
  y = y[batch_ids]


  # 21 Best features chosen by STG and XGBoost.
  best_features = np.array([1826, 3324, 3518, 3751, 4523, 5068, 5104, 5143, 
                            5515, 6050, 6320, 6743, 6930, 6961, 9160, 9177, 
                            10608, 11528, 12125, 13570, 13643])


  X = X[:, best_features]
  M = 1.0 - np.isnan(X)
  X = np.where(np.isnan(X), 0, X)

  X = torch.tensor(X).float()
  M = torch.tensor(M).float()
  y = torch.tensor(y).long()


  dataset_dict = {
    "num_con_features": X.shape[-1],
    "num_cat_features": 0,
    "most_categories": 0,
    "out_dim": int(y.max().item()+1),
    "metric": "accuracy",
    "max_dim": None,
  }

  preprocess_and_save_data(
    path=path,
    dataset_dict=dataset_dict,
    train_size=int(X.shape[0]*train_ratio),
    val_size=int(X.shape[0]*val_ratio),
    X=X,
    y=y,
    M=M,
    shuffle=True,
    num_bins=200,
    size_normal=1e-5,
    ratio_uniform=0.05,
  )
