import os
import sys

import pandas as pd

sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir, 'src')))
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from experiment_utils import make_experiment
from utils.metadata import DATA_DIRECTORY
from utils.utils import Bunch, ceil_div


def make_chunks(liste, chunksize):
    n_chunks = ceil_div(len(liste), chunksize)
    for i in range(n_chunks):
        start = i*chunksize
        end = min(start + chunksize, len(liste))
        yield liste[start:end]


if __name__ == '__main__':
    experiment = make_experiment()


    @experiment.config
    def config():
        params = dict(
            data_path=os.path.join(DATA_DIRECTORY, 'TEP_harvard'),
            read_chunksize=1000000,
            faults=list(range(1, 21)),
            fault_chunksize=7,
            training=False,
            gzip=False,
        )


    @experiment.automain
    def main(params, _run):
        """
        Splits the data converted by convert_tep_data.py into one file for each fault. This is so that they an be loaded
        selectively, possibly saving RAM.
        """
        params = Bunch(params)

        fault_str = 'Faulty'
        test_str = 'Training' if params.training else 'Testing'
        fname = f'TEP_{fault_str}_{test_str}.csv.gz'

        for f_chunk in make_chunks(params.faults, params.fault_chunksize):
            print(f'Reading Faults {f_chunk}...' )
            data_chunks = pd.read_csv(os.path.join(params.data_path, fname), chunksize=params.read_chunksize)
            df = []
            for data in data_chunks:
                # Filter only the needed fault
                data = data.loc[data['faultNumber'].isin(f_chunk)]

                df.append(data.copy())

            df = pd.concat(df)
            print('Done reading chunks!')

            for f in f_chunk:
                print(f'Writing data for fault {f:02d}...')
                df_f = df.loc[df['faultNumber'] == f]
                file = f'TEP_{fault_str}_{test_str}_{f:02d}.csv'
                if params.gzip:
                    file += '.gz'
                df_f.to_csv(os.path.join(params.data_path, file), index=False)
                print('Done!')
