import os
import random
import argparse
import concurrent.futures
from tqdm import tqdm
import logging
import time
import io
import PIL
import re
import google.generativeai as genai
from dotenv import load_dotenv
from datasets import load_from_disk, Dataset

def call_gemini_api(content, model_name="gemini-2.0-flash", temperature=0.2, max_tokens=512, retries=2, system_instruction=None):
    """
    Modified to handle PIL Images properly for Gemini API
    """
    model = genai.GenerativeModel(model_name=model_name)
    
    if system_instruction:
        model = genai.GenerativeModel(
            model_name=model_name,
            system_instruction=system_instruction
        )
    
    generation_config = genai.types.GenerationConfig(
        temperature=temperature,
        max_output_tokens=max_tokens,
    )

    # Convert content list to proper format
    formatted_content = []
    for item in content:
        if isinstance(item, PIL.Image.Image):
            # Convert PIL Image to bytes
            with io.BytesIO() as bio:
                item.save(bio, format='PNG')
                img_bytes = bio.getvalue()
            formatted_content.append({
                "mime_type": "image/png",
                "data": img_bytes
            })
        else:
            formatted_content.append(item)

    for attempt in range(retries + 1):
        try:
            response = model.generate_content(
                formatted_content,
                generation_config=generation_config
            )
            return response.text

        except Exception as e:
            logging.error(f"API call error (attempt {attempt + 1}/{retries + 1}): {str(e)}")
            if attempt < retries:
                time.sleep(min(2 ** attempt, 30))
                continue
            raise

def is_mcq_answer(answer):
    """Check if the answer is in MCQ format (letter options only)"""
    answer = answer.strip().upper()
    
    # Single letter A-Z
    if len(answer) == 1 and 'A' <= answer <= 'Z':
        return True
    
    # Options like (A), [A], A), A.
    patterns = [
        r'^\([A-Z]\)$',  # (A)
        r'^\[[A-Z]\]$',  # [A]
        r'^[A-Z]\)$',    # A)
        r'^[A-Z]\.$'     # A.
    ]
    
    # Check all patterns
    for pattern in patterns:
        if re.match(pattern, answer):
            return True
    
    return False

