import json
import multiprocessing as mp
import os
# import jax

class InferenceEngine:
    def __init__(self, engine_type: str, model: str, gpu_memory_utilization: float = 0.9, **engine_args):
        """
        engine_type can be vllm or sgl

        vllm engine_args: https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py#L89
        sgl engine_args: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py#L41
        """
        self.engine_type = engine_type
        self.engine_args = engine_args
        if engine_type == 'vllm':
            import vllm
            self.llm = vllm.LLM(model=model, gpu_memory_utilization=gpu_memory_utilization, **engine_args)
        elif engine_type == 'sgl':
            import sglang as sgl
            self.llm = sgl.Engine(model_path=model, mem_fraction_static=gpu_memory_utilization, **engine_args)


    def generate(self, *args, **kwargs):
        """
        vllm:
        sgl: https://docs.sglang.ai/backend/sampling_params.html
        """
        return self.llm.generate(*args, **kwargs)

    # def generate(self, prompts, **sampling_params):
    #     """
    #     Generate for any pytree.
    #     """
    #     flattened_prompts, tree_def = jax.tree.flatten(prompts)
    #     flattened_outputs = self.vanilla_generate(flattened_prompts, **sampling_params)

    #     # unflatten each structure within the results dict
    #     outputs = {
    #         k: jax.tree.unflatten(tree_def, v)
    #         for k, v in flattened_outputs.items()
    #     }

    #     return outputs

    # def generate(self, prompts: list[str], **sampling_params):
    #     """
    #     For sampling_params, look at: https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py.
    #     """
    #     num_samples = sampling_params.get("n", 1)
    #     vllm_sampling_params = vllm.SamplingParams(**sampling_params)
    #     vllm_outputs =self. llm.generate(prompts, vllm_sampling_params)

    #     # Extract responses
    #     outputs = {}
    #     outputs["text"] = [
    #         [y.text for y in x.outputs] if num_samples > 1 else x.outputs[0].text
    #         for x in vllm_outputs
    #     ]
    #     outputs['token_ids'] = [
    #         [list(y.token_ids) for y in x.outputs] if num_samples > 1 else list(x.outputs[0].token_ids)
    #         for x in vllm_outputs
    #     ]
    #     if sampling_params.get('logprobs', 0) > 0:
    #         outputs["logprobs"] = [
    #             [y.logprobs for y in x.outputs] if num_samples > 1 else x.outputs[0].logprobs
    #             for x in vllm_outputs
    #         ]

    #     return outputs

    def cleanup(self):
        import gc
        import torch
        if self.engine_type == 'vllm':
            from vllm.distributed.parallel_state import destroy_model_parallel
            destroy_model_parallel()
        del self.llm
        gc.collect()
        torch.cuda.empty_cache()
        torch.distributed.destroy_process_group()
        print("Finished cleanup")


def load_json_file(file_path):
    """
    Load a single JSON file and return its contents with validation.
    Returns a tuple of (file_path, data, status, error_message)
    """
    result = {
        "file_path": file_path,
        "data": None,
        "status": "unknown",
        "error": None
    }

    try:
        # Check if file exists
        if not os.path.exists(file_path):
            result["status"] = "error"
            result["error"] = "File does not exist"
            return result

        # Check if file is empty
        if os.path.getsize(file_path) == 0:
            result["status"] = "empty"
            result["error"] = "File is empty"
            return result

        # Try to load and parse JSON
        with open(file_path, 'r') as f:
            data = json.load(f)
            result["data"] = data
            result["status"] = "valid"
            return result

    except json.JSONDecodeError as e:
        result["status"] = "invalid_json"
        result["error"] = f"Invalid JSON: {str(e)}"
        return result
    except Exception as e:
        result["status"] = "error"
        result["error"] = f"Error loading file: {str(e)}"
        return result


def load_json_files_parallel(file_paths, num_processes=None):
    """
    Load multiple JSON files in parallel.

    Args:
        file_paths: List of paths to JSON files
        num_processes: Number of processes to use (defaults to CPU count)

    Returns:
        List of loaded JSON contents in the same order as file_paths
    """
    # If num_processes is not specified, use the number of CPU cores
    if num_processes is None:
        num_processes = mp.cpu_count()
    print(f"Parallel loading files with {num_processes} processes: ", file_paths)

    # Create a pool of workers
    with mp.Pool(processes=num_processes) as pool:
        # Map the load_json_file function to each file path
        results = pool.map(load_json_file, file_paths)

    output_results = []
    for i, path in enumerate(file_paths):
        assert results[i]["status"] == "valid", f"path {i} ({path}) has invalid status {results[i]['status']} with error {results[i]['error']}"
        output_results.append(results[i]["data"])
    return output_results
