from dataclasses import dataclass, field
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Literal, TypeAlias

SupportedModels: TypeAlias = Literal["gpt-4o", "o1"]

@dataclass
class SymbolDefinition:
    filePath: str
    start: int
    end: int
    name: str
    body: str

def dictToSymbolDefinition(d: dict) -> SymbolDefinition:
    return SymbolDefinition(
        filePath=d["filePath"],
        start=d["start"],
        end=d["end"],
        name=d["name"],
        body=d["body"],
    )

@dataclass
class GlobalCtxAgentState:
    """Dataclass to store the state of the agent."""
    # Used during the React Loop
    feedback: str = ""
    steps: int = 0
    openDefinitions: Dict[str, List[SymbolDefinition]] = field(default_factory=lambda: defaultdict(list))
    prevQueries: List[Tuple[str, List[str]]] = field(default_factory=list) # (query, results) from the previous step
    pastQueries: List[Tuple[str, List[str]]] = field(default_factory=list) # (query, results) from all previous steps
    memory: List[str] = field(default_factory=list)
    justification: Optional[str] = None
    promptTokens: int = 0
    completionTokens: int = 0

def dictToGlobalCtxAgentState(d: dict) -> GlobalCtxAgentState:
    # Convert openDefinitions: each key maps to a list of SymbolDefinition objects.
    openDefinitions: Dict[str, List[SymbolDefinition]] = {}
    for filePath, symbolDefs in d['openDefinitions'].items():
        openDefinitions[filePath] = [dictToSymbolDefinition(sDef) for sDef in symbolDefs]
    prevQueries: List[Tuple[str, List[str]]] = []
    for [q, results] in d['prevQueries']:
        prevQueries.append((q, results))
    pastQueries: List[Tuple[str, List[str]]] = []
    for [q, results] in d['pastQueries']:
        pastQueries.append((q, results))
    
    return GlobalCtxAgentState(
        feedback=d['feedback'],
        steps=d['steps'],
        openDefinitions=openDefinitions,
        prevQueries=prevQueries,
        pastQueries=pastQueries,
        memory=d['memory'],
        justification=d['justification'],
        promptTokens=d['promptTokens'],
        completionTokens=d['completionTokens'],
    )

@dataclass
class GlobalCtxAgentResult:
    steps: int
    relevanceExplanation: str
    relevantFunctions: List[SymbolDefinition]
    relevantDefinitions: Dict[str, List[SymbolDefinition]]
    relevantQueries: List[Tuple[str, str]]
    promptTokens: int
    completionTokens: int
    finalReactState: GlobalCtxAgentState

def dictToGlobalCtxAgentResult(result: dict) -> GlobalCtxAgentResult:
    """
    Convert a dictionary (obtained from callling dataclasses.asdict on a GlobalCtxAgentResult object) to a GlobalCtxAgentResult object.
    """
    # Convert the list of SymbolDefinition dictionaries
    relevantFunctions = [dictToSymbolDefinition(item) for item in result['relevantFunctions']]

    # Convert the dictionary of SymbolDefinition lists 
    relevantDefinitions = {
        key: [dictToSymbolDefinition(sd) for sd in lst]
        for key, lst in result['relevantDefinitions'].items()
    }

    # Convert finalReactState dictionary
    finalReactState = dictToGlobalCtxAgentState(result["finalReactState"])

    return GlobalCtxAgentResult(
        steps=result['steps'],
        relevanceExplanation=result['relevanceExplanation'],
        relevantFunctions=relevantFunctions,
        relevantDefinitions=relevantDefinitions, 
        relevantQueries=[(q, res) for [q, res] in result['relevantQueries']],
        promptTokens=result['promptTokens'],
        completionTokens=result['completionTokens'],
        finalReactState=finalReactState,
    )
