import argparse
from itertools import islice
import json
import os
import asyncio
import time
import signal
import sys
from datetime import datetime
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
import pandas as pd
from api_call import *

class RateLimiter:
    """Rate limiter to control API call frequency"""
    def __init__(self, calls_per_minute=40):
        self.calls_per_minute = calls_per_minute
        self.interval = 60 / calls_per_minute  # Time between requests in seconds
        self.last_call_time = 0
        self.lock = asyncio.Lock()
    
    async def wait_if_needed(self):
        """Wait if necessary to maintain the rate limit"""
        async with self.lock:
            current_time = time.time()
            time_since_last_call = current_time - self.last_call_time
            
            if time_since_last_call < self.interval:
                wait_time = self.interval - time_since_last_call
                await asyncio.sleep(wait_time)
            
            self.last_call_time = time.time()

class CheckpointManager:
    """Manages checkpoint operations for the experiment"""
    def __init__(self, output_dir, checkpoint_interval=10):
        self.output_dir = output_dir
        self.checkpoint_interval = checkpoint_interval
        self.checkpoint_file = os.path.join(output_dir, "checkpoint.json")
        self.checkpoint_data = {
            "last_processed_index": -1,
            "prompt_correct_counts": {},
            "prompt_token_usage": {},
            "total_questions": 0,
            "is_first_result": True,
            "is_first_log": True,
            "timestamp": None
        }
        self.load_checkpoint()
    
    def load_checkpoint(self):
        """Load checkpoint data if it exists"""
        if os.path.exists(self.checkpoint_file):
            with open(self.checkpoint_file, "r") as f:
                checkpoint = json.load(f)
                self.checkpoint_data = checkpoint
                print(f"Resuming from checkpoint at index {checkpoint['last_processed_index']}")
                return True
        return False
    
    def save_checkpoint(self, index, prompt_correct_counts, prompt_token_usage, total_questions, 
                       is_first_result, is_first_log):
        """Save current progress to checkpoint file"""
        # Convert numpy arrays to lists for JSON serialization
        serializable_token_usage = {}
        for prompt, tokens in prompt_token_usage.items():
            serializable_token_usage[prompt] = tokens if isinstance(tokens, list) else tokens.tolist()
        
        self.checkpoint_data = {
            "last_processed_index": index,
            "prompt_correct_counts": prompt_correct_counts,
            "prompt_token_usage": serializable_token_usage,
            "total_questions": total_questions,
            "is_first_result": is_first_result,
            "is_first_log": is_first_log,
            "timestamp": datetime.now().isoformat()
        }
        
        with open(self.checkpoint_file, "w") as f:
            json.dump(self.checkpoint_data, f, indent=4)
    
    def should_save_checkpoint(self, current_index):
        """Determine if we should save a checkpoint now"""
        return current_index % self.checkpoint_interval == 0

def setup_output_files(model_name, dataset_name, resume=False):
    """Create output directory and files, with option to resume from previous run"""
    if resume:
        # Look for existing output directories for this model and dataset
        dirs = [d for d in os.listdir() if d.startswith(f'optimized_results_{dataset_name}_{model_name}_') and os.path.isdir(d)]
        if dirs:
            # Sort by creation time (most recent first)
            dirs.sort(key=lambda x: os.path.getctime(x), reverse=True)
            output_dir = dirs[0]
            print(f"Resuming from existing directory: {output_dir}")
            
            analysis_file = os.path.join(output_dir, "analysis.json")
            result_file = os.path.join(output_dir, "results.json")
            log_file = os.path.join(output_dir, "log.json")
            
            return output_dir, analysis_file, result_file, log_file
    
    # Create new output directory
    output_dir = f'optimized_results_{dataset_name}_{model_name}_{datetime.now().strftime("%m%d_%H%M")}'
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    analysis_file = os.path.join(output_dir, "analysis.json")
    result_file = os.path.join(output_dir, "results.json")
    log_file = os.path.join(output_dir, "log.json")
    
    # Initialize the analysis file with metadata
    with open(analysis_file, "w") as f:
        json.dump({
            "metadata": {
                "description": "Experiment comparing different prompt strategies for reasoning efficiency",
                "prompt_types": list(reduce_prompts.keys())
            },
            "summary": {},
            "per_prompt_metrics": []
        }, f, indent=4)
    
    # Initialize the result file
    with open(result_file, "w") as f:
        f.write("[\n")  # Start JSON array
    
    # Initialize the log file
    with open(log_file, "w") as f:
        f.write("[\n")  # Start JSON array
    
    return output_dir, analysis_file, result_file, log_file

