import logging
import sys
import os
import time
import re
import json
from tqdm import tqdm
from typing import Sequence, Any
from collections import defaultdict

class LazyFileHandler(logging.FileHandler):
    """
    A file handler that only creates the log file when the first log message is written.
    This prevents empty log files from being created when no logging actually occurs.
    """
    def __init__(self, filename, mode='a', encoding='UTF-8', delay=True):
        # Use delay=True to defer file creation until first emit
        super().__init__(filename, mode, encoding, delay)
        self._filename = filename
        
    def emit(self, record):
        # Ensure the directory exists before creating the file
        if self.stream is None:
            os.makedirs(os.path.dirname(self._filename), exist_ok=True)
        super().emit(record)

class LazyLoggerManager:
    """
    Manages a single logger instance with lazy file creation.
    Ensures only one log file is created per session in the working directory.
    """
    _instance = None
    _logger = None
    _initialized = False
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance
    
    def get_logger(self, log_path: str | None = None, name: str | None = None) -> logging.Logger:
        """
        Returns a configured logger instance. Only creates handlers once.
        """
        if self._logger is not None:
            return self._logger
            
        # Use current working directory for logs instead of file directory
        if log_path is None:
            log_path = os.path.join(os.getcwd(), 'log')
        
        if name is None:
            name = f"{time.strftime('%Y-%m-%d-%H-%M-%S')}.log"
        
        # Create a named logger to avoid conflicts with root logger
        logger_name = 'olym_gen'
        self._logger = logging.getLogger(logger_name)
        
        # Configure the logger only once
        if not self._initialized:
            # Clear any existing handlers
            self._logger.handlers.clear()
            
            # Create console handler (always active)
            console_handler = logging.StreamHandler(sys.stdout)
            console_handler.setLevel(logging.INFO)
            
            # Create lazy file handler (only creates file when needed)
            log_file_path = os.path.join(log_path, name)
            file_handler = LazyFileHandler(log_file_path, encoding='UTF-8')
            file_handler.setLevel(logging.DEBUG)

            # Create formatters
            simple_formatter = logging.Formatter(
                '%(asctime)s - %(name)s - %(filename)s - %(levelname)s - %(message)s'
            )
            detailed_formatter = logging.Formatter(
                '%(asctime)s | %(name)s:%(filename)s:%(lineno)03d | %(funcName).15s | %(levelname)-8s | %(message)s'
            )
            console_handler.setFormatter(simple_formatter)
            file_handler.setFormatter(detailed_formatter)
            
            # Add filter to suppress httpx INFO in console
            def filter_httpx_info(record):
                return not (record.name.startswith('httpx') and record.levelno == logging.INFO)
            console_handler.addFilter(filter_httpx_info)

            # Add handlers to the logger
            self._logger.addHandler(console_handler)
            self._logger.addHandler(file_handler)
            self._logger.setLevel(logging.DEBUG)
            
            # Prevent propagation to root logger to avoid duplicate logs
            self._logger.propagate = False
            
            self._initialized = True
        
        return self._logger

# Global instance
_logger_manager = LazyLoggerManager()

class LazyBaseGeneratorManager:
    """
    Manages singleton instances of GeneratorBase for each provider configuration.
    Ensures only one GeneratorBase instance is created per unique provider+model combination.
    """
    _instance = None
    _generators: dict[str, Any] = {}
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance
    
    def get_generator_base(self, provider: str = "dummy", model: str | None = None, extra_model_paras: dict[str, Any] | None = None) -> Any:
        """
        Returns a GeneratorBase instance for the given provider configuration.
        Creates a new instance only if one doesn't exist for this configuration.
        
        Args:
            provider: The provider name (e.g., "openai", "deepseek", etc.)
            model: The model name to use (None for default)
            extra_model_paras: Extra model parameters
            
        Returns:
            GeneratorBase instance
        """
        # Create a unique key for this configuration
        # Convert extra_model_paras to a sorted tuple for hashing
        extra_params_key = tuple(sorted(extra_model_paras.items())) if extra_model_paras else None
        config_key = (provider, model, extra_params_key)
        
        if config_key not in self._generators:
            # Import here to avoid circular imports
            from olym_gen.generator.base_generator import GeneratorBase
            self._generators[config_key] = GeneratorBase(
                provider=provider,
                model=model,
                extra_model_paras=extra_model_paras
            )
        
        return self._generators[config_key]

# Global instances
_generator_manager = LazyBaseGeneratorManager()

def get_generator_base(provider: str = "dummy", model: str | None = None, extra_model_paras: dict[str, Any] | None = None) -> Any:
    """
    Returns a shared GeneratorBase instance for the given provider configuration.
    This ensures efficient resource reuse across different generator types.
    
    Args:
        provider: The provider name (e.g., "openai", "deepseek", etc.)
        model: The model name to use (None for default)
        extra_model_paras: Extra model parameters
        
    Returns:
        GeneratorBase instance
    """
    return _generator_manager.get_generator_base(provider, model, extra_model_paras)

def get_logger(log_path: str | None = None, name: str | None = None) -> logging.Logger:
    """
    Creates and configures a logger with lazy file creation.
    Log files are only created when there's actual output to write.
    Log files are saved to the current working directory's 'log' folder.
    
    This function now uses a singleton pattern to ensure only one logger
    instance is created per session, preventing multiple empty log files.
    """
    return _logger_manager.get_logger(log_path, name)

