import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.pipelines import run_math_inference
import argparse
from vllm import LLM
import torch
from transformers import AutoTokenizer
from src.data_prep import process_aime, process_gsm, process_mmlu, process_med, process_sudoku
from datasets import DatasetDict
from src.helper import extract_math_answer
import json
import random
from src.model_loader import model_inference_batch_vllm_aio
from enum import Enum
import gc

class ErrorType(Enum):
    CORRECT = "correct"
    NATURAL_ERROR = "natural_error"  # Errors that occur naturally during inference
    SYNTHETIC_ERROR = "synthetic_error"  # Deliberately created easy-to-fix errors

def parse_args():
    parser = argparse.ArgumentParser()
    # Model configurations
    parser.add_argument('--model_name', type=str, required=True, help='Main model for inference')
    parser.add_argument('--error_gen_model_name', type=str, default=None, 
                       help='Model for generating synthetic errors (defaults to main model)')
    
    # Dataset configurations
    parser.add_argument('--dataset_name', type=str, default='GSM', choices=['MMLU', 'GSM', 'AIME', 'MED', 'SUDOKU'],)
    parser.add_argument('--med_subset', type=str, default=None, help='Subset of MED dataset to process') # prompt is for surgery
    parser.add_argument('--sudoku_num_prefilled', type=int, default=10, help='Number of prefilled cells in Sudoku puzzles')

    parser.add_argument('--stop', type=int, default=9999999, help='Max samples to process')
    parser.add_argument('--output_dir', type=str, default='output')
    
    # Generation configurations
    parser.add_argument('--batch_size', type=int, default=500)
    parser.add_argument('--temperature', type=float, default=1.0)
    parser.add_argument('--max_new_tokens', type=int, default=2000)
    
    # Data distribution configurations
    parser.add_argument('--total_samples', type=int, default=1000, 
                       help='Total number of samples to generate')
    parser.add_argument('--wrong_ratio', type=float, default=0.5, 
                       help='Ratio of wrong answers in final dataset (0.5 means 50% wrong)')
    parser.add_argument('--synth_wrong_ratio', type=float, default=0.5, 
                       help='Ratio of synthetic errors within wrong answers (0.5 means 50% of wrong are synthetic)')
    
    # Advanced options
    parser.add_argument('--synthetic_error_prompt_template', type=str, default=None,
                       help='Custom prompt for generating synthetic errors')
    parser.add_argument('--max_generation_attempts', type=int, default=10,
                       help='Maximum attempts to generate required data distribution')
    
    return parser.parse_args()

def categorize_responses(data):
    """Categorize responses into correct and incorrect."""
    correct = []
    incorrect = []
    
    for row in data:
        try:
            if row['answer'] == extract_math_answer(row['label']):
                row['error_type'] = ErrorType.CORRECT.value
                correct.append(row)
            else:
                row['error_type'] = ErrorType.NATURAL_ERROR.value
                incorrect.append(row)
        except Exception as e:
            print(f"Error processing row: {e}")
            continue
    
    return correct, incorrect

def generate_synthetic_errors(correct_responses, model, tokenizer, max_new_tokens, 
                            temperature, prompt_template=None):
    """Generate synthetic easy-to-fix errors from correct responses."""
    
    if prompt_template is None:
        prompt_template = """You will be given a pair of question and answer. Your task is to create an error in the answer that is:
1. Easy to identify and fix
2. Educational for training purposes
3. Realistic but obvious

Original Question: {question}
Original Answer: {response}

Provide ONLY the modified answer with the error, no explanations."""
    
    results = []
    messages = []
    
    for row in correct_responses:
        messages.append([{
            "role": "user",
            "content": prompt_template.format(
                question=row["question"],
                response=row["response"]
            )
        }])
    
    responses = model_inference_batch_vllm_aio(
        model, tokenizer, messages, 
        max_new_tokens=max_new_tokens, 
        temperature=temperature
    )
    
    for response, original_data in zip(responses, correct_responses):
        try:
            answer = extract_math_answer(response)
            # Verify this is actually wrong
            if answer and answer != original_data['answer']:
                results.append({
                    'answer': answer,
                    'response': response,
                    'label': extract_math_answer(original_data['label']),
                    'question': original_data['question'],
                    'error_type': ErrorType.SYNTHETIC_ERROR.value,
                    'original_correct_response': original_data['response'],
                    'original_correct_answer': original_data['answer']
                })
        except Exception as e:
            print(f"Error processing synthetic error: {e}")
            continue
    
    return results

