import asyncio
import gc
import os
import sys
from concurrent.futures import ProcessPoolExecutor, as_completed
from concurrent.futures.process import BrokenProcessPool
from pathlib import Path

import numpy as np

from llms.llm_utils import batch_call_llm, call_llm
from utils.file_utils import find_files
from utils.logger_utils import logger

from .config_run import dump_exper_args, run_config

if run_config.get("max_batch_size_runners", 0) > 0:
    MAX_BATCH_SIZE = run_config["max_batch_size_runners"]
else:
    MAX_BATCH_SIZE = 150


def get_task_domains(filepath="rascunho_tasks.txt"):
    domain_task_ids = {}
    with open(filepath, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            domain, task_id = line.split()
            if "task_id" in task_id:
                continue
            domain_task_ids.setdefault(domain, set()).add(task_id)
    return domain_task_ids


def get_files_and_task_ids(config, exc_strs=["/failed/"], domain=""):
    """
    Given a configuration with a trace path template, finds all matching trajectory files
    Returns:
        List of tuples, each tuple is (file_path, task_id)
    """

    trace_path_template = config["trace_path_template"]

    start_dir = Path(trace_path_template.split("{task_id}")[0]).parent
    wildcard_trace_path = trace_path_template.format(task_id="*").split(str(start_dir))[1].strip("/")
    file_paths = find_files(start_dir, wildcard_trace_path)

    files_with_task_ids = []
    split_parts = trace_path_template.split("{task_id}")
    left_part = split_parts[0].strip("./").strip("/")
    right_part = split_parts[1].strip("/")
    for file_path in file_paths:
        if any(exc_str in file_path for exc_str in exc_strs):
            continue
        # String difference between file_path and trace_path_template
        f_strip = file_path.split(left_part)[1]
        f_strip = f_strip.split(right_part)[0]
        task_id = f_strip.strip("/")
        files_with_task_ids.append((file_path, task_id))

    conversation_dir = f"{config['out_dir'].strip('-')}/conversation"

    task_ids_final = []
    # if sys.gettrace():
    #     run_config["overwrite"] = True
    if run_config["overwrite"]:
        task_ids_final = files_with_task_ids
    else:
        for file_path, task_id in files_with_task_ids:
            full_conversation_path = f"{conversation_dir}/{task_id}.txt"
            if os.path.exists(full_conversation_path):
                logger.info(f"Skipping {task_id} because {full_conversation_path} exists.")
                continue
            task_ids_final.append((file_path, task_id))

    if run_config.get("task_list"):
        task_subset = get_task_domains(run_config["task_list"]).get(domain, set())
        if task_subset:
            task_ids_final = [(file_path, task_id) for file_path, task_id in task_ids_final if task_id in task_subset]
        else:
            task_ids_final = []

    return task_ids_final


def flatten_configs(env_configs):
    all_configs = {}
    for env in env_configs:
        all_configs.update(env_configs[env])
    return all_configs


def split_jobs_into_batches(all_jobs_data, max_batch_size=100):
    # Calculate the number of batches; each batch will have at most max_batch_size items.
    total_tasks = len(all_jobs_data)
    if total_tasks == 0:
        return []

    if max_batch_size > 0:
        num_batches = (total_tasks + max_batch_size - 1) // max_batch_size
    else:
        num_batches = 1

    # Create batches: each sublist is at most max_batch_size long.
    batch_idxs = np.array_split(range(total_tasks), num_batches)
    for batch_idx in batch_idxs:
        yield [all_jobs_data[i] for i in batch_idx]


def run_batch_mode(configs_per_env, run_config, gen_config, build_llm_call_args_fn):
    for env in configs_per_env:
        configs_per_case = configs_per_env[env]

        all_jobs_data = []
        order_id = 0  # Auxiliary variable to sort the jobs by order of appearance
        for case_name, configs_per_domain in configs_per_case.items():
            # Collect data to run in batch mode.
            for domain, config in configs_per_domain.items():
                # dump_exper_args(config)
                for file_path, task_id in get_files_and_task_ids(config, domain=domain):
                    all_jobs_data.append((task_id, config, order_id))
            order_id += 1

        if len(all_jobs_data) == 0:
            logger.info(f"No tasks to run for env {env}")
            continue

        if run_config.get("sort_by_config", False):
            # Sort the jobs_data by task_id and config
            all_jobs_data.sort(key=lambda x: (x[2], x[0]))
        else:
            # Sort by task_id
            all_jobs_data.sort(key=lambda x: x[0])

        # Sort the jobs_data by config
        # all_jobs_data.sort(key=lambda x: x[2])

        # Create a list of batches. Each batch is represented by a tuple: (jobs_data, retry_count)
        batches = [
            (jobs_data, 0) for jobs_data in list(split_jobs_into_batches(all_jobs_data, max_batch_size=MAX_BATCH_SIZE))
        ]
        if len(batches) == 0:
            logger.info(f"No tasks to run for env {env}")
            continue
        max_retries = 3  # If exception during batch processing, retry up to 3 times.
        batch_index = 0  # For logging/tracking purposes

        logger.info(f"Starting execution for env {env}: {len(all_jobs_data)} tasks broken into {len(batches)} batches")
        while batches:
            batch_index += 1
            jobs_data, attempt = batches.pop(0)  # Get the next batch and its current attempt number
            logger.info(
                f"Processing batch {batch_index} with {len(jobs_data)} items, attempt {attempt + 1}/{max_retries}"
            )
            batch_call_llm_args = {
                "prompts": [],
                "conversation_dirs": [],
                "usage_dirs": [],
                "call_ids": [],
            }
            gc.collect()
            batch_error_occurred = False
            try:
                if sys.gettrace():
                    # run_config["overwrite"] = True
                    # jobs_data = jobs_data[:10]
                    # # if not jobs_data:
                    # #     continue
                    futures = [build_llm_call_args_fn(task_id, config, run_config) for task_id, config, _ in jobs_data]
                    return

                # Build LLM call arguments for the batch using a ProcessPoolExecutor.
                # Using a list to preserve order (this helps in downstream caching of images)
                with ProcessPoolExecutor(max_workers=20) as executor:
                    futures = [
                        executor.submit(build_llm_call_args_fn, task_id, config, run_config)
                        for task_id, config, _ in jobs_data
                    ]

                    # Iterate over futures in the order they were submitted
                    for idx, future in enumerate(futures):
                        result = future.result()
                        if result is None:
                            continue

                        prompt, conversation_dir, usage_dir, call_id = result
                        if not prompt:
                            continue
                        batch_call_llm_args["prompts"].append(prompt)
                        batch_call_llm_args["conversation_dirs"].append(conversation_dir)
                        batch_call_llm_args["usage_dirs"].append(usage_dir)
                        batch_call_llm_args["call_ids"].append(call_id)

                # Run the batch of LLM calls.
                if len(batch_call_llm_args["prompts"]) == 0:
                    logger.info(f"Skipping batch {batch_index}. No tasks to run.")
                    continue
                logger.info(f"Running {len(batch_call_llm_args['prompts'])} calls in batch mode")
                _, _ = batch_call_llm(
                    gen_kwargs=gen_config,
                    prompts=batch_call_llm_args["prompts"],
                    conversation_dirs=batch_call_llm_args["conversation_dirs"],
                    usage_dirs=batch_call_llm_args["usage_dirs"],
                    call_ids=batch_call_llm_args["call_ids"],
                    max_batch_size=-1,
                    num_workers=run_config.get("max_api_keys", 2),
                    multiprocess_mode=run_config.get("multiprocess_batch_mode", False),
                    verbose=True,
                    return_outputs=False,
                    max_api_keys=run_config.get("max_api_keys", 5),
                    # order_by_payload_size=True,
                )
                gc.collect()

            except Exception as e:
                logger.error(f"Error during batch call: {e}")
                batch_error_occurred = True

            finally:
                batch_call_llm_args.clear()
                gc.collect()

            # If an error occurred processing this batch, and we haven't hit the max retries, re-add it.
            if batch_error_occurred:
                if attempt + 1 < max_retries:
                    logger.error(
                        f"Batch {batch_index} encountered an error. Re-adding batch for retry (attempt {attempt + 1}/{max_retries})."
                    )
                    batches.append((jobs_data, attempt + 1))
                else:
                    logger.error(f"Batch {batch_index} exceeded max retries. Skipping batch.")
                continue
            # end of single batch run.
        logger.info(f"Finished execution for env {env}")
        # end of env run.


async def run_sequential(all_configs, run_config, gen_config, build_llm_call_args_fn):
    # Create an asyncio Queue to hold (prompt, config, task_id) tuples.
    queue = asyncio.Queue()

    async def async_call_llm(prompt, gen_config, conversation_dir, usage_dir, call_id):
        try:
            return call_llm(
                prompt=prompt,
                gen_kwargs=gen_config,
                conversation_dir=conversation_dir,
                usage_dir=usage_dir,
                call_id=call_id,
                verbose=True,
            )
        except Exception as e:
            logger.warning(f"Error executing call associated to {conversation_dir}, call_id: {call_id}: {e}")
            return None

    async def producer():
        for _, config in all_configs.items():
            dump_exper_args(config)
            try:
                for task_id in get_files_and_task_ids(config):
                    llm_call_args = build_llm_call_args_fn(task_id, config, run_config)
                    if not llm_call_args[0]:
                        continue
                    # Enqueue the prompt and associated call parameters.
                    await queue.put(llm_call_args)
            except Exception as e:
                logger.warning(f"Error creating prompt for task {task_id}, config {config}: {e}")

        # Signal that production is done by enqueuing a sentinel.
        await queue.put(None)

    async def consumer():
        while True:
            item = await queue.get()
            if item is None:
                # No more items to process.
                break
            prompt, conversation_dir, usage_dir, call_id = item
            # Sequential asynchronous call: await ensures one at a time.
            await async_call_llm(
                prompt=prompt,
                gen_config=gen_config,
                conversation_dir=conversation_dir,
                usage_dir=usage_dir,
                call_id=call_id,
            )
            queue.task_done()

    # Start the producer task (builds prompts concurrently)
    producer_task = asyncio.create_task(producer())

    # Run the consumer which processes llm calls sequentially.
    await consumer()
    await queue.join()  # Ensures all queued items are processed.
    await producer_task
