import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
import os
# train, test, validation (60, 20, 20)
# we want to split data into 5 splits and each split with 20% of the data

def create_cross_validation_split(path, save_dir):
    df = pd.read_csv(path)
    df = df[df['split'] != 'test']
    df = df.drop_duplicates(subset='note_id', keep='first')
    df = df.reset_index(drop=True)
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    split = 1
    for train_idx, test_idx in kf.split(df):
        train_idx, val_idx = train_test_split(train_idx, test_size=0.2, random_state=42)
        df_train = df.iloc[train_idx]
        df_val = df.iloc[val_idx]
        df_test = df.iloc[test_idx]
        if not os.path.exists(f'{save_dir}/split-{split}'):
            os.makedirs(f'{save_dir}/split-{split}')
        df_train.to_csv(f'{save_dir}/split-{split}/train.csv', index=False)
        df_val.to_csv(f'{save_dir}/split-{split}/val.csv', index=False)
        df_test.to_csv(f'{save_dir}/split-{split}/test.csv', index=False)
        split += 1

path = 'data/annotation/ontreatment_split/on_treatment.csv'
save_dir = 'data/annotation/ontreatment_split/cross_validation'
create_cross_validation_split(path, save_dir)