def process_sample(sample, system_instruction, verbose=False):
    """Process a single dataset sample with a combined approach - using only text"""
    sample_id = sample['id']
    sample_image = sample['images']
    problem = sample['problem']
    original_answer = sample['answer']
    
    # Skip if original answer is already 'yes' or 'no'
    if original_answer.lower() in ['yes', 'no']:
        if verbose:
            print(f"Skipping sample {sample_id}: Answer is '{original_answer}'")
        return {
            'id': sample_id,
            'images': sample_image,
            'processed': False,
            'error': 'Answer is yes/no',
            'problem': problem,
            'answer': original_answer,
            'is_mcq': False,
            'is_yes_no': True,
            'is_open_ended': False,  # We'll use this field for non-verifiable
            'already_good': False
        }
    
    # Check if this is an MCQ answer
    is_mcq = is_mcq_answer(original_answer)
    
    # If not an MCQ answer, we should still process it to remove units if needed
    if not is_mcq:
        # For non-MCQ answers, check if they seem non-verifiable
        # Too long answers might be hard to exact match
        if len(original_answer.split()) > 8:
            return {
                'id': sample_id,
                'images': sample_image,
                'processed': False,
                'error': 'Answer too long/complex for exact matching',
                'problem': problem,
                'answer': original_answer,
                'is_mcq': False,
                'is_yes_no': False,
                'is_open_ended': True,  # Using this field for non-verifiable
                'already_good': False
            }
            
        # For non-MCQ answers that need unit removal, we'll process them like MCQs
        # But only if they appear to have units
        if any(unit in original_answer.lower() for unit in ['°', 'km', 'm', 'kg', 's', 'L', 'circ']):
            is_mcq = True  # Treat as needing processing
        else:
            # Otherwise it's considered good as is
            return {
                'id': sample_id,
                'images': sample_image,
                'problem': problem,
                'answer': original_answer,
                'processed': True,
                'is_mcq': False,
                'is_yes_no': False,
                'is_open_ended': False,
                'already_good': True
            }
    
    # For MCQ answers and answers with units, process with LLM
    examples = """

Example 1 (VERIFIABLE - GOOD FOR EXACT MATCHING):
Original Problem:
<image>
Calculate the result of the following matrix:
A = \\begin{pmatrix} 12/5 \\\\ 16/5 \\end{pmatrix}

Original Answer: \\begin{pmatrix} 12/5 \\\\ 16/5 \\end{pmatrix}

Transformed Problem:
<image>
Calculate the result of the following matrix:
A = \\begin{pmatrix} 12/5 \\\\ 16/5 \\end{pmatrix}

Processed Answer: \\begin{pmatrix} 12/5 \\\\ 16/5 \\end{pmatrix}
Answer Type: Verifiable (mathematical expression)

Example 2 (NOT VERIFIABLE - BAD FOR EXACT MATCHING):
Original Problem:
<image>
Two refrigerators lost power. The door of one fridge was slightly open, and the door of the other was closed. During this time, thermal energy was transferred from ___ to ___.
A) the surroundings to each refrigerator
B) each refrigerator to the surroundings 
C) the open refrigerator to the closed refrigerator
D) the closed refrigerator to the open refrigerator

Original Answer: B

Transformed Problem:
<image>
Two refrigerators lost power. The door of one fridge was slightly open, and the door of the other was closed. During this time, where was thermal energy transferred?

Processed Answer: each refrigerator to the surroundings
Answer Type: Not verifiable (complex relationship description)

Example 3 (VERIFIABLE - GOOD FOR EXACT MATCHING):
Original Problem: 
<image>
As shown in the figure, the three vertices of triangle $$ABC$$ are on lines $$a$$ and $$b$$, respectively, and $$a \\parallel b$$. If $$\\angle 1=120^{ \\circ }$$ and $$\\angle 2=80^{ \\circ }$$, then the measure of $$\\angle 3$$ is ___.?

Original Answer: 60^{\\circ}

Transformed Problem:
<image>
As shown in the figure, the three vertices of triangle $$ABC$$ are on lines $$a$$ and $$b$$, respectively, and $$a \\parallel b$$. If $$\\angle 1=120^{ \\circ }$$ and $$\\angle 2=80^{ \\circ }$$, then the measure of $$\\angle 3$$ is ___ degrees?

Processed Answer: 60
Answer Type: Verifiable (specific numerical value without units))

Example 4 (VERIFIABLE - GOOD FOR EXACT MATCHING):
Original Problem:
<image>
The picture shows the terrain map of a certain place. This image includes a contour map and a directional guide marker.
What is the altitude of the easternmost point?
A) -6.8 km
B) -5.2 km
C) -4.0 km
D) -3.5 km

Original Answer: A

Transformed Problem:
<image>
The picture shows the terrain map of a certain place. This image includes a contour map and a directional guide marker.
What is the altitude of the easternmost point?

Processed Answer: -6.8
Answer Type: Verifiable (specific numerical value without units)
"""
    
    # Create combined prompt
    prompt = f"""
{examples}

Now process the following:

Original Problem:
{problem}

Original Answer:
{original_answer}

Transformed Problem:
Processed Answer: 
Answer Type:
"""
    
    try:
        # Call Gemini API with text only for MCQ answers and answers with units
        api_response = call_gemini_api(
            content=[prompt],
            temperature=0.2,
            system_instruction=system_instruction
        )
        
        # Extract the processed answer, transformed problem, and answer type
        processed_answer = None
        transformed_problem = None
        answer_type = None
        
        in_transformed_problem = False
        transformed_problem_lines = []

        for line in api_response.strip().split('\n'):
            if line.startswith("Transformed Problem:"):
                in_transformed_problem = True
                transformed_problem_lines = [line.replace("Transformed Problem:", "").strip()]
            elif line.startswith("Processed Answer:"):
                in_transformed_problem = False
                processed_answer = line.replace("Processed Answer:", "").strip()
            elif line.startswith("Answer Type:"):
                answer_type = line.replace("Answer Type:", "").strip()
            elif in_transformed_problem:
                transformed_problem_lines.append(line)
        
        # Join the transformed problem lines
        if transformed_problem_lines:
            transformed_problem = "\n".join(transformed_problem_lines).strip()
        
        # If we couldn't extract all parts properly
        if not processed_answer or not answer_type or not transformed_problem:
            if verbose:
                print(f"Could not extract all needed parts for sample {sample_id}")
                print(f"API Response: {api_response}")
            return {
                'id': sample_id,
                'images': sample_image,
                'error': 'Failed to extract processed answer, transformed problem, or answer type',
                'processed': False,
                'problem': problem,
                'answer': original_answer,
                'api_response': api_response,
                'is_mcq': True,
                'is_yes_no': False,
                'is_open_ended': False,
                'already_good': False
            }
        
        # Check if it's not verifiable
        is_not_verifiable = "not verifiable" in answer_type.lower()
        
        # Check if the answer is yes/no
        is_yes_no = processed_answer.lower() in ['yes', 'no']
        
        # Filter conditions
        if is_not_verifiable:
            if verbose:
                print(f"Skipping non-verifiable sample {sample_id}")
            return {
                'id': sample_id,
                'images': sample_image,
                'processed': False,
                'error': 'Answer not verifiable for exact matching',
                'problem': problem,
                'answer': original_answer,
                'processed_answer': processed_answer,
                'transformed_problem': transformed_problem,
                'answer_type': answer_type,
                'is_mcq': True,
                'is_yes_no': False,
                'is_open_ended': True,  # Using this field for non-verifiable
                'already_good': False
            }
        
        if is_yes_no:
            if verbose:
                print(f"Skipping yes/no answer sample {sample_id}")
            return {
                'id': sample_id,
                'images': sample_image,
                'processed': False,
                'error': 'Answer is yes/no',
                'problem': problem,
                'answer': original_answer,
                'processed_answer': processed_answer,
                'transformed_problem': transformed_problem,
                'answer_type': answer_type,
                'is_mcq': True,
                'is_yes_no': True,
                'is_open_ended': False,
                'already_good': False
            }
        
        # This is a valid sample to keep
        if verbose:
            print(f"\n{'-'*80}\nProcessed sample {sample_id}:\n")
            print(f"Original Problem: {problem}")
            print(f"Transformed Problem: {transformed_problem}")
            print(f"Original Answer: {original_answer}")
            print(f"Processed Answer: {processed_answer}")
            print(f"Answer Type: {answer_type}")
            print(f"{'-'*80}\n")
        
        return {
            'id': sample_id,
            'images': sample_image,
            'problem': transformed_problem,  # Use the transformed problem
            'answer': processed_answer,
            'processed': True,
            'original_problem': problem,
            'original_answer': original_answer,
            'answer_type': answer_type,
            'is_mcq': is_mcq_answer(original_answer),  # Use the original is_mcq check
            'is_yes_no': False,
            'is_open_ended': False,
            'already_good': False
        }
    except Exception as e:
        if verbose:
            print(f"\n{'-'*80}\nError processing sample {sample_id}: {str(e)}\n")
        
        return {
            'id': sample_id,
            'images': sample_image,
            'error': str(e),
            'processed': False,
            'problem': problem,
            'answer': original_answer,
            'is_mcq': is_mcq_answer(original_answer),
            'is_yes_no': original_answer.lower() in ['yes', 'no'],
            'is_open_ended': False,
            'already_good': False
        }

