# from: https://github.com/snap-stanford/MLAgentBench/blob/main/MLAgentBench/benchmarks/amp-parkinsons-disease-progression-prediction/scripts/prepare.py

import subprocess
import pandas as pd
import random
import os

taskname = "amp-parkinsons-disease-progression-prediction"
download_dir = "./amp_pd"
os.makedirs(download_dir, exist_ok=True)

input(f"Consent to the competition at https://www.kaggle.com/competitions/{taskname}/data; Press any key after you have accepted the rules online.")

subprocess.run(["kaggle", "competitions", "download", "-c", taskname], cwd=download_dir) 
subprocess.run(["unzip", "-n", f"{taskname}.zip"], cwd=download_dir) 
subprocess.run(["rm", f"{taskname}.zip"], cwd=download_dir) 
subprocess.run(["rm", "-r", "amp_pd_peptide"], cwd=download_dir)
subprocess.run(["rm", "-r", "amp_pd_peptide_310"], cwd=download_dir)

# ## split train to train and test in env

data_proteins     = pd.read_csv(f'{download_dir}/train_proteins.csv')
data_clinical     = pd.read_csv(f'{download_dir}/train_clinical_data.csv')
data_peptides     = pd.read_csv(f'{download_dir}/train_peptides.csv')
data_supplemental = pd.read_csv(f'{download_dir}/supplemental_clinical_data.csv')

# raise Exception('stop here')

random.seed(42)

patient_id = data_clinical['patient_id'].unique()
patiend_from_supplemental = data_supplemental['patient_id'].unique()

total_test_patient = int(len(patient_id) * 0.2)
patient_id_not_in_supplemental = [x for x in patient_id if x not in patiend_from_supplemental]
test_patient_id = random.sample(patient_id_not_in_supplemental, total_test_patient)
train_patient_id = [x for x in patient_id if x not in test_patient_id]

print('train_patient_id', len(train_patient_id))
print('test_patient_id', len(test_patient_id), 'ratio', len(test_patient_id) / len(patient_id))

data_proteins[~data_proteins['patient_id'].isin(test_patient_id)].to_csv(f'{download_dir}/train_proteins.csv', index=False)
data_clinical[~data_clinical['patient_id'].isin(test_patient_id)].to_csv(f'{download_dir}/train_clinical_data.csv', index=False)
data_peptides[~data_peptides['patient_id'].isin(test_patient_id)].to_csv(f'{download_dir}/train_peptides.csv', index=False)
data_supplemental[~data_supplemental['patient_id'].isin(test_patient_id)].to_csv(f'{download_dir}/supplemental_clinical_data.csv', index=False)

data_proteins[data_proteins['patient_id'].isin(test_patient_id)].to_csv(f'{download_dir}/example_test_files/test_proteins.csv', index=False)
data_peptides[data_peptides['patient_id'].isin(test_patient_id)].to_csv(f'{download_dir}/example_test_files/test_peptides.csv', index=False)
test_clinical = data_clinical[data_clinical['patient_id'].isin(test_patient_id)]


# Create test.csv
temp_list = []
for i in range(1, 5):
    temp = test_clinical.copy()
    temp['level_3'] = i
    temp['updrs_test'] = f'updrs_{i}'
    temp_list.append(temp)
mock_train = pd.concat(temp_list)
mock_train['row_id'] = (mock_train[['patient_id', 'visit_month', 'level_3']]
                      .apply((lambda r: f"{r.patient_id}_{int(r.visit_month)}_updrs_{r.level_3}"), axis=1))
mock_train[['visit_id', 'patient_id', 'visit_month','row_id', 'updrs_test']].to_csv(f'{download_dir}/example_test_files/test.csv', index=False)

# Create sample_submission.csv
temp_list = []
for wait in [0, 6, 12, 24]:
    temp = mock_train.copy()
    temp['wait'] = wait
    temp_list.append(temp)
y = pd.concat(temp_list)
y = y[y.visit_month + y.wait <= 108]
y['prediction_id'] = (y[['patient_id', 'visit_month', 'wait', 'level_3']]
                      .apply((lambda r: f"{r.patient_id}_{int(r.visit_month)}_updrs_{r.level_3}_plus_{r.wait}_months"), axis=1))

def get_rating(row):
    rating = test_clinical[test_clinical["visit_id"] == f'{row.patient_id}_{int(row.visit_month) + int(row.wait) }' ][f'updrs_{row.level_3}']
    if len(rating) == 0:
        return None
    return rating.item()

y['rating'] = (y[['patient_id', 'visit_month', 'wait', 'level_3']].apply(get_rating, axis=1))
y = y.dropna()
y[['prediction_id', 'rating', 'visit_month']].to_csv(f'{download_dir}/example_test_files/answer.csv', index=False)

y['rating'] = 0
y[['prediction_id', 'rating', 'visit_month']].to_csv(f'{download_dir}/example_test_files/sample_submission.csv', index=False)