# script to decompose integral steps

import random   
import numpy as np
import random, io
import cloudpickle as pickle 
import logging
import sympy as sp
import time
import os
import concurrent.futures
from concurrent.futures import as_completed
from dataclasses import dataclass
import argparse

from alpha_integrate.synthetic_data.expr_utils import TokensToSympy
from alpha_integrate.synthetic_data.exceptions import ImaginaryUnitException, InvalidPrefixExpression, InfinityError
from alpha_integrate.synthetic_data.timeout import TimeoutError
from typing import List

from alpha_integrate.synthetic_data.int_steps import int_steps
from alpha_integrate.synthetic_data.decompose_steps import decompose_steps, steps_to_string, list_of_rules
from alpha_integrate.synthetic_data.process_action import is_subexpr
from alpha_integrate.synthetic_data.params.step_params import ALL_STEPS
from alpha_integrate.synthetic_data.expression_tokenizer import ExpressionTokenizer

date_time_idx = time.strftime("%Y%m%d-%H%M%S")
logging.basicConfig(filename = f'alpha_integrate/synthetic_data/decomposelogs/decompose_{date_time_idx}.txt', filemode='w', level=logging.INFO)

@dataclass
class Result:
    sympy_success: int = 0
    imaginary_exceptions: int = 0
    timeout_exceptions: int = 0
    invalid_exceptions: int = 0
    other_exceptions: int = 0
    int_success: int = 0
    int_timeout: int = 0
    int_dontknow: int = 0
    int_other: int = 0
    decompose_success: int = 0
    decompose_fail: int = 0
    decompose_other: int = 0
    total_steps: int = 0

    # add a method to add
    def __add__(self, other):
        return Result(self.sympy_success + other.sympy_success,
            self.imaginary_exceptions + other.imaginary_exceptions,
            self.timeout_exceptions + other.timeout_exceptions,
            self.invalid_exceptions + other.invalid_exceptions,
            self.other_exceptions + other.other_exceptions,
            self.int_success + other.int_success,
            self.int_timeout + other.int_timeout,
            self.int_dontknow + other.int_dontknow,
            self.int_other + other.int_other,
            self.decompose_success + other.decompose_success,
            self.decompose_fail + other.decompose_fail,
            self.decompose_other + other.decompose_other,
            self.total_steps + other.total_steps)
    
    def print(self, dataset_size: int):
        logging.info(f"Total expressions: {dataset_size} out of which {self.sympy_success} were converted to sympy ({self.sympy_success / dataset_size:.2%}).")
        logging.info(f"Imaginary unit exceptions: {self.imaginary_exceptions} constituting {self.imaginary_exceptions / dataset_size:.2%}% of the expressions.")
        logging.info(f"Timeout exceptions: {self.timeout_exceptions} constituting {self.timeout_exceptions / dataset_size:.2%}% of the expressions.")
        logging.info(f"Invalid prefix exceptions: {self.invalid_exceptions} constituting {self.invalid_exceptions / dataset_size:.2%}% of the expressions.")
        logging.info(f"Other exceptions: {self.other_exceptions} constituting {self.other_exceptions / dataset_size:.2%}% of the expressions.\n")
        logging.info(f"Successfully found manual integration steps for {self.int_success} expressions ({self.int_success / self.sympy_success:.2%}).")
        logging.info(f"Timeout exceptions for manual integration steps: {self.int_timeout} constituting {self.int_timeout / self.sympy_success:.2%}% of the expressions.")
        logging.info(f"Expressions with 'dont know' steps: {self.int_dontknow} constituting {self.int_dontknow / self.sympy_success:.2%}% of the expressions.\n")
        logging.info(f"Successfully decomposed {self.decompose_success} expressions ({self.decompose_success / self.sympy_success:.2%}).")
        logging.info(f"Failed to decompose {self.decompose_fail} expressions ({self.decompose_fail / self.sympy_success:.2%}).")
        logging.info(f"Other exceptions for manual integration steps: {self.int_other} constituting {self.int_other / self.sympy_success:.2%}% of the expressions.")
        logging.info(f"Total steps: {self.total_steps}.\n")

def print_stat_dict(stat_dict: dict):

    logging.info(f"Statistics for the decomposition experiment:\n")
    for rule in stat_dict:
        logging.info(f"{rule}: {stat_dict[rule]} ({stat_dict[rule] / stat_dict['total']:.2%})")
    
    logging.info("\n")

def data_len(data_path: str):

    with io.open(data_path, mode='r', encoding='utf-8') as f:
        file_len = sum(1 for _ in f)
    
    return file_len

def get_indices(remaining: int, data_size: int, duplicate_ids: dict = dict()) -> dict:

    perm_dict = dict()
    while len(perm_dict) < remaining:
        rn = random.randint(0, data_size)
        if rn not in perm_dict and rn not in duplicate_ids:
            perm_dict[rn] = rn

    return perm_dict