def process_dataset_parallel(dataset, max_workers=32, system_instruction=None, limit=None, verbose=False):
    """Process the dataset in parallel using ThreadPoolExecutor"""
    results = []
    
    # If limit is provided, randomly sample that many entries
    if limit and limit < len(dataset):
        random_indices = random.sample(range(len(dataset)), limit)
        dataset = dataset.select(random_indices)
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all tasks
        future_to_idx = {
            executor.submit(process_sample, sample, system_instruction, verbose): i 
            for i, sample in enumerate(dataset)
        }
        
        # Process results as they complete
        for future in tqdm(concurrent.futures.as_completed(future_to_idx), total=len(future_to_idx)):
            idx = future_to_idx[future]
            try:
                result = future.result()
                results.append(result)
            except Exception as exc:
                if verbose:
                    print(f"\n{'-'*80}\nException in executor for sample {idx}: {str(exc)}\n{'-'*80}\n")
                
                results.append({
                    'id': dataset[idx]['id'] if 'id' in dataset[idx] else f"sample_{idx}",
                    'images': dataset[idx].get('images'),
                    'error': str(exc),
                    'processed': False,
                    'problem': dataset[idx].get('problem', 'N/A'),
                    'answer': dataset[idx].get('answer', 'N/A'),
                    'is_mcq': is_mcq_answer(dataset[idx].get('answer', '')),
                    'is_yes_no': dataset[idx].get('answer', '').lower() in ['yes', 'no'],
                    'is_open_ended': False
                })
    
    return results

