import json
import pandas as pd
import ast
import os
import glob
import argparse
import sys
from tqdm import tqdm
from transformers import AutoTokenizer

def parse_python_dict_str(s):
    """Safely evaluate a string containing a Python dictionary."""
    try:
        return ast.literal_eval(s)
    except (ValueError, SyntaxError):
        return None

def parse_args():
    parser = argparse.ArgumentParser(description="Convert Toucan Parquet data to Swift Agent JSONL format.")
    
    parser.add_argument("--input_dir", type=str, required=True, 
                        help="Path to the directory containing raw .parquet files")
    parser.add_argument("--output_dir", type=str, required=True, 
                        help="Directory to save the processed .jsonl file")
    
    parser.add_argument("--model_path", type=str, required=True, 
                        help="Path to the model or tokenizer for length calculation")
    
    parser.add_argument("--max_tokens", type=int, default=30000, 
                        help="Max token threshold for filtering (default: 30000)")
    parser.add_argument("--output_filename", type=str, default="toucan_swift_agent_format_implicit_cot.jsonl",
                        help="Name of the output file")
    
    return parser.parse_args()

def main():
    args = parse_args()
    
    print(f"\n=== Starting Data Conversion ===")
    print(f"Input Directory (Base): {os.path.basename(os.path.normpath(args.input_dir))}")
    print(f"Output Directory: {args.output_dir}")
    print(f"Tokenizer: {os.path.basename(os.path.normpath(args.model_path))}")
    print(f"Max Tokens Threshold: {args.max_tokens}")

    print(f"Loading tokenizer...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
    except Exception as e:
        print(f"Error loading tokenizer: {e}")
        print("Please ensure '--model_path' points to a valid HuggingFace model directory.")
        sys.exit(1)

    parquet_files = glob.glob(os.path.join(args.input_dir, "*.parquet"))
    if not parquet_files:
        print(f"Error: No .parquet files found in {args.input_dir}")
        sys.exit(1)
    
    print(f"Found {len(parquet_files)} parquet files.")
    
    os.makedirs(args.output_dir, exist_ok=True)
    output_file = os.path.join(args.output_dir, args.output_filename)
    
    total_processed = 0
    discarded_format = 0 
    discarded_length = 0 
    preview_count = 0
    
    with open(output_file, 'w', encoding='utf-8') as f_out:
        for p_file in parquet_files:
            file_name = os.path.basename(p_file)
            print(f"Processing file: {file_name} ...")
            
            try:
                df = pd.read_parquet(p_file)
            except Exception as e:
                print(f"Error reading {file_name}: {e}")
                continue
            
            for index, row in tqdm(df.iterrows(), total=df.shape[0], desc=f"Converting"):
                try:
                    # Logic 1: Handle tools
                    tools_str = row.get('tools', '[]')
                    if not (tools_str and len(tools_str) > 5):
                        tools_json_str = "[]"
                    else:
                        try:
                            json.loads(tools_str)
                            tools_json_str = tools_str
                        except json.JSONDecodeError:
                            tools_json_str = "[]" 

                    # Logic 2: Handle messages and roles
                    original_msgs = json.loads(row['messages'])
                    swift_messages = []
                    
                    for i, msg in enumerate(original_msgs):
                        role = msg.get('role')
                        content = msg.get('content')
                        
                        if role == 'user':
                            swift_messages.append({"role": "user", "content": content})
                        
                        elif role == 'assistant':
                            # Check if followed by tool call
                            is_followed_by_tool_call = (i + 1 < len(original_msgs) and 
                                                        original_msgs[i+1].get('role') == 'tool_call')

                            if content and content.strip() and is_followed_by_tool_call:
                                # Your specific logic for implicit COT
                                swift_messages.append({"role": "assistant", "content": f"{content}"})
                            else:
                                swift_messages.append({"role": "assistant", "content": content})

                        elif role == 'tool_call':
                            call_dict = parse_python_dict_str(content)
                            if call_dict:
                                arguments_str = call_dict.get('arguments', '{}')
                                try:
                                    arguments_obj = json.loads(arguments_str)
                                except (json.JSONDecodeError, TypeError):
                                    arguments_obj = {}
                                call_dict['arguments'] = arguments_obj
                                swift_messages.append({"role": "tool_call", "content": json.dumps(call_dict, ensure_ascii=False)})

                        elif role == 'tool_response':
                            try:
                                content_obj = json.loads(content)
                                content_str = json.dumps(content_obj, ensure_ascii=False)
                            except (json.JSONDecodeError, TypeError):
                                content_str = str(content)
                            swift_messages.append({"role": "tool_response", "content": content_str})
                    
                    # Logic 3: Format Validation (Last message must be assistant)
                    is_format_valid = False
                    if swift_messages and swift_messages[-1].get("role") == "assistant":
                        last_content = str(swift_messages[-1].get("content", "")).strip()
                        if last_content: 
                            is_format_valid = True
                    
                    if not is_format_valid:
                        discarded_format += 1
                        continue

                    # Logic 4: Length Validation
                    # Construct raw text for token counting (Approximation)
                    all_text_content = ""
                    for m in swift_messages:
                        all_text_content += str(m.get('role', '')) + "\n" + str(m.get('content', '')) + "\n"
                    
                    full_check_text = tools_json_str + "\n" + all_text_content
                    
                    token_count = len(tokenizer.encode(full_check_text, add_special_tokens=False))
                    
                    if token_count > args.max_tokens:
                        discarded_length += 1
                        continue
                    
                    # Logic 5: Write Entry
                    entry = { "tools": tools_json_str, "messages": swift_messages }
                    
                    # Preview first 2 entries for verification
                    if preview_count < 2:
                        print(f"\n[Preview entry {preview_count+1}]")
                        print(f"  - Tool definitions length: {len(tools_json_str)}")
                        print(f"  - Messages count: {len(swift_messages)}")
                        print(f"  - Token count: {token_count}")
                        preview_count += 1
                    
                    f_out.write(json.dumps(entry, ensure_ascii=False) + '\n')
                    total_processed += 1

                except Exception as e:
                    discarded_format += 1
                    continue

    print("\n" + "="*40)
    print(f"Processing Complete!")
    print(f"Output File: {args.output_filename}")
    print(f"Successfully written: {total_processed}")
    print(f"Discarded (Invalid Format): {discarded_format}")
    print(f"Discarded (Length > {args.max_tokens}): {discarded_length}")
    print("="*40)

if __name__ == "__main__":
    main()