def get_random_expressions(data_path: str, num_expressions: int = None, duplicate_ids: dict = dict()):
    '''
    This function is used to get a random sample of expressions from the data_path
    If num_expressions is None, all expressions are returned
    '''

    data_size = data_len(data_path)

    if num_expressions is not None:
        perm_dict = get_indices(num_expressions, data_size, duplicate_ids)

    with io.open(data_path, mode='r', encoding='utf-8') as f:
        lines = []
        for i, line in enumerate(f):
            if num_expressions is None or i in perm_dict:
                lines.append(line.rstrip())
                duplicate_ids[i] = True
                if len(lines) == num_expressions:
                    break
    
    data = [xy.split('\t') for xy in lines]
    expressions = [xy[0].split()[2:] for xy in data if len(xy) == 2]

    return expressions, duplicate_ids

def get_next_file_num(directory: str):

    # get list of all files in path
    files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
    files = sorted(files, key=lambda x: int(x.split('_')[-1].split('.')[0]))
    if len(files) == 0:
        return 0
    last_file_num = int(files[-1].split('_')[-1].split('.')[0])
    next_file_num = last_file_num + 1
    return next_file_num


def process_expressions(expressions: List[str], path_save: str, dataset: str, pid: int):

    logging.info(f"Process {pid} started.")

    env = TokensToSympy()
    tokenizer = ExpressionTokenizer()
    new_expressions = []

    r = Result()
    start = time.time()
    stat_dict = dict()
    for rule in ALL_STEPS:
        stat_dict[rule] = 0
    stat_dict['total'] = 0

    for expression in expressions:
        
        try:
            sp_expr = env.seq_to_sp_direct(expression)
            new_expressions.append(sp_expr)
        except TimeoutError as e:
            r.timeout_exceptions += 1
            continue
        except ImaginaryUnitException as e:
            r.imaginary_exceptions += 1
            continue
        except InvalidPrefixExpression as e:
            r.invalid_exceptions += 1
            continue
        except Exception as e:
            r.other_exceptions += 1
            continue

    r.sympy_success = len(new_expressions)

    logging.info(f"Process {pid} finished processing {len(expressions)} expressions.")
    logging.info(f"Process {pid} trying to get manual integration steps...")

    expr_steps = []
    decomposed_expr_steps = []

    symbol = sp.Symbol('x')
    for sp_expr in new_expressions:
        # try to get integral steps with sympy
        try:
            steps = int_steps(sp_expr, symbol)
            if steps.contains_dont_know():
                r.int_dontknow += 1
            else:
                expr_steps.append(steps)
                try:
                    variable_list = list()
                    decomposed_steps = decompose_steps(sp_expr, symbol, variable_list, steps)
                    if None in decomposed_steps:
                        #logging.info(f"Failed to decompose {sp_expr}.")
                        r.decompose_fail += 1
                        continue
                    else:
                        decomposed_expr_steps.append(decomposed_steps)
                except Exception as e:
                    logging.info(f"Exception: {e} for {sp_expr}.")
                    r.decompose_other += 1

        except TimeoutError as e:
            r.int_timeout += 1
            continue
        except Exception as e:
            r.int_other += 1
            continue
            
    r.int_success = len(expr_steps)

    next_number = get_next_file_num(path_save)
    path = f'{path_save}/{dataset}_{next_number}.txt'
    with open(path, 'w') as f:
        for d_step in decomposed_expr_steps:
            success = True
            for step in d_step:
                expr, subexpr, rule, result = step
                if not is_subexpr(expr, subexpr):
                    continue
                try:
                    tokenized_expr = ' '.join(tokenizer.sp_to_seq(expr))
                    tokenized_subexpr = ' '.join(tokenizer.sp_to_seq(subexpr))
                    tokenized_rule = ''
                    for ru in rule:
                        if isinstance(ru, str):
                            tokenized_rule += ru + '\t'
                        elif isinstance(ru, sp.Expr):
                            tokenized_rule += ' '.join(tokenizer.sp_to_seq(ru)) + '\t'
                    tokenized_rule = tokenized_rule.rstrip('\t')
                    tokenized_result = ' '.join(tokenizer.sp_to_seq(result))   
                    tokenized_step = f"{tokenized_expr}\t\t{tokenized_subexpr}\t\t{tokenized_rule}\t\t{tokenized_result}\n"
                    f.write(tokenized_step)
                    r.total_steps += 1
                except InfinityError as e:
                    success = False
                    break
                except Exception as e:
                    success = False
                    break 

            f.write('\n')
            # if we managed to save all steps        
            if success:
                r.decompose_success += 1
                ls_rules = list(set(list_of_rules(d_step)))
                for rule in ls_rules:
                    stat_dict[rule] += 1
                stat_dict['total'] += 1

    logging.info(f"Process {pid} finished processing {len(expressions)} expressions and saved {r.total_steps} steps.")

    '''
    # sample 10 decomposed steps and print them
    
    N_sample = 10
    if len(decomposed_expr_steps) > N_sample:
        sample = random.sample(decomposed_expr_steps, N_sample)
        for sample_step in sample:
            logging.info("\n" + steps_to_string(sample_step))
            logging.info("\n")
    '''

    return r, pid, stat_dict