def print_stats(results):
    """Print statistics about the processed results"""
    total_samples = len(results)
    
    # Already good samples (non-MCQ answers that don't need processing)
    already_good_samples = sum(1 for r in results if r.get('already_good', False))
    
    # Samples that required MCQ processing
    mcq_samples = sum(1 for r in results if r.get('is_mcq', False))
    successfully_processed_mcq = sum(1 for r in results if r.get('is_mcq', False) and r.get('processed', False))
    
    # Total successful samples (already good + successfully processed MCQ)
    successful_samples = sum(1 for r in results if r.get('processed', False))
    failed_samples = total_samples - successful_samples
    
    # Filtered samples by reason
    yes_no_count = sum(1 for r in results if r.get('is_yes_no', False))
    non_verifiable_count = sum(1 for r in results if r.get('is_open_ended', False) and not r.get('is_yes_no', False))
    
    # Other error counts
    other_errors = sum(1 for r in results if not r.get('processed', False) and not r.get('is_yes_no', False) and not r.get('is_open_ended', False))
    
    error_types = {}
    for r in results:
        if not r.get('processed', False) and 'error' in r:
            error_type = r['error']
            error_types[error_type] = error_types.get(error_type, 0) + 1
    
    print("\n===== PROCESSING STATISTICS =====")
    print(f"Total samples: {total_samples}")
    print(f"Already good (non-MCQ): {already_good_samples} ({(already_good_samples/total_samples)*100:.2f}%)")
    print(f"Required MCQ processing: {mcq_samples} ({(mcq_samples/total_samples)*100:.2f}%)")
    print(f"Successfully processed MCQ: {successfully_processed_mcq} ({(successfully_processed_mcq/mcq_samples)*100:.2f}% of MCQ samples)" if mcq_samples else "Successfully processed MCQ: 0 (0.00%)")
    
    print(f"\nTotal samples for final dataset: {successful_samples} ({(successful_samples/total_samples)*100:.2f}%)")
    print(f"  - Already good samples: {already_good_samples}")
    print(f"  - Converted MCQ samples: {successful_samples - already_good_samples}")
    
    print(f"\nFiltered out samples: {failed_samples} ({(failed_samples/total_samples)*100:.2f}%)")
    print(f"  - Yes/No answers: {yes_no_count} ({(yes_no_count/total_samples)*100:.2f}%)")
    print(f"  - Not verifiable (bad for exact matching): {non_verifiable_count} ({(non_verifiable_count/total_samples)*100:.2f}%)")
    print(f"  - Other errors: {other_errors} ({(other_errors/total_samples)*100:.2f}%)")
    
    if error_types:
        print("\nCommon error types:")
        for error, count in sorted(error_types.items(), key=lambda x: x[1], reverse=True)[:5]:
            print(f"  - {error}: {count} occurrences")
    
    print("===============================\n")

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description="Process dataset: convert MCQ questions and answers to freeform format")
    parser.add_argument("--dataset_path", type=str, required=True, help="Path to the dataset")
    parser.add_argument("--workers", type=int, default=64, help="Number of parallel workers")
    parser.add_argument("--limit", type=int, default=None, help="Limit number of samples to process (random sampling)")
    parser.add_argument("--model", type=str, default="gemini-2.0-flash", help="Gemini model to use")
    parser.add_argument("--verbose", action="store_true", help="Show detailed processing information")
    args = parser.parse_args()

    # Configure logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler("processing.log"),
            logging.StreamHandler()
        ]
    )

    # Configure Gemini API
    load_dotenv()
    api_key = os.getenv("GOOGLE_API_KEY")
    if api_key:
        genai.configure(api_key=api_key)
    else:
        logging.error("No API key found. Please set GOOGLE_API_KEY in .env file")
        return

    # Updated system instruction to handle both question and answer transformation
    # Modified to emphasize stripping units from answers
    system_instruction = """You are a format converter for educational problems and answers.

    Your job is to:
    1. Transform MCQ problems by removing the multiple choice options and creating a direct question format
    2. For multiple-choice questions with letter/number answers (A, B, C, D, etc.), extract the full text of the answer choice
    3. Classify each answer as either "Verifiable" or "Not verifiable" for exact matching purposes
    4. Remove units from numerical answers (e.g., "45^{\\circ}" or "45°" becomes "45", "-6.8 km" becomes "-6.8")

    When transforming problems:
    - Keep the main question content intact
    - Remove the answer options (A, B, C, D, etc.)
    - Convert "which of the following" or similar phrases to direct questions
    - Preserve all mathematical expressions, formulas, and context
    - IMPORTANT: When the answer contains units (degrees, meters, etc.), incorporate those units into the question format
    * For example, change "the angle is ___" to "the angle is ___ degrees" when the answer has degrees
    * For blank-completion questions, format as "... is ___ [unit]" where [unit] is the unit removed from the answer
    - Ensure the transformed problem is complete and sensible on its own
    - Keep any <image> tags in the same position

    When processing answers:
    - For square root numbers, use LaTeX format like "8 \\sqrt{5}" instead of "8√5"
    - For numerical values, ALWAYS REMOVE UNITS (degrees, meters, kilograms, etc.)
    - Remove any LaTeX unit notations like "^{\\circ}" (for degrees)
    - Keep only the number itself (including negative signs and decimals if present)
    - For example:
    * "45^{\\circ}" or "45°" becomes "45"
    * "-6.8 km" becomes "-6.8"
    * "10 m/s" becomes "10"
    * "8√5" becomes "8 \\sqrt{5}"
    - Do not remove units from mathematical expressions or LaTeX notation

    Verifiable answers (GOOD for exact matching):
    - Short and precise numerical values (including fractions, decimals)
    - Mathematical expressions and formulas
    - Mathematical objects like matrices, vectors, coordinates
    - LaTeX notation (e.g., \\begin{pmatrix} ... \\end{pmatrix})
    - Specific formulas or equations
    - Single concepts or terms
    - Short, definite answers

    Not verifiable answers (BAD for exact matching):
    - Long sentences or explanatory paragraphs
    - Complex relationships (e.g., "X is transferred from Y to Z")
    - Comparative statements (e.g., "X is greater than Y")
    - Descriptive or explanatory answers (not mathematical)
    - Verbal reasoning or arguments

    For each problem, provide:
    1. The transformed problem without MCQ options, with units from the answer incorporated into the question when applicable
    2. The processed answer text with units removed for numerical values and proper LaTeX formatting for square roots
    3. Whether the answer is "Verifiable" or "Not verifiable"."""

    # Load dataset
    logging.info(f"Loading dataset from {args.dataset_path}")
    full_dataset_path = os.path.join('./sampled_data', args.dataset_path)
    try:
        dataset = load_from_disk(full_dataset_path)
        logging.info(f"Loaded dataset with {len(dataset)} samples")
    except Exception as e:
        logging.error(f"Error loading dataset: {str(e)}")
        return

    # Process dataset
    logging.info(f"Processing dataset with {args.workers} workers, limit={args.limit}, verbose={args.verbose}")
    results = process_dataset_parallel(
        dataset,
        max_workers=args.workers,
        system_instruction=system_instruction,
        limit=args.limit,
        verbose=args.verbose
    )

    # Print statistics
    print_stats(results)

    # Filter out only successfully processed items
    clean_results = [r for r in results if r.get('processed', False)]
    
    if not clean_results:
        logging.error("No successfully processed samples. Exiting without creating dataset.")
        return
    
    # Create a new dataset with only the required fields
    processed_dataset = Dataset.from_dict({
        'id': [r['id'] for r in clean_results],
        'images': [r['images'] for r in clean_results],
        'problem': [r['problem'] for r in clean_results],
        'answer': [r['answer'] for r in clean_results]
    })
    
    # Create output path
    output_path = os.path.join('./sampled_data', f"{args.dataset_path}-FREEFORM")
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Save the processed dataset
    processed_dataset.save_to_disk(output_path)
    logging.info(f"Saved processed dataset with {len(processed_dataset)} samples to {output_path}")
    print(f"Saved processed dataset with {len(processed_dataset)} samples to {output_path}")

if __name__ == "__main__":
    main()