import os
import re
import json
import argparse
from typing import Dict, List, Any
from tqdm import tqdm
from utils.llm_call import concurrent_call
from prompt import *

class SamplingParams:
    """
    Parameters for sampling from language models.
    """
    def __init__(self, n: int, temperature: float, top_p: float, max_tokens: int):
        self.n = n
        self.temperature = temperature
        self.top_p = top_p
        self.stop = None
        self.max_tokens = max_tokens

class LLMSynthesizer:
    """Base class for LLM-based synthesis tasks.
    
    This class provides the basic structure for processing data through LLM:
    1. Preprocessing input data
    2. Making LLM API calls
    3. Postprocessing LLM responses
    
    Attributes:
        prompt_templates (Dict[str, str]): Templates for LLM prompts
        col_mapping (Dict[str, str]): Mapping of data columns
    """
    def __init__(self, prompt_templates: Dict[str, str], col_mapping: Dict[str, str] = None):
        self.prompt_templates = prompt_templates
        self.col_mapping = col_mapping
        
    def preprocess(self):
        raise NotImplementedError
        
    def postprocess(self):
        raise NotImplementedError
        
    def call_llm(self, inputs: List[str], params: Dict[str, Any]) -> List[tuple]:
        """Make concurrent calls to LLM.
        
        Args:
            inputs: List of input prompts
            params: LLM call parameters
            
        Returns:
            List of (prompt, response) tuples
        """
        prompts, responses = concurrent_call(inputs, **params)
        return list(zip(prompts, responses))

class CodeProblemSynthesizer(LLMSynthesizer):
    def __init__(self, prompt_templates: Dict[str, str], col_mapping: Dict[str, str] = None):
        assert 'question' in col_mapping
        assert 'solution' in col_mapping
        
        self.prompt_templates = {
            'base': NEW_CODE_PROBLEM_SYNTHESIS_PROMPT
        }
        super().__init__(prompt_templates, col_mapping)
        
    def preprocess(self, data):
        question = data[self.col_mapping['question']]
        solution = data[self.col_mapping['solution']]
        return self.prompt_templates['base'].format(question=question, solution=solution)
        
    def postprocess(self, response: str):
        
        def extract_response_parts(response):
            # Define patterns for each section
            patterns = {
                "analysis": r"## Part 1: Original Problem and Solution Analysis\s*(.*?)\s*##",
                "new_problem": r"## Part 2: New Problem\s*(.*?)\s*##",
                "example_test_cases": r"## Part 3: Example Test Cases\s*(.*?)\s*##",
            }

            # Extract each section using regex
            extracted_parts = {}
            for key, pattern in patterns.items():
                match = re.search(pattern, response, re.DOTALL)
                extracted_parts[key] = match.group(1).strip() if match else None
            return extracted_parts
        try:
            response_info = extract_response_parts(response)
        except Exception as e:
            response_info = None
        return response_info

        
class UtilityFunctionGenerator(LLMSynthesizer):
    """Generates test input utility functions using LLM.
    
    This class handles the generation of:
    1. Input constraint parsing
    2. Test input generation functions
    3. Input validation functions
    """
    def __init__(self, prompt_templates, col_mapping = None):
        assert 'question' in col_mapping
        assert 'starter_code' in col_mapping
        
        self.prompt_templates = {
            'standard_io_based': TEST_INPUT_GENERATION_PROMPT_FOR_STANDARD_IO_BASED_PROBLEMS,
            'function_based': TEST_INPUT_GENERATION_PROMPT_FOR_FUNCTION_BASED_PROBLEM
        }
        super().__init__(prompt_templates, col_mapping)

    def preprocess(self, data):
        question = data[self.col_mapping['question']]
        if self.col_mapping['starter_code'] in data and len(data[self.col_mapping['starter_code']]) > 0:
            starter_code = data[self.col_mapping['starter_code']]
            prompt = self.prompt_templates['function_based'].format(question=question, starter_code=starter_code)
        else:
            prompt = self.prompt_templates['standard_io_based'].format(question=question)
        
        # Add CYaRon doc before the prompt
        with open('cyaron_docs.md', 'r') as f:
            cyaron_doc = f.read()
        prompt = 'Here is the documentation for the CYaRon library:\n\n' + cyaron_doc + '\n\n' + prompt + 'You need to answer in English.'
        return prompt
    
    def postprocess(self, response: str) -> dict:
        """Process the response from LLM to extract utility functions.
        
        Args:
            response (str): Raw response text from LLM
            
        Returns:
            dict: Dictionary containing:
                - input_constraints: Parsed input constraints
                - test_input_generation_code: Function for generating test inputs
                - test_input_validation_code: Function for validating test inputs
        """
        def extract_section(text: str, section_name: str) -> str:
            """Extract a section from the response text."""
            pattern = f"## {section_name}:\n(.*?)\n##"
            match = re.search(pattern, text + "\n##", re.DOTALL)
            return match.group(1).strip() if match else ""
            
        def extract_code_block(text: str) -> str:
            """Extract Python code block from text."""
            pattern = r"```python\n(.*?)\n```"
            match = re.search(pattern, text, re.DOTALL)
            return match.group(1).strip() if match else text.strip()

        try:
            # Extract main sections
            input_constraints = extract_section(response, "Part 1: Parse Input Constraints")
            test_gen_section = extract_section(response, "Part 2: Code for Test Input Generation")
            test_val_section = extract_section(response, "Part 3: Code to Validate Test Input")
            
            # Extract code blocks from sections
            test_gen_function = extract_code_block(test_gen_section)
            test_val_function = extract_code_block(test_val_section)
            
            return {
                'input_constraints': input_constraints,
                'test_input_generation_code': test_gen_function,
                'test_input_validation_code': test_val_function
            }
        except Exception as e:
            print(f"Error processing response: {str(e)}")
            return {
                'input_constraints': "",
                'test_input_generation_code': "",
                'test_input_validation_code': ""
            }
            
