import json
import os
from typing import Tuple, List, Dict, Any


def load_query(args) -> List[Dict[str, Any]]:
    """
    Load query data based on arguments, compatible with different data sources
    Returns data in the same format as load_test_data for consistency
    
    Args:
        args: Arguments containing splits and other parameters
        
    Returns:
        List[Dict[str, Any]]: List of dictionaries with structure:
        {
            "poi_dict": {...},
            "messages": [...],
            "case_index": int,
            "original_case": {...}
        }
    """
    
    def read_json(file_path):
        """read complete JSON file"""
        data = []
        if file_path.endswith(".jsonl"):
            with open(file_path, 'r', encoding='utf-8') as f:
                for line in f:
                    data.append(json.loads(line))
        else:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        return data
    
    def read_jsonl(file_path):
        """read JSONL file, each line is a JSON object"""
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:  # Skip empty lines
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"Error parsing line in {file_path}: {line[:100]}... Error: {e}")
                        continue
        return data
    
    # Determine the data file based on splits
    if args.splits == "synthesis":
        data_file = "data_preprocess/Synthesis_data_cleaned.jsonl"
        data_source = "synthesis"
    elif args.splits == "generalized":
        data_file = "data_preprocess/Generalization_data_cleaned.jsonl"
        data_source = "generalized"
    elif args.splits == "custom":
        data_file = args.custom_data_file
        data_source = "custom"
    else:
        data_file = "test_cases.json"  # default
        data_source = "test"
    
    # Get the project root directory
    project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    file_path = os.path.join(project_root, data_file)
    
    print(f"Loading data from: {file_path}")
    print(f"Data source type: {data_source}")
    
    # Load the data
    try:
        if os.path.exists(file_path):
            # Determine file format and read accordingly
            if file_path.endswith('.jsonl'):
                print(f"Reading JSONL format: {file_path}")
                data = read_jsonl(file_path)
            else:
                print(f"Reading JSON format: {file_path}")
                data = read_json(file_path)

            
            if data_source in ["synthesis", "generalized"]:
                # Handle JSONL synthesis/generalized data format
                print(f"Processing JSONL format: {data_file} (source: {data_source})")
                
                if isinstance(data, list):
                    test_data_list = []
                    
                    for i, case in enumerate(data):
                        try:
                            # extract data - support different field names
                            userQuery = case.get("userQuery", case.get("query", ""))
                            try:
                                day = int(case.get("day", case.get("days", 3)))
                            except (ValueError, TypeError):
                                day = 3
                            departure = case.get("departure", case.get("origin", ""))
                            arrive = case.get("arrive", case.get("destination", ""))
                            transportation = case.get("transportation", "Yes")
                            transport_pool = case.get("transport_pool", "{}")
                            
                            # Handle preference field which might be string or dict
                            preference_raw = case.get("preference", "{}")
                            if isinstance(preference_raw, dict):
                                preference = preference_raw
                            else:
                                try:
                                    preference = json.loads(preference_raw)
                                except (json.JSONDecodeError, TypeError):
                                    preference = {}

                            # Handle messages field
                            messages_raw = case.get("messages", "[]")
                            if isinstance(messages_raw, list):
                                messages = messages_raw
                            else:
                                try:
                                    messages = json.loads(messages_raw)
                                except (json.JSONDecodeError, TypeError):
                                    messages = []
                                
                            locale = case.get("locale", "en-US")
                            poi_pool = case.get("poi_pool", "")
                            hotel_pool = case.get("hotel_pool", "")

                            # build poi_dict
                            poi_dict = {
                                "userQuery": userQuery,
                                "day": day,
                                "locale": locale,
                                "departure": departure,
                                "arrive": arrive,
                                "transportation": transportation,
                                "reference": messages[-1]["content"].split("[User Query]")[0].strip() if messages else "",
                                "transport_pool": transport_pool,
                                "preference": preference,
                                "poi_pool": poi_pool,
                                "hotel_pool": hotel_pool
                            }
                            
                            test_data_list.append({
                                "poi_dict": poi_dict,
                                "messages": messages,
                                "case_index": i + 1,
                                "original_case": case,
                                "message_id": case.get('message_id', case.get('id', f'{data_source}_{i}'))
                            })
                            
                        except Exception as e:
                            print(f"Error processing {data_source} case {i}: {e}")
                            continue
                    # test_data_list = test_data_list[:5]
                    return test_data_list
                else:
                    # Fallback for unexpected format
                    return [{
                        "poi_dict": {"userQuery": f"{data_source.title()} fallback", "day": 3},
                        "messages": [],
                        "case_index": 1,
                        "original_case": data,
                        "message_id": f"{data_source}_fallback"
                    }]
            elif data_source == "custom":
                # Handle test_cases.json or part_train_cases.json format (multiple cases)
                print(f"Processing {data_file} format")
                
                if isinstance(data, list):
                    # List format
                    test_data_list = []
                    
                    for i, case in enumerate(data):
                        try:
                            # extract data
                            userQuery = case.get("userQuery", "")
                            try:
                                day = int(case.get("day", 3))
                            except (ValueError, TypeError):
                                day = 3
                            departure = case.get("departure", "")
                            arrive = case.get("arrive", "")
                            transportation = case.get("transportation", "Yes")
                            transport_pool = case.get("transport_pool", "{}")
                            try:
                                preference = json.loads(case.get("preference", "{}"))
                            except (json.JSONDecodeError, TypeError):
                                preference = {}

                            try:
                                messages = json.loads(case.get("request", "[]"))['messages']
                            except (json.JSONDecodeError, TypeError):
                                messages = []
                                
                            locale = case.get("locale", "en-US")
                            poi_pool = case.get("poi_pool", "")
                            hotel_pool = case.get("hotel_pool", "")

                            case['day'] = day
                            # build poi_dict
                            poi_dict = {
                                "userQuery": userQuery,
                                "day": day,
                                "locale": locale,
                                "departure": departure,
                                "arrive": arrive,
                                "transportation": transportation,
                                "reference": messages[-1]["content"].split("[User Query]")[0].strip() if messages else "",
                                "transport_pool": transport_pool,
                                "preference": preference,
                                "poi_pool": poi_pool,
                                "hotel_pool": hotel_pool
                            }
                            
                            test_data_list.append({
                                "poi_dict": poi_dict,
                                "messages": messages,
                                "case_index": i + 1,
                                "original_case": case,
                                "message_id": case.get('message_id', f'case_{i}')
                            })
                            
                        except Exception as e:
                            print(f"Error processing case {i}: {e}")
                            continue
                    # test_data_list = test_data_list[:5]
                    return test_data_list

            else:
                # Unknown format, use fallback
                return [{
                    "poi_dict": {"userQuery": "Unknown format query", "day": 1},
                    "messages": [],
                    "case_index": 1,
                    "original_case": data,
                    "message_id": "unknown"
                }]
            
                
        else:
            # Create dummy data if file doesn't exist
            print(f"File {file_path} not found, creating dummy data")
            dummy_case = {
                "message_id": "dummy_001",
                "userQuery": "Plan a 3-day trip to Beijing",
                "day": 3,
                "departure": "Shanghai",
                "arrive": "Beijing",
                "transportation": "Yes",
                "transport_pool": "{}",
                "hotel_pool": "{}",
                "poi_pool": "{}",
                "reference": "{}",
                "messages": "[]",
                "locale": "en-US",
                "context_id": ""
            }
            
            poi_dict = {
                "userQuery": dummy_case["userQuery"],
                "day": 3,
                "locale": dummy_case["locale"],
                "departure": dummy_case["departure"],
                "arrive": dummy_case["arrive"],
                "transportation": dummy_case["transportation"],
                "reference": "",
                "transport_pool": dummy_case["transport_pool"],
                "preference": {},
                "poi_pool": dummy_case["poi_pool"],
                "hotel_pool": dummy_case["hotel_pool"]
            }
            
            return [{
                "poi_dict": poi_dict,
                "messages": [],
                "case_index": 1,
                "original_case": dummy_case,
                "message_id": "dummy_001"
            }]
            
    except Exception as e:
        print(f"Error loading data: {e}")
        # Fallback dummy data
        dummy_case = {
            "message_id": "error_fallback",
            "userQuery": "Plan a sample trip",
            "day": 2,
            "departure": "City A",
            "arrive": "City B",
            "transportation": "Yes",
            "transport_pool": "{}",
            "hotel_pool": "{}",
            "poi_pool": "{}",
            "reference": "{}",
            "messages": "[]",
            "locale": "en-US",
            "context_id": ""
        }
        
        poi_dict = {
            "userQuery": dummy_case["userQuery"],
            "day": 2,
            "locale": dummy_case["locale"],
            "departure": dummy_case["departure"],
            "arrive": dummy_case["arrive"],
            "transportation": dummy_case["transportation"],
            "reference": "",
            "transport_pool": dummy_case["transport_pool"],
            "preference": {},
            "poi_pool": dummy_case["poi_pool"],
            "hotel_pool": dummy_case["hotel_pool"]
        }
        
        return [{
            "poi_dict": poi_dict,
            "messages": [],
            "case_index": 1,
            "original_case": dummy_case,
            "message_id": "error_fallback"
        }]


def save_json_file(json_data: Dict[str, Any], file_path: str):
    """
    Save data to JSON file
    
    Args:
        json_data: Data to save
        file_path: Path to save the file
    """
    try:
        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        
        # Save the file
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(json_data, f, ensure_ascii=False, indent=2)
            
    except Exception as e:
        print(f"Error saving file {file_path}: {e}")

