import os
import json
import argparse
from datasets import Dataset
from PIL import Image
import re
from tqdm import tqdm
from collections import Counter

def parse_arguments():
    parser = argparse.ArgumentParser(description='Process dataset with filtering options')
    parser.add_argument('--k12-only', action='store_true', 
                        help='If set, only accept images with "K12" in URL and not "MMPR"')
    parser.add_argument('--input-jsonl', type=str, default="./sampled_data/MM-Eureka/dataset.jsonl",
                        help='Path to input JSONL file')
    parser.add_argument('--input-base', type=str, default="./sampled_data/MM-Eureka/",
                        help='Base path for input images')
    parser.add_argument('--output-path', type=str, default=None,
                        help='Custom output path for the processed dataset')
    return parser.parse_args()

def load_dataset(args):
    data = []
    jsonl_file_path = args.input_jsonl
    images_base_path = args.input_base
    
    # Check if jsonl file exists
    if not os.path.exists(jsonl_file_path):
        raise FileNotFoundError(f"JSONL file not found at {jsonl_file_path}")
    
    # Count total lines in file for tqdm
    with open(jsonl_file_path, 'r', encoding='utf-8') as f:
        total_lines = sum(1 for _ in f)
    
    print(f"Found {total_lines} samples in dataset.jsonl")
    
    # Counters for statistics
    total_samples = 0
    skipped_yes_no = 0
    missing_images = 0
    filtered_by_k12 = 0
    
    # Process each line
    with open(jsonl_file_path, 'r', encoding='utf-8') as f:
        for line in tqdm(f, total=total_lines, desc="Processing samples"):
            total_samples += 1
            
            sample = json.loads(line.strip())
            
            # Extract id
            sample_id = sample['id']
            
            # Extract problem - get content after "Question:"
            user_content = ""
            for conv in sample['conversations']:
                if conv['role'] == 'user':
                    user_content = conv['content']
                    break
            
            question_match = re.search(r'Question:\s*(.*?)$', user_content, re.DOTALL)
            if question_match:
                problem_text = question_match.group(1).strip()
            else:
                # Fallback: try to extract content after any obvious tags or markers
                problem_text = re.sub(r'^.*?<image>\s*', '', user_content)
                problem_text = re.sub(r'^.*?Question:\s*', '', problem_text)
                problem_text = re.sub(r'Your answer must be.*?tags, respectively,.*?</answer>\.\s*', '', problem_text, flags=re.DOTALL)
            
            # Extract answer and remove $ symbols
            answer = sample['answer'].replace('$', '')
            
            # Skip if answer is 'yes' or 'no'
            if answer.lower() in ['yes', 'no']:
                skipped_yes_no += 1
                continue
            
            # Load image
            image_url = sample.get('image_urls', [''])[0]
            
            # Apply K12 filtering if the flag is set
            if args.k12_only:
                # Skip if image URL doesn't contain "K12" or contains "MMPR"
                if not image_url or "K12" not in image_url or "MMPR" in image_url:
                    filtered_by_k12 += 1
                    continue
            
            image = None
            if image_url:
                image_path = os.path.join(images_base_path, image_url)
                if os.path.exists(image_path):
                    try:
                        image = Image.open(image_path)
                    except Exception as e:
                        print(f"Error loading image {image_path}: {e}")
                        missing_images += 1
                else:
                    missing_images += 1
            else:
                missing_images += 1
            
            # Skip if image is None
            if image is None:
                continue
            
            # Create formatted sample
            formatted_sample = {
                'id': sample_id,
                'images': image,
                'problem': f"<image>\n{problem_text}",
                'answer': answer
            }
            
            data.append(formatted_sample)
    
    # Print initial statistics
    print(f"\nPre-processing Statistics:")
    print(f"  - Total samples processed: {total_samples}")
    print(f"  - Samples with 'yes' or 'no' answers (skipped): {skipped_yes_no}")
    print(f"  - Samples with missing images: {missing_images}")
    if args.k12_only:
        print(f"  - Samples filtered by K12 criteria: {filtered_by_k12}")
    print(f"  - Valid samples after filtering: {len(data)}")
    
    return data

def analyze_dataset(dataset):
    """Analyze the dataset and print statistics"""
    
    # Count answer lengths
    answer_lengths = [len(sample['answer']) for sample in dataset]
    avg_answer_length = sum(answer_lengths) / len(answer_lengths) if answer_lengths else 0
    
    # Count problem lengths
    problem_lengths = [len(sample['problem']) for sample in dataset]
    avg_problem_length = sum(problem_lengths) / len(problem_lengths) if problem_lengths else 0
    
    # Check image sizes
    image_sizes = [f"{sample['images'].width}x{sample['images'].height}" for sample in dataset]
    image_size_counts = Counter(image_sizes)
    
    # Print statistics
    print("\nDataset Statistics:")
    print(f"  - Number of samples: {len(dataset)}")
    print(f"  - Average answer length: {avg_answer_length:.2f} characters")
    print(f"  - Average problem length: {avg_problem_length:.2f} characters")
    print(f"  - Most common image sizes:")
    for size, count in image_size_counts.most_common(5):
        print(f"    * {size}: {count} images ({count/len(dataset)*100:.1f}%)")
    
    # Example of first few answers
    print("\nSample answers:")
    for i in range(min(5, len(dataset))):
        print(f"  - {dataset[i]['answer']}")

def create_hf_dataset(args):
    print("Loading dataset...")
    data_samples = load_dataset(args)
    
    # Create dataset
    print("Creating Hugging Face dataset...")
    dataset = Dataset.from_list(data_samples)
    
    print(f"Created dataset with {len(dataset)} samples")
    print(f"Dataset features: {dataset.features}")
    
    if len(dataset) > 0:
        print(f"Example sample: {dataset[0]['id'], dataset[0]['problem'], dataset[0]['answer']}")
        print(f"Image dimensions: {dataset[0]['images'].size}")
    
    # Analyze the dataset
    analyze_dataset(data_samples)
    
    return dataset

# Execute
if __name__ == "__main__":
    args = parse_arguments()
    
    # Set output path based on arguments
    if args.output_path:
        output_path = args.output_path
    elif args.k12_only:
        output_path = "./sampled_data/K12-HF"
    else:
        # Extract the dataset name from the input base path
        dataset_name = os.path.basename(os.path.normpath(args.input_base))
        output_path = f"./sampled_data/{dataset_name}-HF"
    
    print(f"Output will be saved to: {output_path}")
    
    hf_dataset = create_hf_dataset(args)
    
    # Create output directory if it doesn't exist
    os.makedirs(output_path, exist_ok=True)
    
    # Save the dataset
    print(f"\nSaving dataset to {output_path}...")
    hf_dataset.save_to_disk(output_path)
    
    print(f"Dataset successfully saved to {output_path}")