import argparse
import logging
import time
import os
import io
import numpy as np
from alpha_integrate.synthetic_data.expression_tokenizer import ExpressionTokenizer

def get_steps(data_path: str):
    with io.open(data_path, mode='r', encoding='utf-8') as f:
        lines = []
        unique_expressions = 0
        for line in f:
            if line == '\n':    
                unique_expressions += 1
            else:
                lines.append(line.split('\t\t'))
    return lines, unique_expressions

def main(args):
    dataset = args.dataset
    directory = f'alpha_integrate/synthetic_data/steps_dataset/{dataset}/'
    files = [f for f in os.listdir(directory) if f.endswith('.txt')]
    files = sorted(files, key=lambda x: int(x.split('_')[-1].split('.')[0]))

    e = ExpressionTokenizer()
    converted_steps = []
    maxsteps = 0
    step_count = 0
    num_expressions = 0
    lengths = []

    for file in files:
        data_path = directory + file
        steps, unique_expressions = get_steps(data_path)
        num_expressions += unique_expressions
        stepc = 0

        for data in steps:
            expr, subexpr, rule, result = data
            lenrule = sum([len(r.split()) for r in rule.split('\t')])
            l = len(expr.split()) + len(subexpr.split()) + lenrule
            lengths.append(l)
            stepc += 1
        maxsteps = max(maxsteps, stepc)
        step_count += stepc

    logging.info(f"Total number of unique expressions: {num_expressions}")
    logging.info(f"Total number of steps: {step_count}")
    logging.info("Length statistics")
    logging.info(f"Mean: {np.mean(lengths)}")
    logging.info(f"Median: {np.median(lengths)}")
    logging.info(f"Max: {np.max(lengths)}")
    logging.info(f"Min: {np.min(lengths)}")
    logging.info(f"Max number of steps in a single expression: {maxsteps}")
    L = 1024
    logging.info(f"Number of steps with length > {L}: {len([l for l in lengths if l > L])}\n")

    # remove the steps with length > L
    lengths = [l for l in lengths if l <= L]

    logging.info(f"Total number of steps (length <= {L}): {len(lengths)}")
    logging.info(f"Mean (length <= {L}): {np.mean(lengths)}")
    logging.info(f"Median (length <= {L}): {np.median(lengths)}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Process expressions dataset to calculate statistics.')
    parser.add_argument('--dataset', type=str, required=True, help='Name of the dataset to process')
    args = parser.parse_args()

    date_time_idx = time.strftime("%Y%m%d-%H%M%S")
    logging.basicConfig(filename = f'alpha_integrate/synthetic_data/readlogs/logs_{date_time_idx}.txt', filemode='w', level=logging.INFO)

    logging.info(f"Processing dataset: {args.dataset}\n")
    main(args)