async def process_question_with_model(question, reference_answer, model_name, rate_limiter, dataset_name="math500"):
    """Process a single question with all prompts using the specified model"""
    
    if not question or not reference_answer:
        return None
    
    # Map of which API call function to use based on model name
    model_api_map = {
        "deepseek-reasoner": call_ds_series_async,
        "o1-2024-12-17": call_reasoning_gpt_series_async,
        "claude-3-7-sonnet-20250219": call_claude_series_async,
        "gemini-2.5-flash-preview-04-17": call_gemini_async
    }
    
    # Get the appropriate API call function
    api_call_func = model_api_map.get(model_name)
    if not api_call_func:
        raise ValueError(f"Unsupported model: {model_name}")
    
    # Test all reduce prompts in parallel
    all_results = []
    for prompt_name, prompt_template in reduce_prompts.items():
        result = await test_prompt_with_model(question, reference_answer, prompt_name, 
                                             prompt_template, model_name, api_call_func, 
                                             rate_limiter, dataset_name)
        all_results.append(result)
    
    return {
        "question": question,
        "reference_answer": reference_answer,
        "results": all_results
    }

async def test_prompt_with_model(question, reference_answer, prompt_name, prompt_template, 
                               model_name, api_call_func, rate_limiter, dataset_name="math500"):
    """Test a single prompt with the specified model and return the results"""
    formatted_prompt = prompt_template.format(q=question)
    
    # Apply rate limiting before making the API call
    await rate_limiter.wait_if_needed()
    
    # Call the API with the appropriate model
    response = await api_call_func(formatted_prompt, model_name)
    
    # Initialize thinking_text variable
    thinking_text = None
        
    # Handle different response formats based on the model
    if model_name.startswith("claude"):
        answer = response.content[1].text
        # For Claude, we can get thinking tokens if available
        thinking_text = response.content[0].thinking
        reasoning_tokens = estimate_tokens(thinking_text)
        
        # Use completion tokens as a fallback
        completion_tokens = response.usage.output_tokens
        total_tokens = response.usage.input_tokens + response.usage.output_tokens
    elif "gemini" in model_name:
        #input = total - completion (candidate) - reasoning
        #output = completion (candidate)
        #reasoning = reasoning
        answer = response.text
        usage = json.loads(response.model_dump_json())["usage_metadata"]

        reasoning_tokens = usage["thoughts_token_count"]
        completion_tokens = usage["candidates_token_count"]
        total_tokens = usage["total_token_count"]
    else:
        # For OpenAI and DeepSeek models
        #input = total - completion
        #output = completion - reasoning
        #reasoning = reasoning
        answer = response.choices[0].message.content
        usage = response.usage.to_dict()
        
        reasoning_tokens = usage.get('completion_tokens_details').get('reasoning_tokens')
        completion_tokens = usage.get('completion_tokens')
        total_tokens = usage.get('total_tokens')
    
    # Apply rate limiting before making the validation API call
    await rate_limiter.wait_if_needed()
    
    is_correct = await validate_answer_async(question, reference_answer, answer)
    
    result = {
        "prompt_name": prompt_name,
        "correct": is_correct,
        "reasoning_tokens": reasoning_tokens,
        "completion_tokens": completion_tokens,
        "total_tokens": total_tokens,
        "answer": answer
    }
    
    # Add thinking_text to the result if it's a Claude model
    if model_name.startswith("claude") and thinking_text:
        result["thinking_text"] = thinking_text
    
    return result

