from .data_model import TravelRequest
from utils.llms import AbstractLLM
import json
import re
from datetime import datetime, timedelta

def _get_annotation_json(dataclass_obj):
    annotations = getattr(dataclass_obj, '__annotations__', {})
    result = {}
    for key, typ in annotations.items():
        if isinstance(typ, str):
            result[key] = typ
        else:
            try:
                result[key] = typ.__name__
            except AttributeError:
                result[key] = str(typ)
    return result

class InstructionTranslator:
    """Translates between natural language and symbolic representations"""

    def __init__(self, **kwargs):
        """Initialize with specific LLM"""
        self.model: AbstractLLM = kwargs.get("backbone_llm") 

    def natural_to_symbolic(self, natural_text: str) -> TravelRequest:
        """Convert natural language request to symbolic form"""
        system_prompt = f"""
You are a travel request translator that converts natural language travel requests into a structured symbolic format.
The symbolic format follows this exact structure (TravelRequest dataclass):
{_get_annotation_json(TravelRequest)}
Return ONLY valid JSON that matches the exact structure above.
    """.strip()
        
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": natural_text}
        ]

        response = self.model._get_response(messages, True, False)
        
        if isinstance(response, str):
            json_match = re.search(r'\{.*\}', response, re.DOTALL)
            if json_match:
                try:
                    data = json.loads(json_match.group())
                except json.JSONDecodeError:
                    raise ValueError("Invalid JSON in response")
            else:
                raise ValueError("No JSON found in response")
        else: 
            raise ValueError(response)
        
        # Parse date strings to datetime objects
        if data.get('start_date') is None and data.get('maximum_num_of_days') is not None:
            data['start_date'] = datetime.now().strftime("%Y-%m-%d")
            data['end_date'] = (datetime.now() + timedelta(days=data['maximum_num_of_days'] -1 )).strftime("%Y-%m-%d")
        elif data.get('start_date') is not None and data.get('end_date') is not None:
            pass
        else:
            raise ValueError("start_date and end_date are required")
        
        if isinstance(data.get('start_date'), str):
            try:
                data['start_date'] = datetime.fromisoformat(data['start_date'].replace('Z', '+00:00'))
            except ValueError:
                # Try alternative date formats
                for fmt in ['%Y-%m-%d', '%Y-%m-%d %H:%M:%S', '%Y-%m-%dT%H:%M:%S']:
                    try:
                        data['start_date'] = datetime.strptime(data['start_date'], fmt)
                        break
                    except ValueError:
                        continue
                else:
                    raise ValueError(f"Could not parse start_date: {data['start_date']}")
        
        if isinstance(data.get('end_date'), str):
            try:
                data['end_date'] = datetime.fromisoformat(data['end_date'].replace('Z', '+00:00'))
            except ValueError:
                # Try alternative date formats
                for fmt in ['%Y-%m-%d', '%Y-%m-%d %H:%M:%S', '%Y-%m-%dT%H:%M:%S']:
                    try:
                        data['end_date'] = datetime.strptime(data['end_date'], fmt)
                        break
                    except ValueError:
                        continue
                else:
                    raise ValueError(f"Could not parse end_date: {data['end_date']}")
        
        # Convert must_have_pois list to set if needed
        if isinstance(data.get('must_have_pois'), list):
            data['must_have_pois'] = set(data['must_have_pois'])
        
        return TravelRequest(**data)

    def symbolic_to_natural(self, request: TravelRequest) -> str:
        """Convert symbolic form to natural language"""
        system_prompt = """
You are a travel request translator that converts structured symbolic travel requests into natural language.
Convert the given JSON travel request into a clear, natural language description that a human would understand.
Include all relevant details like origin, destinations, dates, budget, and any constraints in a conversational tone.
""".strip()
        
        # Use the to_dict method for proper serialization
        request_dict = TravelRequest.to_dict(request)
        user_query = json.dumps(request_dict, ensure_ascii=False)
        
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_query}
        ]
        
        response = self.model._get_response(messages, True, False)
        return response

    def evaluate_translation(self, original: TravelRequest, 
                            translated: TravelRequest) -> float:
        """Calculate basic similarity score between original and translated requests"""
        fields = list(TravelRequest.__annotations__.keys())
        matching_fields = 0
        for field in fields:
            original_value = getattr(original, field)
            translated_value = getattr(translated, field)
            if original_value == translated_value:
                matching_fields += 1
        return matching_fields / len(fields)