import os
import argparse
import json
import ijson
import logging
from typing import Any, Dict, List

from provers import *
from solvers import *

def setup_logger(log_path: str):
    # Logging configuration
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    os.makedirs(os.path.dirname(log_path), exist_ok=True)
    file_handler = logging.FileHandler(log_path, mode="w")
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    file_handler.setFormatter(formatter)
    if not logger.hasHandlers():
        logger.addHandler(file_handler)
    return logger

def load_data(input_file: str) -> List[Dict[str, Any]]:
    with open(input_file, 'r') as f:
        data = json.load(f)
    return data

def load_large_data(input_file: str) -> List[Dict[str, Any]]:
    data = []
    with open(input_file, 'r') as f:
        for item in ijson.items(f, 'item'):
            data.append(item)
    return data

def save_json(data: Any, path: str) -> None:
    """Save data as JSON to the given path, ensuring directory exists."""
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w") as f:
        json.dump(data, f, indent=2, ensure_ascii=False)

def flatten_nested_array(nested_array):
    """Flatten a nested array to a one-dimensional array."""
    flat = []
    for item in nested_array:
        if isinstance(item, list):
            flat.extend(flatten_nested_array(item))
        else:
            flat.append(item)
    return flat

def fold_array_to_nested(flat_array: List[Any], nested_template: List[List[Any]]) -> List[Any]:
    """Fold a one-dimensional array into a nested array 
    according to the structure of the nested template."""
    iterator = iter(flat_array)
    
    def _fold(template):
        folded = []
        for item in template:
            if isinstance(item, list):
                folded.append(_fold(item))
            else:
                folded.append(next(iterator))
        return folded
    
    flat_template = flatten_nested_array(nested_template)
    assert len(flat_template) == len(flat_array), \
        "The length of the flattened template does not match the length of the one-dimensional array."
    
    return _fold(nested_template)


# Mapping from solver names to their respective classes
def get_solver(solver_name: str, args: argparse.Namespace):
    """
    Map solver name to the corresponding solver class and instantiate it with arguments.
    """
    solver_dict = {
        "mistral": MistralSolver,
        "erdos": ErdosSolver,
        "deepseek_qwen_8b": DeepSeekQwen8BSolver,
        "qwen3": Qwen3Solver,
        "gpt_oss": GPTOSSSolver,
    }
    if solver_name not in solver_dict:
        raise ValueError(f"Unknown solver_name: {solver_name}")
    
    # Prepare common parameters
    solver_params = {
        "gpu": args.gpu,
        "n": args.solver_k,
        "max_tokens": args.max_tokens,
        "temperature": args.temperature,
        "top_p": args.top_p,
    }
    
    # Add model_path if provided
    if hasattr(args, "solver_path") and args.solver_path:
        solver_params["model_path"] = args.solver_path
    
    return solver_dict[solver_name](**solver_params)
    
def get_prover(prover_name: str, args: argparse.Namespace):
    """
    Map prover name to the corresponding prover class and instantiate it with arguments.
    """
    prover_dict = {
        "goedel": GoedelProver,
        "deepseek_v15_rl": DeepSeekProverV15RL,
        "deepseek_v2_cot": DeepSeekProverV2CoT,
        "kimina": KiminaProver,
        "deepseek_v2_non_cot": DeepSeekProverV2nonCoT,
        "stp": STP,
        "leana": Leana,
        "erdos": ErdosProver,
    }
    if prover_name not in prover_dict:
        raise ValueError(f"Unknown prover_name: {prover_name}")
    prover_params = {
        "gpu": args.gpu,
        "n": args.prover_k,
        "max_tokens": args.max_tokens,
        "temperature": args.temperature,
        "top_p": args.top_p,
    }
    if hasattr(args, "prover_path") and args.prover_path:
        prover_params["model_path"] = args.prover_path
    return prover_dict[prover_name](**prover_params)