import os
from typing import List, Dict, Tuple
from lean_utils import *
import signal
import ray
from tqdm import tqdm

class LeanVerifier:
    """Lean code verifier supporting single and batch verification with context manager"""
    
    def __init__(self):
        """
        Initialize LeanVerifier
        
        Args:
            max_repls: Maximum number of REPLs, defaults to value in config
            memory_limit_gb: Memory limit per REPL (GB), defaults to value in config
        """
        self.work_dir = os.path.abspath("./") # default work dir
        self.log_file = None
        self._initialized = False
        self._remote_tqdm = None
        
    def initialize(self):
        """Manually initialize Ray (alternative to context manager)"""
        if not self._initialized:
            try:
                num_processes = int(os.cpu_count() * 0.8)
                ray.init(num_cpus=num_processes)
                max_repl_memory = ray.cluster_resources()["memory"] * 0.95 / num_processes
            except ValueError as e:
                print("Error: ", e)
                print("Warning: When connecting to an existing cluster, num_cpus and num_gpus must not be provided; Reset it and try again.")
                ray.init()
                num_processes = int(ray.cluster_resources()['CPU'] * 0.8)
                max_repl_memory = ray.cluster_resources()["memory"] * 0.95 / num_processes
            self._initialized = True
            
            @ray.remote(memory=max_repl_memory)
            def _verify_single(code, timeout):
                repl = run_env_build(self.work_dir, log_file=self.log_file)
                response = {}
                try:
                    cmd = {"cmd": code}
                    response["response"] = send_command_with_timeout(repl, cmd, timeout)
                    response["error"] = None
                except Exception as e:
                    response["response"] = None
                    response["error"] = str(e)
                finally:
                    repl.stdin.close()
                    os.killpg(os.getpgid(repl.pid), signal.SIGKILL)
                    repl.wait()
                return response
            
            self._verify_single_remote = _verify_single
    
    def shutdown(self):
        """Shutdown Ray cluster and cleanup resources"""
        if self._initialized:
            ray.shutdown()
            self._initialized = False
            
    def __enter__(self):
        """Context manager entry point - initialize Ray"""
        if not self._initialized:
            self.initialize()
            self._initialized = True
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit point - cleanup Ray resources"""
        self.shutdown()
        
    def verify_batch(self, codes: List[str], timeout: int = 60, use_tqdm: bool = False) -> List[Dict]:
        """
        Batch verify Lean code
        
        Args:
            codes: List of Lean code strings
            timeout: Timeout in seconds
            use_tqdm: Whether to show progress bar
            
        Returns:
            list: List of verification results
        """

        if not self._initialized:
            raise RuntimeError("LeanVerifier must be used as a context manager or initialized manually")
                    
        obj_dict = {self._verify_single_remote.remote(code, timeout): i for i, code in enumerate(codes)}
        results = [None] * len(codes)
        
        if use_tqdm:
            pbar = tqdm(total=len(codes), desc="Batch verifying Lean code")
        
        remaining = list(obj_dict.keys())
        while remaining:
            done, remaining = ray.wait(remaining)
            result = ray.get(done[0])
            original_index = obj_dict[done[0]]
            results[original_index] = result
            if use_tqdm:
                pbar.update(1)
        
        if use_tqdm:
            pbar.close()
        
        return results
    
    def verify_single(self, code: str, timeout: int = 60) -> Dict:
        """
        Verify a single Lean code
        """
        repl = run_env_build(self.work_dir, log_file=None)
        response = {}
        try:
            cmd = {"cmd": code}
            write_to_process(repl.stdin, cmd)
            response["response"] = read_from_process(repl.stdout, timeout)
            response["error"] = None
        except Exception as e:
            response["response"] = None
            response["error"] = str(e)
        finally:
            repl.terminate()
            repl.wait()
        return response
    
    def parse_convert_thms(self, thm_name: str, response: Dict) -> List[Dict]:
        """
        Parse the response from the LeanREPL
        """
        return parse_convert_thms(thm_name, response)
    
    def parse_mutated_thms(self, thm_name: str, response: Dict) -> Dict:
        """
        Parse the response from the LeanREPL
        """
        return parse_mutated_thms(thm_name, response)
    
    def parse_extracted_thms(self, thm_name: str, response: Dict) -> Dict:
        """
        Parse the response from the LeanREPL
        """    
        return parse_extracted_thms(thm_name, response)
    
    def parse_results(self, response: List[Dict]) -> List[Dict]:
        """
        Analyze the response from the LeanREPL
        
        Args:
            response: Response from the LeanREPL
        """
        results = [parse_client_response(r) for r in response]
        return results
    