import os
from datetime import datetime 
from datasets import load_dataset

import multiprocessing
import logging, logging.handlers
from joblib import Parallel, delayed, parallel_config
import operators as op

import MathGym.utils.utils as utils
import MathGym.utils.gpt as gpt

from args import parse_args
args = parse_args()

#### setting & preparing
api_ids = gpt.get_ids(min_index=1, max_index=30)
dataset = load_dataset("lighteval/MATH", 'all', split=f"{args.data}[{args.index[0]}:{args.index[-1]}]", cache_dir="../cache/dataset/math")
args.save_folder_path = os.path.join('./data/', args.data)

# mkdir for saving data & log
if not os.path.exists(args.save_folder_path): os.makedirs(args.save_folder_path)
log_dir = "./data/logs/log_%s.log" %(datetime.now().strftime("%Y-%m-%d_%H:%M:%S")); os.makedirs(log_dir)


def process_file(file, client, operator, save_file_path, logger):
    if operator == "solution":
        op.solve(file, client, save_file_path, logger, args)
    elif operator == "download":
        op.download(file, client, save_file_path, logger, args)
    elif operator == "formalization":
        op.autoformalize(file, client, save_file_path, logger, args)
    elif operator == "check":
        op.check(file, client, save_file_path, logger, args)
    elif operator == "informalization":
        op.informalize(file, client, save_file_path, logger, args)
    elif operator == "mutation":
        op.mutate(file, client, save_file_path, logger, args)
    elif operator == "refresh":
        op.refresh(file, client, save_file_path, logger, args)
    else:
        raise NotImplementedError

def process_batch(file_manager, queue, operator, api_ids, idx):
    logger = utils.setup_logger(idx, queue, log_dir)
    logger.info(f"Worker {idx} starting, the task is {operator}") 
    client = gpt.set_api_key(api_ids[idx % len(api_ids)])
    while True:
        index, data = file_manager.get_next_file()
        if data:
            try:
                init_index, save_folder_path = args.index[0], args.save_folder_path
                file_name = "problem_%s" %(init_index+index)
                save_file_path = os.path.join(save_folder_path, file_name+'.json')
                logger.info(f"processing {save_file_path} in worker {idx}")
                process_file(data, client, operator, save_file_path, logger)
            except Exception as e:
                logger.error(f"Some uncaught error in {save_file_path}")
                print(e)
                logger.error(e)
                raise e
        else:
            logger.info(f"No more file to process in worker: {idx}, quiting now!")
            break

if __name__ == '__main__':
    manager = multiprocessing.Manager()
    file_manager = utils.FileManager(dataset, manager)  
    queue = manager.Queue(-1)
    listener = logging.handlers.QueueListener(queue, logging.FileHandler(os.path.join(log_dir, "error.log")))
    listener.start()
    Parallel(n_jobs=args.num_process)(delayed(process_batch)(file_manager, queue, args.operator, api_ids, idx) for idx in range(1, args.num_process+1))
    listener.stop()