# Map synthesizer types to their implementations
SynthesizerType = {
    'code_problem_synthesis': CodeProblemSynthesizer,
    'utility_function_generation': UtilityFunctionGenerator
}

def main():
    """Main entry point for running synthesis tasks."""
    parser = argparse.ArgumentParser(description='Run LLM-based code synthesis tasks')
    parser.add_argument('--qaf', type=str, help='Path to input data file')
    parser.add_argument('--synthesizer', type=str, choices=SynthesizerType.keys(),
                      help='Type of synthesis task to perform')
    parser.add_argument('--query_model', type=str, help='LLM model to use')
    parser.add_argument('--sampling_n', type=int, default=1,
                      help='Number of completions to generate')
    parser.add_argument('--sampling_temperature', type=float, default=0.8,
                      help='Temperature for LLM sampling')
    parser.add_argument('--sampling_top_p', type=float, default=0.95,
                      help='Top-p for nucleus sampling')
    parser.add_argument('--sampling_max_tokens', type=int, default=4096,
                      help='Maximum tokens in response')
    parser.add_argument('--output_json', type=str, default=None,
                      help='Path to save results')
    
    args = parser.parse_args()

    # Load input data
    with open(args.qaf, 'r') as f:
        data = json.load(f)

    # Initialize appropriate synthesizer
    if args.synthesizer == 'code_problem_synthesis':
        col_mapping = {'question': 'question', 'solution': 'solution'}
    elif args.synthesizer == 'utility_function_generation':
        col_mapping = {'question': 'question', 'starter_code': 'starter_code'}
        
    synthesizer = SynthesizerType[args.synthesizer](
        col_mapping=col_mapping
    )

    # Prepare inputs
    inputs2idx = {}
    print('Data length:', len(data))
    for i, d in enumerate(data):
        prompt = synthesizer.preprocess(d)
        inputs2idx[prompt] = i
        
    inputs = list(inputs2idx.keys())
    sampling_params = SamplingParams(n=args.sampling_n, 
                                  temperature=args.sampling_temperature, 
                                  top_p=args.sampling_top_p, 
                                  max_tokens=args.sampling_max_tokens
                                 )
    
    # Process data in batches
    results = []
    call_llm_params = {'query_model': args.query_model, 'sampling_params': sampling_params}
    for i in tqdm(range(0, len(inputs), args.batch_size), desc='Batch call', total=len(inputs)//args.batch_size):
        input_output_lists = synthesizer.call_llm(inputs[i:i+args.batch_size], params=call_llm_params)
        for _, (query, response) in enumerate(input_output_lists):
            response_info = [synthesizer.postprocess(r) for r in response]
            data[inputs2idx[query]]['query_prompt'] = query
            data[inputs2idx[query]]['response'] = response
            data[inputs2idx[query]]['response_info'] = response_info
            results.append(data[inputs2idx[query]])
            
        if args.output_json:
            with open(args.output_json, 'w') as f:
                json.dump(results, f, indent=4)

if __name__ == '__main__':
    main()
    