def append_to_json_file(file_path, data, is_first):
    """Append data to a JSON file, handling proper array formatting"""
    with open(file_path, "a") as f:
        if not is_first:
            f.write(",\n")
        json.dump(data, f, indent=4)
    return False  # Return is_first=False for next time

def close_json_files(result_file, log_file):
    """Close JSON arrays in result and log files"""
    with open(result_file, "a") as f:
        f.write("\n]")
    
    with open(log_file, "a") as f:
        f.write("\n]")

async def process_dataset(dataset, output_dir, analysis_file, result_file, log_file, model_name, 
                       rate_limiter, checkpoint_manager, dataset_name="math500"):
    """Process dataset and update all output files"""
    # Initialize counters and lists for analysis from checkpoint if available
    total_questions = checkpoint_manager.checkpoint_data["total_questions"]
    prompt_correct_counts = checkpoint_manager.checkpoint_data.get("prompt_correct_counts", 
                                                                 {prompt_name: 0 for prompt_name in reduce_prompts.keys()})
    
    # Initialize token usage from checkpoint or create new
    prompt_token_usage = {}
    for prompt_name in reduce_prompts.keys():
        if prompt_name in checkpoint_manager.checkpoint_data.get("prompt_token_usage", {}):
            prompt_token_usage[prompt_name] = checkpoint_manager.checkpoint_data["prompt_token_usage"][prompt_name]
        else:
            prompt_token_usage[prompt_name] = []
    
    is_first_result = checkpoint_manager.checkpoint_data.get("is_first_result", True)
    is_first_log = checkpoint_manager.checkpoint_data.get("is_first_log", True)
    
    # Start from the last processed index + 1
    start_index = checkpoint_manager.checkpoint_data["last_processed_index"] + 1
    
    # Register signal handlers for graceful shutdown
    def signal_handler(sig, frame):
        print("\nReceived interrupt signal. Saving checkpoint before exiting...")
        checkpoint_manager.save_checkpoint(i, prompt_correct_counts, prompt_token_usage, 
                                         total_questions, is_first_result, is_first_log)
        close_json_files(result_file, log_file)
        print("Checkpoint saved. Exiting.")
        sys.exit(0)
    
    signal.signal(signal.SIGINT, signal_handler)  # Handle Ctrl+C
    signal.signal(signal.SIGTERM, signal_handler)  # Handle termination signal
    
    # Process each question in the dataset
    field_map = {
        "math500":    ("problem",        "answer"),
        "gsm8k":      ("question",       "answer"),
        "svamp":      ("question_concat","Answer"),
        "aime_2024":  ("problem",        "answer"),
        "gpqa":       ("Question",       "Correct Answer")
    }

    name_lower = dataset_name.lower()
    for pattern, (q_key, a_key) in field_map.items():
        if pattern in name_lower:
            break
    else:
        raise ValueError(
            f"Unsupported dataset: {dataset_name!r}. "
            f"Supported datasets must contain one of: {', '.join(field_map)}."
        )

    for i, data in enumerate(tqdm(dataset)):
        if i < start_index:
            continue  # Skip already processed questions
        question         = data[q_key]
        reference_answer = data[a_key]
            
        if not question or not reference_answer:
            continue
        
        total_questions += 1
        
        try:
            # Process the question with all prompt types using the specified model
            question_results = await process_question_with_model(
                question, reference_answer, model_name, rate_limiter, dataset_name
            )
            
            if not question_results:
                continue
            
            # Detailed log entry for this question
            log_entry = {
                "question_id": i,
                "question": question,
                "reference_answer": reference_answer,
                "prompt_results": []
            }
            
            # Simplified result entry for this question
            result_entry = {
                "question_id": i,
                "question": question,
                "reference_answer": reference_answer,
                "answers": {}
            }
            
            # Process results for each prompt type
            for result in question_results["results"]:
                prompt_name = result["prompt_name"]
                is_correct = result["correct"]
                reasoning_tokens = result["reasoning_tokens"]
                answer = result["answer"]
                
                # Update counters
                if is_correct:
                    prompt_correct_counts[prompt_name] = prompt_correct_counts.get(prompt_name, 0) + 1
                
                # Store token usage
                if prompt_name not in prompt_token_usage:
                    prompt_token_usage[prompt_name] = []
                prompt_token_usage[prompt_name].append(reasoning_tokens)
                
                # Store response in result file
                result_entry["answers"][prompt_name] = {
                    "answer": answer,
                    "correct": is_correct,
                    "reasoning_tokens": reasoning_tokens
                }
                
                # Create log entry for prompt result
                prompt_log_entry = {
                    "prompt_name": prompt_name,
                    "prompt_text": reduce_prompts[prompt_name].format(q=question),
                    "answer": answer,
                    "correct": is_correct,
                    "reasoning_tokens": reasoning_tokens,
                    "completion_tokens": result["completion_tokens"],
                    "total_tokens": result["total_tokens"]
                }
                
                # Add thinking text to log if it's a Claude model and thinking_text is available
                if model_name.startswith("claude") and "thinking_text" in result:
                    prompt_log_entry["thinking_text"] = result["thinking_text"]
                
                log_entry["prompt_results"].append(prompt_log_entry)
            
            # Write to result file (JSON array format)
            is_first_result = append_to_json_file(result_file, result_entry, is_first_result)
            
            # Write to log file (JSON array format)
            is_first_log = append_to_json_file(log_file, log_entry, is_first_log)
            
            # Save checkpoint if needed
            if checkpoint_manager.should_save_checkpoint(i):
                checkpoint_manager.save_checkpoint(i, prompt_correct_counts, prompt_token_usage, 
                                                 total_questions, is_first_result, is_first_log)
                print(f"Checkpoint saved at index {i}")
            
            # Print progress update
            if (i + 1) % 10 == 0:
                print(f"Processed {i + 1} questions. Current accuracies:")
                for prompt_name in reduce_prompts.keys():
                    accuracy = prompt_correct_counts.get(prompt_name, 0) / total_questions * 100 if total_questions > 0 else 0
                    avg_tokens = np.mean(prompt_token_usage.get(prompt_name, [0])) if prompt_token_usage.get(prompt_name) else 0
                    print(f"  {prompt_name}: {accuracy:.2f}% accuracy, {avg_tokens:.2f} avg tokens")
        
        except Exception as e:
            print(f"Error processing question {i}: {e}")
            # Save checkpoint on error
            checkpoint_manager.save_checkpoint(i-1, prompt_correct_counts, prompt_token_usage, 
                                             total_questions, is_first_result, is_first_log)
            print(f"Checkpoint saved at index {i-1} due to error")
            raise e  # Re-raise the exception to stop processing
    
    # Close the JSON arrays in result and log files
    close_json_files(result_file, log_file)
    
    # Calculate final analysis metrics
    analysis = {
        "metadata": {
            "total_questions": total_questions,
            "model_name": model_name,
            "dataset_name": dataset_name
        },
        "summary": {
            "overall_accuracy": {},
            "token_efficiency": {},
            "correct_with_minimum_tokens_count": 0,
            "best_performing_prompt": ""
        },
        "per_prompt_metrics": []
    }
    
    # Track which prompt had minimum tokens for correct answers
    min_token_counts = {prompt_name: 0 for prompt_name in reduce_prompts.keys()}
    
    # Calculate metrics for each prompt type
    best_accuracy = 0
    best_prompt = "normal"
    
    for prompt_name in reduce_prompts.keys():
        correct_count = prompt_correct_counts.get(prompt_name, 0)
        accuracy = correct_count / total_questions * 100 if total_questions > 0 else 0
        
        avg_tokens = np.mean(prompt_token_usage.get(prompt_name, [0])) if prompt_token_usage.get(prompt_name) else 0
        median_tokens = np.median(prompt_token_usage.get(prompt_name, [0])) if prompt_token_usage.get(prompt_name) else 0
        
        # Update best performing prompt
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_prompt = prompt_name
        
        # Count cases where this prompt used minimum tokens and was correct
        try:
            with open(result_file, "r") as f:
                results = json.load(f)
                
            for result in results:
                answers = result.get("answers", {})
                # Find minimum tokens among correct answers
                correct_answers = {p: data for p, data in answers.items() if data.get("correct", False)}
                
                if correct_answers:
                    min_token_prompt = min(correct_answers.items(), key=lambda x: x[1].get("reasoning_tokens", float('inf')))[0]
                    if min_token_prompt == prompt_name:
                        min_token_counts[prompt_name] = min_token_counts.get(prompt_name, 0) + 1
        except Exception as e:
            print(f"Error calculating minimum token counts: {e}")
            raise e
        
        # Store metrics for this prompt
        analysis["per_prompt_metrics"].append({
            "prompt_name": prompt_name,
            "accuracy": accuracy,
            "correct_count": correct_count,
            "avg_reasoning_tokens": avg_tokens,
            "median_reasoning_tokens": median_tokens,
            "min_tokens_and_correct_count": min_token_counts.get(prompt_name, 0)
        })
        
        # Add to summary
        analysis["summary"]["overall_accuracy"][prompt_name] = accuracy
        analysis["summary"]["token_efficiency"][prompt_name] = avg_tokens
    
    # Determine overall best prompt (balancing accuracy and efficiency)
    analysis["summary"]["best_performing_prompt"] = best_prompt
    analysis["summary"]["correct_with_minimum_tokens_count"] = sum(min_token_counts.values())
    
    # Write final analysis
    with open(analysis_file, "w") as f:
        json.dump(analysis, f, indent=4)
    
    # Generate CSV report for easy analysis
    csv_file = os.path.join(output_dir, "analysis_summary.csv")
    df = pd.DataFrame(analysis["per_prompt_metrics"])
    df.to_csv(csv_file, index=False)
    
    # Clear checkpoint file as we've completed successfully
    if os.path.exists(checkpoint_manager.checkpoint_file):
        os.rename(checkpoint_manager.checkpoint_file, os.path.join(output_dir, "final_checkpoint.json"))
    
    return analysis

