import os
from datetime import datetime 
from datasets import load_dataset

import multiprocessing
import logging, logging.handlers
from joblib import Parallel, delayed, parallel_config

import utils.utils as utils
import utils.gpt as gpt
import operators as op

from args import parse_args
args = parse_args()

#### setting & preparing
api_ids = gpt.get_ids(min_index=1, max_index=30)
dataset = load_dataset('gsm8k', 'main', split=f"{args.data}[{args.index[0]}:{args.index[-1]}]", cache_dir="../cache/dataset/gsm8k")
args.save_folder_path = os.path.join('./data/', args.data)

# mkdir for saving data & log
if not os.path.exists(args.save_folder_path): os.makedirs(args.save_folder_path)
log_dir = "./data/logs/log_%s.log" %(datetime.now().strftime("%Y-%m-%d_%H:%M:%S")); os.makedirs(log_dir)


def process_file(file, index, client, type, logger):
    if type == "solution":
        op.solve(file, index, client, logger, args)
    elif type == "formalization":
        op.autoformalize(file, index, client, logger, args)
    elif type == "informalization":
        op.informalize(file, index, client, logger, args)
    elif type == "mutation":
        op.mutate(file, index, client, logger, args)
    elif type == "refresh":
        op.refresh(file, index, client, logger, args)
    elif type == "check":
        op.check(file, index, client, logger, args)
    else:
        raise NotImplementedError

def process_batch(file_manager, queue, type, api_ids, idx):
    logger = utils.setup_logger(idx, queue, log_dir)
    logger.info(f"Worker {idx} starting, the task is {type}") 
    client = gpt.set_api_key(api_ids[idx % len(api_ids)])
    while True:
        index, data = file_manager.get_next_file()
        if data:
            try:
                logger.info(f"processing problem_{index} in worker {idx}")
                process_file(data, index, client, type, logger)
            except Exception as e:
                logger.error(f"Some uncaught error in problem_{index}")
                print(e)
                logger.error(e)
                raise e
        else:
            logger.info(f"No more file to process in worker: {idx}, quiting now!")
            break


if __name__ == '__main__':
    manager = multiprocessing.Manager()
    file_manager = utils.FileManager(dataset, manager)  
    queue = manager.Queue(-1)
    listener = logging.handlers.QueueListener(queue, logging.FileHandler(os.path.join(log_dir, "error.log")))
    listener.start()
    Parallel(n_jobs=args.num_process)(delayed(process_batch)(file_manager, queue, args.type, api_ids, idx) for idx in range(1, args.num_process+1))
    listener.stop()