def balance_dataset(correct_data, natural_errors, synthetic_errors, args):
    """Balance the dataset according to specified ratios.
    
    Formula:
    - total * wrong_ratio = total_wrong_count
    - total_wrong * synth_wrong_ratio = synthetic_wrong_count
    - total_wrong * (1 - synth_wrong_ratio) = natural_wrong_count
    - total * (1 - wrong_ratio) = correct_count
    """
    
    total = args.total_samples
    
    # Calculate counts based on ratios
    total_wrong = int(total * args.wrong_ratio)
    n_synthetic = int(total_wrong * args.synth_wrong_ratio)
    n_natural = total_wrong - n_synthetic  # Remainder of wrong answers
    n_correct = total - total_wrong  # Remainder are correct
    
    # Ensure we don't exceed available data
    n_correct = min(n_correct, len(correct_data))
    n_natural = min(n_natural, len(natural_errors))
    n_synthetic = min(n_synthetic, len(synthetic_errors))
    
    # Sample from each category
    random.shuffle(correct_data)
    random.shuffle(natural_errors)
    random.shuffle(synthetic_errors)
    
    balanced_data = (
        correct_data[:n_correct] +
        natural_errors[:n_natural] +
        synthetic_errors[:n_synthetic]
    )
    
    random.shuffle(balanced_data)
    
    return balanced_data, {
        'correct': n_correct,
        'natural_errors': n_natural,
        'synthetic_errors': n_synthetic,
        'total_wrong': n_natural + n_synthetic,
        'total': len(balanced_data),
        'actual_wrong_ratio': (n_natural + n_synthetic) / len(balanced_data) if len(balanced_data) > 0 else 0,
        'actual_synth_ratio_of_wrong': n_synthetic / (n_natural + n_synthetic) if (n_natural + n_synthetic) > 0 else 0
    }

