import pandas as pd
import os
import pickle
import gzip
import os.path as osp
from .dataset_produce import SmilesRepeat

# Function to read a gzipped CSV file
def read_gzipped_csv(file_path):
    with gzip.open(file_path, 'rt') as f:
        return pd.read_csv(f, header=None)

def gsplit(root, test_idx_path,valid_idx_path,train_idx_path):
    # Read the CSV files
    test_idx = read_gzipped_csv(test_idx_path).values.T[0]
    valid_idx = read_gzipped_csv(valid_idx_path).values.T[0]
    train_idx = read_gzipped_csv(train_idx_path).values.T[0]

    ori_csv_name = osp.join(root,'raw.csv')
    oridf = pd.read_csv(ori_csv_name)
    for s in ['test','train','valid']:
        sub_dir_name = osp.join(root,s)
        if not osp.exists(sub_dir_name):
            os.makedirs(sub_dir_name)
        des = osp.join(sub_dir_name,'{}.csv'.format(s))
        if not osp.exists(des):
            if s == 'test':
                df = oridf.iloc[test_idx]
            elif s == 'train':
                df = oridf.iloc[train_idx]
            else:
                df = oridf.iloc[valid_idx]

            df.to_csv(des, index=False,header=True) 
        else:
            print("*********************** {} Existed! ************************".format(des))


def csvcatg(root, polymer_type, task, rep, ratios = None): 
    if not ratios:
        ratios = [1] * len(rep)
    if len(ratios) != len(rep):
        raise ValueError(f"The number of ratios must be equal to the number of repeat times.")
    for s in ['train','valid']:
        dir_name = osp.join(f'{root}/{polymer_type}',task,'concat',str(rep),s)
        if not osp.exists(dir_name):
            os.makedirs(dir_name)
        des = osp.join(dir_name,f'{s}.csv')
        if not osp.exists(des):
            all_data = []
            for r,ratio in zip(rep,ratios):
                cur_raw_file = osp.join(f'{root}/{polymer_type}',task,str(r),s,'{}.csv'.format(s))
                if not osp.exists(cur_raw_file):
                    SmilesRepeat(r, task, root=f'{root}/{polymer_type}').repeat()
                df = pd.read_csv(cur_raw_file)
                if ratio < 1:
                    df = df.sample(frac=ratio, random_state=42)
                all_data.append(df)
            result = pd.concat(all_data, axis=0,ignore_index=False)
            # SAVE CSV
            result.to_csv(des, index=False) 
        else:
            print("*********************** {} Existed! ************************".format(des))


if __name__ == '__main__':
    pass
    