UNKNOWN_INDEX = -1
UNKNOWN_FIELD = "unknown"

def retrieve_id_from_name(file_name: str) -> tuple[int, int, int]:
    """
    Give a file name like "problem_1_proof_2_generate_3.json" or "problem_1_generate_3.json", return the problem index, proof index, and generation index. If the proof index is not present, return -1 for it.
    """
    file_name = os.path.split(file_name)[-1]
    m = re.match(r"problem_(\d+)_generate_(\d+)\.json", file_name)
    if m is not None:
        return int(m.group(1)), UNKNOWN_INDEX, int(m.group(2))
    m = re.match(r"problem_(\d+)_proof_(\d+)_generate_(\d+)\.json", file_name)
    if m is not None:
        return int(m.group(1)), int(m.group(2)), int(m.group(3))
    raise ValueError(
        f"File name {file_name} does not match the expected pattern."
    )

# def from_json_to_jsonl(path: str, save_path: str) -> None:
#     """
#     Our pipeline generally read a jsonl file as a dataset but we generate different json files to make sure the data is well organized and easy to read. We need to convert these json files to a jsonl file for further processing.
#     Convert JSON files containing questions and problems to a JSONL file.
#     NOTE: now the function is only designed to handle the generation to check.
#     Args:
#         path (str): The path to the input JSON files.
#         save_path (str): The path to save the output JSONL file.
#     """
#     import json
#     import os

#     if not os.path.exists(path):
#         raise FileNotFoundError(f"File {path} does not exist.")
#     jsonl_dict = {}
#     for file_name in os.listdir(path):
#         try:
#             problem_index, _, _ = retrieve_id_from_name(file_name)  # Ensure the function is called to validate the file name format
#         except ValueError as e:
#             # not a json file, skip it
#             continue
        
#         with open(os.path.join(path, file_name), 'r', encoding='UTF-8') as f:
#             data = json.load(f)
        
#         if problem_index not in jsonl_dict:
#             jsonl_dict[problem_index] = {
#                 "question": data["question"],
#                 "orig_solution": data["orig_solution"], # NOTE: that a question may have multiple solutions and this origina solution could not be the one that is rephrased. This original solution is only used to help reader to understand the problem in report.
#                 "field": data["field"],
#                 "proofs": [],
#                 "source": [],
#             }
#         jsonl_dict[problem_index]["proofs"].append(data["new_solution"])
#         jsonl_dict[problem_index]["source"].append(file_name)
                
#     # Write to JSONL file
#     if not os.path.exists(os.path.dirname(save_path)):
#         os.makedirs(os.path.dirname(save_path), exist_ok=True)
#     with open(save_path, 'w', encoding='UTF-8') as f:
#         for problem_index, content in sorted(jsonl_dict.items(), key=lambda x: x[0]):
#             content["problem_index"] = problem_index
#             f.write(json.dumps(content, ensure_ascii=False) + '\n')
            
def re_utf_for_json(path: str, ext: None | Sequence[str] = None) -> None:
    """
    Convert the encoding of a file to UTF-8.
    Args:
        path (str): The path to the file to be converted.
    """
    for file_name in os.listdir(path):
        file_ext = os.path.splitext(file_name)[-1]
        if ext is not None and file_ext not in ext:
            continue
        file_path = os.path.join(path, file_name)
        if os.path.isfile(file_path):
            with open(file_path, 'r') as f:
                try:
                    content = json.load(f)
                except json.JSONDecodeError:
                    continue
            with open(file_path, 'w', encoding='UTF-8') as f:
                json.dump(content, f, ensure_ascii=False, indent=4)

def normalize_mask_completion(source_dir: str, dest_dir: str | None = None) -> None:
    """
    Normalize the format of masked proof completion files.
    Args:
        source_dir (str): The directory containing the source files.
        dest_dir (str | None): The directory to save the normalized files.
    """

    if dest_dir is None:
        dest_dir = source_dir + "_norm"

    if not os.path.exists(dest_dir):
        os.makedirs(dest_dir, exist_ok=True)
    
    data_dict = defaultdict(list)

    for file_name in tqdm(os.listdir(source_dir)):
        if not file_name.endswith('.json'):
            continue
        source_path = os.path.join(source_dir, file_name)
        with open(source_path, 'r', encoding='UTF-8') as f:
            data = json.load(f)
        
        normalized_data = {
            "question": data.get("question", ""),
            "field": data.get("field", "unknown"),
            "orig_solution": "\n".join(item[0] for item in data.get("groundtruth_proof", [[None]])),
            "masked_proof": data.get("masked_proof", ""),
            "new_solution": data.get("completed_proof", ""),
            "source": data.get("source", "")
        }

        try:
            problem_idx, proof_idx, gen_idx = retrieve_id_from_name(data.get("source", ""))
            normalized_data["problem_index"] = problem_idx
        except Exception as e:
            logger = get_logger()
            logger.warning(f"Failed to retrieve indices from file name {file_name}: {e}, skipping.")
            continue

        data_dict[problem_idx].append(normalized_data)
    
    for problem_idx, entries in data_dict.items():
        for proof_idx, entry in enumerate(entries):
            save_name = f"problem_{problem_idx}_proof_{proof_idx}_generate_0.json"
            entry['proof_index'] = proof_idx
            dest_path = os.path.join(dest_dir, save_name)
            with open(dest_path, 'w', encoding='UTF-8') as f:
                json.dump(entry, f, ensure_ascii=False, indent=4)
