from tqdm import tqdm
import re
import os
import datasets
import json
from datasets import Dataset, DatasetDict
import random

import argparse

def load_jsonl(jsonl_file: str) -> list:
    """
    Load JSONL file and skip any unparseable lines.

    Args:
        jsonl_file (str): Path to the JSONL file.

    Returns:
        list: A list containing all successfully parsed JSON objects.
    """
    data = []
    with open(jsonl_file, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                # Try to parse each line and remove possible whitespace at the end
                data.append(json.loads(line.strip()))
            except json.JSONDecodeError:
                # If a line is not valid JSON, skip it
                # print(f"Skipping invalid line: {line.strip()}")
                continue
    return data

def is_dialogue_valid(item):
    """
    A more strict validation function.
    Ensure dialogue is a list, each turn is a dict, and role/content values are strings.
    """
    # instance_id source file must exist and be strings
    for key in ["instance_id", "source", "file"]:
        if key not in item or not isinstance(item[key], str):
            return False
    dialogue = item.get("dialogue")
    if not isinstance(dialogue, list):
        return False

    if len(dialogue) > 500 or len(dialogue) == 0:
        return False
    for turn in dialogue:
        if not isinstance(turn, dict):
            return False
        # Must have exactly "role" and "content" keys
        if set(turn.keys()) != {"role", "content"}:
            return False
        # Check if keys exist and values are non-empty strings
        role = turn.get("role")
        content = turn.get("content")
        if not isinstance(role, str) or not isinstance(content, str):
            return False
            
    return True

def construct_hf_dataset_from_standard_data(standard_data, train_split=0.8, seed=42):
    """
    Convert standard format list data to Hugging Face DatasetDict object.

    Args:
        standard_data (list): A list containing dictionaries, each representing one data point.
                              Example: [{'instance_id': '...', 'dialogue': '...', 'source': '...'}]
        train_split (float): The proportion for training set. Test set will be (1 - train_split).
        seed (int): Random seed for reproducible splitting.

    Returns:
        datasets.DatasetDict: A DatasetDict object containing 'train' and 'test' splits.
    """
    if not 0 < train_split < 1:
        raise ValueError("train_split value must be between 0 and 1")

    # 1. Create Hugging Face Dataset object from Python list
    # Dataset.from_list automatically infers data columns and types
    full_dataset = Dataset.from_list(standard_data)

    # 2. Split dataset into training and test sets
    # Use .train_test_split() method, which is the recommended approach
    # It returns a DatasetDict with two predefined keys: 'train' and 'test'
    split_dataset = full_dataset.train_test_split(
        train_size=train_split,
        seed=seed
    )
    
    # If you need strict 'test_size', you can also do:
    # test_size = 1.0 - train_split
    # split_dataset = full_dataset.train_test_split(test_size=test_size, seed=seed)

    return split_dataset

def extract_solution(dialogue):
    return  dialogue[-1]['content']

# Add a row to each data item that represents a unique id
def make_map_fn(split):

    def process_fn(example, idx):
        solution = extract_solution(example['dialogue'])
        data = {
            "data_source": example['source'],
            "file_name": example['file'],
            "prompt": example['dialogue'],
            "messages": example['dialogue'],
            "ability": "code",
            "reward_model": {
                "style": "rule",
                "ground_truth": solution
            },
            "extra_info": {
                'split': split,
                'index': idx,
                'instance_id': example['instance_id']
            }
        }
        return data

    return process_fn

def filter_by_language(data_list, target_language):
    """
    Filter data by specified language
    
    Args:
        data_list: List of data items
        target_language: Target language, empty string means no filtering
    
    Returns:
        Filtered data list
    """
    if not target_language:  # Empty string means no filtering
        return data_list
    
    filtered_data = []
    for item in data_list:
        # Check if language field exists and matches target language
        if isinstance(item, dict) and item.get('language') == target_language:
            filtered_data.append(item)
    
    return filtered_data

def sample_data(data_list, sample_ratio, seed=42):
    """
    Randomly sample data by specified ratio
    
    Args:
        data_list: List of data items
        sample_ratio: Sampling ratio (0.0-1.0)
        seed: Random seed
    
    Returns:
        Sampled data list
    """
    if sample_ratio >= 1.0:  # No sampling needed
        return data_list
    
    if sample_ratio <= 0.0:
        return []
    
    random.seed(seed)
    sample_size = int(len(data_list) * sample_ratio)
    return random.sample(data_list, sample_size)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Convert filtered JSONL data to Parquet format.")
    parser.add_argument("--intermediate_dir", type=str, default="intermediate_data", help="Directory containing filtered intermediate jsonl files")
    parser.add_argument("--stop_instances_file", type=str, default="stop_instances.json", help="Path to stop_instances json file")
    parser.add_argument("--output_dir", type=str, default="data_cache", help="Directory to save parquet files")
    parser.add_argument("--train_file", type=str, default="train_multi.parquet", help="Train parquet filename")
    parser.add_argument("--test_file", type=str, default="test_multi.parquet", help="Test parquet filename")
    parser.add_argument("--train_split", type=float, default=0.9, help="Train split ratio")
    parser.add_argument("--with_completion", action="store_true", help="Include completion_data in output if set.")
    parser.add_argument("--language", type=str, default="", help="Language to filter (empty string means all languages)")
    parser.add_argument("--sample_ratio", type=float, default=1.0, help="Sampling ratio for each data type (0.0-1.0, default 1.0 means no sampling)")
    args = parser.parse_args()
    print("=== Loading data from intermediate files and converting to Parquet format ===")
    
    # Load filtered intermediate jsonl files
    file_loc_data = load_jsonl(os.path.join(args.intermediate_dir, "filtered_file_loc_data.jsonl"))
    func_loc_data = load_jsonl(os.path.join(args.intermediate_dir, "filtered_func_loc_data.jsonl"))
    task_data = load_jsonl(os.path.join(args.intermediate_dir, "filtered_task_data.jsonl"))
    completion_data = []
    if args.with_completion:
        completion_data = load_jsonl(os.path.join(args.intermediate_dir, "filtered_completion_data.jsonl"))
    
    print(f"Loaded File localization data: {len(file_loc_data)}")
    print(f"Loaded Function localization data: {len(func_loc_data)}")
    print(f"Loaded Task data: {len(task_data)}")
    if args.with_completion:
        print(f"Loaded Completion data: {len(completion_data)}")

    # Filter data by language (if specified)
    if args.language:
        print(f"\n--- Filtering by language: {args.language} ---")
        original_counts = [len(file_loc_data), len(func_loc_data), len(task_data), len(completion_data)]
        
        file_loc_data = filter_by_language(file_loc_data, args.language)
        func_loc_data = filter_by_language(func_loc_data, args.language)
        task_data = filter_by_language(task_data, args.language)
        if args.with_completion:
            completion_data = filter_by_language(completion_data, args.language)
        
        print(f"After language filtering:")
        print(f"File localization data: {len(file_loc_data)}/{original_counts[0]}")
        print(f"Function localization data: {len(func_loc_data)}/{original_counts[1]}")
        print(f"Task data: {len(task_data)}/{original_counts[2]}")
        if args.with_completion:
            print(f"Completion data: {len(completion_data)}/{original_counts[3]}")

    # Sample data by sampling ratio (if specified)
    if args.sample_ratio < 1.0:
        print(f"\n--- Sampling data with ratio: {args.sample_ratio} ---")
        original_counts = [len(file_loc_data), len(func_loc_data), len(task_data), len(completion_data)]
        
        file_loc_data = sample_data(file_loc_data, args.sample_ratio)
        func_loc_data = sample_data(func_loc_data, args.sample_ratio)
        task_data = sample_data(task_data, args.sample_ratio)
        if args.with_completion:
            completion_data = sample_data(completion_data, args.sample_ratio)
        
        print(f"After sampling:")
        print(f"File localization data: {len(file_loc_data)}/{original_counts[0]}")
        print(f"Function localization data: {len(func_loc_data)}/{original_counts[1]}")
        print(f"Task data: {len(task_data)}/{original_counts[2]}")
        if args.with_completion:
            print(f"Completion data: {len(completion_data)}/{original_counts[3]}")

    # Merge all data
    gross_data_original_clean = file_loc_data + func_loc_data + task_data + (completion_data if args.with_completion else [])
    
    print(f"Total merged data: {len(gross_data_original_clean)} items")
    print("Note: Data has been filtered by token length in the first stage, no need to process again.")

    final_clean_data = gross_data_original_clean

    stop_instance_ids = set()
    if os.path.exists(args.stop_instances_file):
        with open(args.stop_instances_file, 'r', encoding='utf-8') as f:
            stop_instance_ids = set([item["instance_id"] for item in json.load(f)])
        print(f"Loaded {len(stop_instance_ids)} stop instances from {args.stop_instances_file}")
    else:
        print(f"Stop instances file {args.stop_instances_file} not found, no instances will be filtered.")

    final_clean_data_wo_stop_instances = [
        item for item in final_clean_data
        if item["instance_id"].rstrip('-completion') not in stop_instance_ids
    ]

    print(f"Removed {len(final_clean_data) - len(final_clean_data_wo_stop_instances)} stop instances. Remaining: {len(final_clean_data_wo_stop_instances)} samples.")

    random.seed(42)
    random.shuffle(final_clean_data_wo_stop_instances)

    # Final step: strict validation of data structure
    valid_data = []
    invalid_count = 0
    for item in tqdm(final_clean_data_wo_stop_instances, desc="Validating data structure"):
        if is_dialogue_valid(item):
            valid_data.append(item)
        else:
            invalid_count += 1
            # (Optional) Print problematic instance_id for debugging
            # print(f"Skipping invalid data with instance_id: {item.get('instance_id', 'Unknown')}")

    print(f"✅ Found {len(valid_data)} valid samples.")
    if invalid_count > 0:
        print(f"❌ Skipped {invalid_count} invalid samples due to malformed dialogue structure.")

    # Now we only use completely valid data to build Dataset
    dataset = construct_hf_dataset_from_standard_data(
        valid_data,  # Use filtered valid_data
        args.train_split
    )

    train_dataset = dataset['train']
    test_dataset = dataset['test']
    print(f'Train Size: {len(train_dataset)} | Test Size: {len(test_dataset)}')

    train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
    test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)

    # Count various data types in final dataset (grouped by source), output for both train and test sets
    from collections import Counter
    def print_source_stats(ds, name):
        sources = [item['data_source'] for item in ds]
        counter = Counter(sources)
        print(f"{name} source stats:")
        for k, v in counter.items():
            print(f"  {k}: {v}")

    print_source_stats(train_dataset, "Train")
    print_source_stats(test_dataset, "Test")

    os.makedirs(args.output_dir, exist_ok=True)
    train_dataset.to_parquet(os.path.join(args.output_dir, args.train_file))
    test_dataset.to_parquet(os.path.join(args.output_dir, args.test_file))