def generate_comprehensive_dataset(args):
    """Main function to generate comprehensive training dataset."""
    
    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(os.path.join(args.output_dir, 'star_0/initial_run'), exist_ok=True)
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs for generating data: {num_gpus}")
    # Calculate target counts
    total_wrong = int(args.total_samples * args.wrong_ratio)
    needed_synthetic = int(total_wrong * args.synth_wrong_ratio)
    needed_natural = total_wrong - needed_synthetic
    needed_correct = args.total_samples - total_wrong
    
    print(f"\n=== Target Distribution ===")
    print(f"Total samples: {args.total_samples}")
    print(f"Correct samples needed: {needed_correct} ({(1-args.wrong_ratio)*100:.1f}%)")
    print(f"Wrong samples needed: {total_wrong} ({args.wrong_ratio*100:.1f}%)")
    print(f"  - Natural errors: {needed_natural} ({(1-args.synth_wrong_ratio)*100:.1f}% of wrong)")
    print(f"  - Synthetic errors: {needed_synthetic} ({args.synth_wrong_ratio*100:.1f}% of wrong)")
    
    # Load main model
    model = LLM(
        model=args.model_name, 
        dtype='half', 
        tensor_parallel_size=num_gpus, 
        max_model_len=8*args.max_new_tokens
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    
    # Load prompt if needed
    base_model_prompt = None
    if "Instruct" not in args.model_name:
        prompt_path = f'prompts/base_model/inference_{args.dataset_name}.md'
        if os.path.exists(prompt_path):
            with open(prompt_path, 'r') as f:
                base_model_prompt = f.read()
    
    # Load dataset
    DATASET_PROCESSORS = {
        "GSM": process_gsm,
        "AIME": process_aime,
        "MMLU": process_mmlu
    }
    if args.dataset_name == "MED":
        dataset = process_med(args.med_subset)
    elif args.dataset_name == "SUDOKU":
        dataset = process_sudoku(args.sudoku_num_prefilled)
    else:
        dataset = DATASET_PROCESSORS[args.dataset_name]()
    dataset_train_valid = dataset['train'].train_test_split(test_size=0.1, seed=42)
    full_dataset = DatasetDict({
        'train': dataset_train_valid['train'][:args.stop],
        'val': dataset_train_valid['test'][:args.stop],
        'test': dataset['test'][:args.stop]
    })
    
    # Containers for different types of responses
    all_correct = []
    all_natural_errors = []
    all_synthetic_errors = []
    
    attempt = 0
    while attempt < args.max_generation_attempts:
        attempt += 1
        print(f"\n=== Generation Attempt {attempt} ===")
        
        # Generate responses with main model
        train_submission, train_full_record = run_math_inference(
            model,
            tokenizer,
            data=(full_dataset['train']['question'], full_dataset['train']['answer']),
            save_path=f'{args.output_dir}/temp',
            batch_size=99999999,
            max_new_tokens=args.max_new_tokens,
            COT=True,
            temperature=args.temperature,
            base_model_prompt=base_model_prompt
        )
        
        # Categorize responses
        correct, natural_errors = categorize_responses(train_full_record)
        all_correct.extend(correct)
        all_natural_errors.extend(natural_errors)
        
        print(f"Generated {len(correct)} correct, {len(natural_errors)} natural errors")
        print(f"Total so far: {len(all_correct)} correct, {len(all_natural_errors)} natural errors")
        
        # Check if we have enough data
        if (len(all_correct) >= needed_correct and 
            len(all_natural_errors) >= needed_natural):
            print("Sufficient data for correct and natural errors generated")
            break
    
    # Generate synthetic errors
    print(f"\n=== Generating Synthetic Errors ===")
    print(f"Need {needed_synthetic} synthetic errors")
    
    # Load error generation model if different
    if args.error_gen_model_name and args.error_gen_model_name != args.model_name and args.synth_wrong_ratio > 0:
        del model
        del tokenizer
        torch.cuda.empty_cache()
        gc.collect()
        # os.environ['CUDA_VISIBLE_DEVICES'] = '3'
        model = LLM(
            args.error_gen_model_name, 
            dtype='half', 
            tensor_parallel_size=num_gpus, 
            max_model_len=4*args.max_new_tokens
        )
        tokenizer = AutoTokenizer.from_pretrained(args.error_gen_model_name)
    
    # Generate synthetic errors from correct responses
    while len(all_synthetic_errors) < needed_synthetic and len(all_correct) > 0:
        batch_size = min(500, len(all_correct), needed_synthetic - len(all_synthetic_errors))
        batch = random.sample(all_correct, batch_size)
        
        synthetic_batch = generate_synthetic_errors(
            batch, model, tokenizer, 
            args.max_new_tokens, args.temperature,
            args.synthetic_error_prompt_template
        )
        
        all_synthetic_errors.extend(synthetic_batch)
        print(f"Generated {len(synthetic_batch)} synthetic errors (total: {len(all_synthetic_errors)})")
    
    # Balance the final dataset
    print("\n=== Balancing Dataset ===")
    final_dataset, stats = balance_dataset(
        all_correct, all_natural_errors, all_synthetic_errors, args
    )
    
    # Save results
    output_path = os.path.join(args.output_dir, 'star_0/initial_run/full_record.json')
    with open(output_path, 'w') as f:
        json.dump(final_dataset, f, indent=2)
    
    # Save statistics
    stats_path = os.path.join(args.output_dir, 'dataset_statistics.json')
    with open(stats_path, 'w') as f:
        json.dump(stats, f, indent=2)
    
    # Save samples of each type for inspection
    samples_path = os.path.join(args.output_dir, 'data_samples.json')
    samples = {
        'correct': all_correct[:5] if all_correct else [],
        'natural_errors': all_natural_errors[:5] if all_natural_errors else [],
        'synthetic_errors': all_synthetic_errors[:5] if all_synthetic_errors else []
    }
    with open(samples_path, 'w') as f:
        json.dump(samples, f, indent=2)
    
    print("\n=== Final Statistics ===")
    for key, value in stats.items():
        if isinstance(value, float):
            print(f"{key}: {value:.3f}")
        else:
            print(f"{key}: {value}")
    
    print(f"\nDataset saved to: {output_path}")
    
    return final_dataset, stats

if __name__ == "__main__":
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    args = parse_args()
    
    # Validate ratios
    if args.wrong_ratio < 0 or args.wrong_ratio > 1:
        raise ValueError(f"wrong_ratio must be between 0 and 1, got {args.wrong_ratio}")
    if args.synth_wrong_ratio < 0 or args.synth_wrong_ratio > 1:
        raise ValueError(f"synth_wrong_ratio must be between 0 and 1, got {args.synth_wrong_ratio}")
    
    print(f"\n=== Configuration ===")
    print(f"Wrong ratio: {args.wrong_ratio} (of total)")
    print(f"Synthetic wrong ratio: {args.synth_wrong_ratio} (of wrong answers)")
    
    dataset, stats = generate_comprehensive_dataset(args)