def process_dataset(DATASET: str, num_expressions: int, workers: int):

    path_train = f'alpha_integrate/synthetic_data/dataset/{DATASET}/{DATASET}.train'
    path_save = f'alpha_integrate/synthetic_data/steps_dataset/{DATASET}'
    path_duplicates = f'alpha_integrate/synthetic_data/steps_dataset/{DATASET}/ids'

    logging.info(f"Processing data in {path_train}...")
    logging.info(f"Save path: {path_save}.")
    logging.info(f"Duplicates path: {path_duplicates}.\n")

    # if the path_duplicates folder does not exist, create it and fill it with empty ids
    if not os.path.exists(path_duplicates):
        os.makedirs(path_duplicates)
        with open(path_duplicates + '/ids.pkl', 'wb') as f:
            pickle.dump(dict(), f)

    # if path_save does not exist, create it
    os.makedirs(path_save, exist_ok=True)

    path_duplicates += '/ids.pkl'
    with open(path_duplicates, 'rb') as f:
        duplicate_ids = pickle.load(f)

    len_d = data_len(path_train)    
    data_size = len_d if num_expressions is None else num_expressions

    assert workers > 3
    chunk_size = data_size // (workers - 3) # leave some margin for the main process
    logging.info(f"Data split into chunks of size {chunk_size}.")

    start = time.time()
    expressions, new_duplicate_ids = get_random_expressions(path_train, num_expressions, duplicate_ids)

    # update duplicate ids
    with open(path_duplicates, 'wb') as f:
        pickle.dump(new_duplicate_ids, f)
    
    # shuffle the expressions
    for _ in range(3):
        random.shuffle(expressions)

    end = time.time()

    logging.info(f"Read {len(expressions)} (random) expressions from {path_train} in {end - start} seconds.\n")

    chunk_indices = list(np.arange(0, data_size, chunk_size))
    if chunk_indices[-1] != data_size:
        chunk_indices.append(data_size)

    chunks = []
    for i in range(len(chunk_indices)-1):
        start_id = chunk_indices[i]
        end_id = chunk_indices[i+1] 
        chunks.append(expressions[start_id:end_id])

    logging.info(f"Created {len(chunks)} chunks - starting the parallel processes...\n")
    start = time.time()

    r = Result()
    stat_dict = dict()
    for rule in ALL_STEPS:
        stat_dict[rule] = 0
    stat_dict['total'] = 0
    
    with concurrent.futures.ProcessPoolExecutor(max_workers=workers-3) as executor:
        futures = [executor.submit(process_expressions, chunk, path_save, DATASET, i) for i, chunk in enumerate(chunks)]

        # this helps collecting results dynamically
        for f in as_completed(futures):
            r1, pid, d = f.result()
            r += r1
            for rule in ALL_STEPS:
                stat_dict[rule] += d[rule]
            stat_dict['total'] += d['total']
            logging.info(f"Process {pid} finished.")

    end = time.time()

    logging.info(f"Decomposition experiment for {len(expressions)} expressions took {end - start} seconds.\n")
    #r.print(len(expressions))

    return r, stat_dict, len(expressions)

def main(args):

    dataset = args.dataset
    # get this many random expressions from each dataset and see if we can manual integrate
    num_expressions = args.num_expressions

    # workers = number of cpus
    workers = os.cpu_count()

    logging.info("Starting the manual integration experiment...")
    logging.info(f"Running with {workers} workers.")
    logging.info(f"Built the environment - starting decomposition for expressions in {dataset}...\n")

    r = Result()
    stat_dict = {rule: 0 for rule in ALL_STEPS}
    stat_dict['total'] = 0

    r1, d, length = process_dataset(dataset, num_expressions, workers)
    r += r1

    for rule in ALL_STEPS:
        stat_dict[rule] += d[rule]
    stat_dict['total'] += d['total']

    logging.info("Summary:\n")
    r.print(length)
    print_stat_dict(stat_dict)

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Decompose steps of expressions.')
    parser.add_argument('--num_expressions', type=int, required=True, help='Number of expressions to process')
    parser.add_argument('--dataset', type=str, required=True, help='Name of the dataset to process')
    args = parser.parse_args()

    t1 = time.time()
    main(args)
    t2 = time.time()

    logging.info(f"Total time: {t2 - t1} seconds.") 
    