def main():
    parser = argparse.ArgumentParser(description="Run reasoning experiments with different prompt strategies")
    parser.add_argument("--sample_size", type=int, default=0, help="Number of questions to sample (0 for all)")
    parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility")
    parser.add_argument("--model", type=str, default="deepseek-reasoner", 
                        choices=["deepseek-reasoner", "o1-2024-12-17", "claude-3-7-sonnet-20250219", "gemini-2.5-flash-preview-04-17"],
                        help="Model to use for inference")
    parser.add_argument("--dataset", type=str, default="math500", 
                        choices=["math500", "gsm8k", "svamp", "aime_2024", "gpqa_diamond"], 
                        help="Dataset to use for evaluation")
    parser.add_argument("--rate_limit", type=int, default=40,
                        help="Maximum number of API calls per minute (default: 40)")
    parser.add_argument("--checkpoint_interval", type=int, default=10,
                        help="Save checkpoint after processing this many questions (default: 10)")
    parser.add_argument("--resume", action="store_true",
                        help="Resume from last checkpoint if available")
    args = parser.parse_args()
    
    # Setup output files with the specified model and dataset
    output_dir, analysis_file, result_file, log_file = setup_output_files(args.model, args.dataset, args.resume)
    
    # Create checkpoint manager
    checkpoint_manager = CheckpointManager(output_dir, args.checkpoint_interval)
    
    # Create rate limiter with the specified rate limit
    rate_limiter = RateLimiter(calls_per_minute=args.rate_limit)
    
    # Load the appropriate dataset
    dataset_name = args.dataset.lower()
    if "math500" in dataset_name:
        dataset = load_dataset('HuggingFaceH4/MATH-500')
        dataset.shuffle(args.seed)  # Shuffle for reproducibility
        if args.sample_size > 0:
            test_dataset = dataset["test"].select(range(min(len(dataset["test"]), args.sample_size)))
        else:
            test_dataset = dataset["test"]
    elif "gsm8k" in dataset_name:
        # Default to GSM8K
        dataset = load_dataset('openai/gsm8k', 'main')
        dataset.shuffle(args.seed)  # Shuffle for reproducibility
        if args.sample_size > 0:
            test_dataset = dataset["test"].select(range(min(len(dataset["test"]), args.sample_size)))
        else:
            test_dataset = dataset["test"]
    elif "svamp" in dataset_name:
        dataset = load_dataset('ChilleD/SVAMP')
        dataset.shuffle(args.seed)  # Shuffle for reproducibility
        if args.sample_size > 0:
            test_dataset = dataset["test"].select(range(min(len(dataset["test"]), args.sample_size)))
        else:
            test_dataset = dataset["test"]
    elif "aime_2024" in dataset_name:
        dataset = load_dataset('HuggingFaceH4/aime_2024')
        dataset.shuffle(args.seed)  # Shuffle for reproducibility
        if args.sample_size > 0:
            test_dataset = dataset["train"].select(range(min(len(dataset["train"]), args.sample_size)))
        else:
            test_dataset = dataset["train"]
    elif "gpqa" in dataset_name:
        dataset = load_dataset('Idavidrein/gpqa', dataset_name)
        dataset.shuffle(args.seed)  # Shuffle for reproducibility
        if args.sample_size > 0:
            test_dataset = dataset["train"].select(range(min(len(dataset["train"]), args.sample_size)))
        else:
            test_dataset = dataset["train"]
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}. Supported datasets are: gsm8k, math500, SVAMP.")
    
    # Run the experiment
    print(f"Starting experiment with {args.model} on {args.dataset} dataset")
    print(f"Processing up to {args.sample_size} samples" if args.sample_size > 0 else "Processing all samples")
    print(f"Rate limit: {args.rate_limit} API calls per minute")
    print(f"Checkpoint interval: every {args.checkpoint_interval} questions")
    print(f"Run with these reduce prompt:{list(reduce_prompts.keys())}")
    # For asyncio to work in Jupyter or scripts
    try:
        analysis = asyncio.run(process_dataset(
            test_dataset, 
            output_dir,
            analysis_file,
            result_file,
            log_file,
            args.model,
            rate_limiter,
            checkpoint_manager,
            args.dataset
        ))
        
        # Print final summary
        print("\nExperiment completed successfully!")
        print(f"Results saved to {output_dir}/")
        print("\nSummary of results:")
        
        for prompt_name, accuracy in analysis["summary"]["overall_accuracy"].items():
            token_efficiency = analysis["summary"]["token_efficiency"][prompt_name]
            print(f"  {prompt_name}: {accuracy:.2f}% accuracy, {token_efficiency:.2f} avg tokens")
        
        print(f"\nBest performing prompt: {analysis['summary']['best_performing_prompt']}")
        
    except Exception as e:
        print(f"Error running experiment: {e}")
        import traceback
        traceback.print_exc()
        
        # Save final checkpoint on unexpected error
        checkpoint_manager.save_checkpoint(
            checkpoint_manager.checkpoint_data["last_processed_index"], 
            checkpoint_manager.checkpoint_data.get("prompt_correct_counts", {}),
            checkpoint_manager.checkpoint_data.get("prompt_token_usage", {}),
            checkpoint_manager.checkpoint_data.get("total_questions", 0),
            checkpoint_manager.checkpoint_data.get("is_first_result", True),
            checkpoint_manager.checkpoint_data.get("is_first_log", True)
        )
        print("Final checkpoint saved. You can resume the experiment using --resume flag.")
        raise e

if __name__ == "__main__":
    main()