from typing import Any, Dict, Tuple
import json
import random
from datetime import datetime, timedelta

from utils.poi_analyzer import POIAnalyzer
from .data_model import TravelRequest

class TravelRequestGenerator:
    """Generates synthetic travel data for fine-tuned LLMTranslator training"""

    def __init__(self, **kwargs):
        self.poi_analyzer: POIAnalyzer = kwargs.get("poi_analyzer", POIAnalyzer())

    def populate_data(self, poi_dict: Dict[str, Any]):
        poi_dict_json = self._parse_poi_dict(poi_dict)
        if poi_dict_json is None:
            raise ValueError("POI字典解析失败")
        self.poi_analyzer.load_pool_from_dict(poi_dict_json)

    def _parse_poi_dict(self, poi_dict: Any) -> Dict[str, Any]:
        """
        parse POI dictionary
        
        Args:
            poi_dict: POI dictionary string
            
        Returns:
            parsed POI dictionary
        """
        if isinstance(poi_dict, dict):
            return poi_dict

        try:
            return json.loads(poi_dict)
        except Exception as e:
            print(f"POI dictionary parsing failed: {e}")
            return None

    def generate_travel_request(self, locale: str = "zh-CN") -> TravelRequest:
        """Generate symbolic travel request using POI analyzer pools"""
        
        # Generate random dates (1-7 days from now)
        start_date = datetime.now() + timedelta(days=random.randint(1, 30))
        end_date = start_date + timedelta(days=random.randint(1, 7))
        
        # Get cities from POI pool
        poi_cities = {}
        for poi_id in self.poi_analyzer.poi_pool:
            poi_info, _, _ = self.poi_analyzer.read_poi_api_info(poi_id, locale)
            city_name = poi_info.get("districtName")
            if not city_name:
                continue
            if city_name not in poi_cities:
                poi_cities[city_name] = []
            # Use appropriate name based on locale
            poi_name = poi_info.get("cname", "") if locale.startswith("zh") else poi_info.get("ename", "")
            poi_cities[city_name].append((poi_id, poi_name))
            if len(poi_cities) >= 5:  # Get up to 5 different cities
                break
        
        destination_cities = random.sample(list(poi_cities.keys()), min(2, len(poi_cities))) if poi_cities else []
        
        # Get POIs from selected cities
        pois_in_cities = []
        for city in destination_cities:
            pois_in_cities.extend(poi_cities.get(city, []))
        
        # Get random POIs as must-have attractions (using names)
        selected_pois = random.sample(pois_in_cities, min(3, len(pois_in_cities))) if pois_in_cities else []
        must_have_pois = set(poi_name for _, poi_name in selected_pois if poi_name)
        
        return TravelRequest(destination_cities_ordered=destination_cities, start_date=start_date, 
                           end_date=end_date, must_have_pois=must_have_pois)