# set PYTHONPATH to include src

import pandas as pd
import json
import re
from typing import List, Dict, Literal
from pathlib import Path
from eval.run_pipeline import get_valid_data_prefixes, get_typed_data_prefixes
import os
import logging
from eval.util import load_simulation_dfs

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

fh = logging.FileHandler("dpo_data.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)


def parse_input_prompt(input_prompt: str) -> List[Dict[str, str]]:
    messages = []
    
    # Find all role patterns $$role$$: in the input
    role_pattern = r'\$\$([^$]+)\$\$:\s*(.*?)(?=\$\$[^$]+\$\$:|$)'
    matches = re.findall(role_pattern, input_prompt, re.DOTALL)
    
    for role, content in matches:
        if role.lower() == "human":
            role = "user"
    
        messages.append({
            "role": role.lower(),
            "content": content.strip()
        })
    
    return messages


def process_dp(data_prefix: str, model_name: str, human_file_path: str, llm_file_path: str) -> List[Dict[str, dict]]:
    try:
        dfs = load_simulation_dfs(
            data_prefix=data_prefix,
            model_name=model_name,
            eval_model_name="gpt-4o-mini-2024-07-18",
            version="v2",
            filter_strategy="human_only",
            human_file_path=human_file_path,
            llm_file_path=llm_file_path,
        )
        df = dfs[1]  # LLM dataframe
        
        if df.empty:
            return []
        
        dpo_examples = []
        
        for _, row in df.iterrows():
            # Skip rows with missing required fields
            if pd.isna(row['input_prompt']) or pd.isna(row['llm_text']) or pd.isna(row['text']):
                continue
            
            # Parse input prompt to messages
            input_messages = parse_input_prompt(row['input_prompt'])
            if len(input_messages) == 0:
                continue
            
            # Create DPO example
            dpo_example = {
                "input": {
                    "messages": input_messages,
                    "tools": [],
                    "parallel_tool_calls": True
                },
                "preferred_output": [
                    {
                        "role": "assistant",
                        "content": str(row['text'])  # Human/preferred response
                    }
                ],
                "non_preferred_output": [
                    {
                        "role": "assistant", 
                        "content": str(row['llm_text'])  # LLM/non-preferred response
                    }
                ]
            }
            
            dpo_examples.append(dpo_example)
        
        return dpo_examples
        
    except Exception as e:
        logger.debug(f"Error processing {data_prefix} with model {model_name}: {e}")
        return []


def generate_dpo_dataset(
    model: str,
    collection_type: Literal["depth", "breadth"] = "depth",
    split_type: Literal["round", "topic", "group"] = "round",
    set_type: Literal["train", "test"] = "train"
):
    result_dir = "/mnt/dv/wid/projects3/XXXX-3-XXXX-5-human-ai/mini-twitter-llm-agent-binary/data/dpo_data"
    output_dir = "/mnt/dv/wid/projects3/XXXX-3-XXXX-5-human-ai/mini-twitter-llm-agent-binary/data/dpo_data_formatted"
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, f"{model}_{collection_type}_{split_type}_{set_type}.jsonl")
    
    data_prefixes = get_valid_data_prefixes()
    data_prefixes = get_typed_data_prefixes(data_prefixes, collection_type)
    
    # Find all relevant CSV files
    configs = []
    for data_prefix in data_prefixes:
        csv_file = os.path.join(result_dir, model, f"{split_type}_split_data_{collection_type}", set_type, f"{data_prefix}.csv")
        if os.path.exists(csv_file):
            configs.append({
                "data_prefix": data_prefix,
                "model_name": model,
                "human_file_path": os.path.join(result_dir, "human", f"{split_type}_split_data_{collection_type}", set_type, f"{data_prefix}.csv"),
                "llm_file_path": os.path.join(result_dir, model, f"{split_type}_split_data_{collection_type}", set_type, f"{data_prefix}.csv"),
            })
    
    logger.info(f"Found {len(configs)} configs for Model: {model}; Collection Type: {collection_type}; Split Type: {split_type}; Set Type: {set_type}")
    
    all_examples = []
    
    # Process each CSV file
    for config in configs:
        logger.debug(f"Processing: {config}")
        examples = process_dp(**config)
        all_examples.extend(examples)
        logger.debug(f"  - Extracted {len(examples)} examples")
    
    # Write to JSONL format
    output_path = Path(output_file)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    with open(output_path, 'w', encoding='utf-8') as f:
        for example in all_examples:
            f.write(json.dumps(example, ensure_ascii=False) + '\n')
    
    logger.info(f"Generated DPO dataset with {len(all_examples)} examples")
    logger.info(f"Saved to: {output_file}")


if __name__ == "__main__":
    for model in ["gpt-4o-mini-2024-07-18", "gpt-4.1-nano-2025-04-14", "Llama-3.1-8B-Instruct"]:
        for collection_type in ["depth", "breadth"]:
            for split_type in ["round", "topic", "group"]:
                for set_type in ["train", "test"]:
                    generate_dpo_dataset(
                        model=model,
                        collection_type=collection_type,
                        split_type=split_type,
                        set_type=set_type
                    )
