from LLMAgent import LLMBaseAgent
import json
import os
import time
import uuid
from typing import List, Tuple, Dict, Optional, Literal
from collections import Counter
from pydantic import BaseModel, ValidationError, Field, model_validator
from collections import defaultdict
from faiss_store import FaissStore
from MentalModelTypes import Hypothesis, Evidence, Message
from Evaluator import Confidence, committee_majority_vote

# Flag to control whether to use structured output (beta API)
USE_STRUCTURED_OUTPUT = True  # Set to False to avoid 404 errors on Azure

dim2desc = {
    1: """
        1. Knowledge/Expertise Assessment​
        - What: The other party's domain knowledge, technical vocabulary, or familiarity with relevant procedures.​
        - Why: Shapes how much detail you need to provide or withhold.​
    """.strip(),
    2: """
        2. Request & Behavior Analysis​
        - What: The type of information they're asking for, how often, and how urgently.​
        - Why: Identifying whether requests align with normal operational needs or hint at ulterior motives.
    """.strip(),
    3: """
        3. Motive & Trust Inference​
        - What: Analyzing whether the other party is overreaching or truly needs the data for legitimate reasons.​
        - Why: Guides how open or defensive you should be.​
    """.strip(),
    4: """
        4. Strategic Direction / Policy Implication​
        - What: High-level approach (e.g., tighten disclosures, offer partial data, escalate to legal, propose NDAs).​
        - Why: Ensures alignment with IP protection goals and the conversation's collaboration objectives.​
    """.strip(),
    5: """
        5. Information Gaps & Next Steps​
        - What: Additional data, clarifications, or verifications.​
        - Why: Allows you to refine your inferences, test assumptions, or confirm your partner's authenticity before making a decision.​
    """.strip()
    ,
    6: """
        6. Privacy/Sensitivity Assessment
        - What: Evaluate whether Party A's to‑be‑shared information is sensitive, and classify sensitivity level (Low/Medium/High).
        - Why: Decide if the information must be disclosed, can be abstracted/deferred, or should be withheld; weigh expected benefits against potential harm to Party A.
    """.strip()
}

HypFaiss: FaissStore = FaissStore()
Hyp2Evi: Dict[str, List[Evidence]] = defaultdict(list) # Hypothesis ID to List[Evidence]

def strip_code_fences(text: str) -> str:
    """Remove triple backtick code fences (e.g., ```json ... ```), and trim whitespace."""
    if not isinstance(text, str):
        return text
    stripped = text.strip()
    if stripped.startswith("```"):
        # Remove leading fence with optional language tag
        first_newline = stripped.find("\n")
        if first_newline != -1:
            stripped = stripped[first_newline+1:]
        # Remove trailing fence if present
        if stripped.endswith("```"):
            stripped = stripped[:-3]
    return stripped.strip()

def parse_hypothesis_action_from_text(text: str) -> "HypothesisAction":
    """Best-effort parser: strip code fences, load JSON, clean fields, then build model."""
    cleaned_text = strip_code_fences(text)
    try:
        data = json.loads(cleaned_text)
    except Exception:
        # Try to find JSON blob heuristically
        start = cleaned_text.find('{')
        end = cleaned_text.rfind('}')
        if start != -1 and end != -1 and end > start:
            data = json.loads(cleaned_text[start:end+1])
        else:
            raise

    if not isinstance(data, dict):
        raise ValueError("Parsed content is not a JSON object")

    decision_type = data.get('decision')
    if decision_type == 'new':
        data.pop('target_hypothesis_id', None)
        data.pop('updated_hypothesis_desc', None)
    elif decision_type == 'merge':
        data.pop('new_hypothesis_desc', None)
    return HypothesisAction(**data)


class ContentTag(BaseModel):
    dimension: int
    text: str

class ContentTags(BaseModel):
    chunks: List[ContentTag]

class HypothesisAction(BaseModel):
    """
    Parsed payload for `update_hypothesis_function`.

    * decision: `"new"`  → create a fresh hypothesis  
                 `"merge"` → fold the chunk into an existing hypothesis

    * target_hypothesis_id / updated_hypothesis_desc  
        Required **only** when decision == "merge"

    * new_hypothesis_desc  
        Required **only** when decision == "new"
    """
    decision: Literal["new", "merge"] = Field(
        ..., description="Action to take with the new chunk"
    )
    # --- Fields for MERGE ---
    target_hypothesis_id: Optional[str] = Field(
        None, description="ID of an existing hypothesis to merge into"
    )
    updated_hypothesis_desc: Optional[str] = Field(
        None,
        description="(MERGE) New description after absorbing the chunk, "
        "if the hypothesis text should be updated",
    )
    # --- Field for NEW ---
    new_hypothesis_desc: Optional[str] = Field(
        None,
        description="(NEW) Description for the brand‑new hypothesis derived from the chunk",
    )
    # ----------------- conditional validation -----------------
    @model_validator(mode="after")
    def _enforce_conditional_fields(self):
        if self.decision == "merge":
            if not self.target_hypothesis_id:
                raise ValueError(
                    "target_hypothesis_id is required when decision == 'merge'"
                )
            # forbid new‑only field
            if self.new_hypothesis_desc is not None:
                raise ValueError(
                    "new_hypothesis_desc must be omitted when decision == 'merge'"
                )
        else: # decision == "new"
            if not self.new_hypothesis_desc:
                raise ValueError(
                    "new_hypothesis_desc is required when decision == 'new'"
                )
            # forbid merge‑only fields
            if self.target_hypothesis_id is not None or self.updated_hypothesis_desc is not None:
                raise ValueError(
                    "target_hypothesis_id and updated_hypothesis_desc must be omitted "
                    "when decision == 'new'"
                )
        return self

# ------------------------------------------------------
# Tagger LLM: Splits message into dimension-tagged chunks
# ------------------------------------------------------
def tag_message(
    party_a: str,
    party_b: str,
    relationship: str,
    background_context: str,
    earlier_conversation: List[Message],
    last_message: Message,
    agent: LLMBaseAgent,
) -> ContentTags:
    """
    Calls the LLM to split the message into dimension-tagged chunks.
    Returns a list of (dimension_id, chunk_text).
    """
    understanding_mental_model_dims = [dim2desc[i] for i in range(1, 4)]
    earlier_convo_block = "\n".join([
        f"{(getattr(m, 'sender', m.get('speaker', '')))}: {(getattr(m, 'text', m.get('content', '')))}"
        for m in earlier_conversation
    ])
    messages = [
        {
            "role": "system",
            "content": f"""
                [Task] Chunk the last message from Party B ({party_b}) from the perspective of
                Party A ({party_a}) along the {len(understanding_mental_model_dims)} mental model dimensions.
                Party A and B are engaged in a conversation about {background_context}, and their relationship is:
                {relationship}.

                [Understanding Mental Model Dimensions]
                {understanding_mental_model_dims}
            
                Return JSON with "chunks": [{{"dimension": int, "text": str}}, ...].

                *Important*
                - Earlier conversation is for contextualizing the last message only. DO NOT CHUNK.
                - Only divide the last message shown in the end into coherent chunks.
            """
        },
        {
            "role": "user",
            "content": f"""
                [Earlier conversation for context (for context only)]
                {earlier_convo_block}

                [The message text to chunk]
                {last_message.text}
            """
        }]
    # Retry until successful format matching
    while True:
        try:
            if USE_STRUCTURED_OUTPUT:
                resp = agent.generate(messages, ret=ContentTags)
                
                # Check if we got structured output or regular text
                if hasattr(resp.choices[0].message, 'parsed') and resp.choices[0].message.parsed:
                    tags = resp.choices[0].message.parsed
                else:
                    # Fallback: parse JSON from content
                    content = resp.choices[0].message.content
                    cleaned = strip_code_fences(content)
                    data = json.loads(cleaned)
                    tags = ContentTags(**data)
            else:
                # Don't use structured output to avoid 404 errors
                # Add JSON instruction to the message
                if messages[-1]['role'] == 'user':
                    messages[-1]['content'] += '\n\nReturn your response in valid JSON format matching this structure: {"chunks": [{"dimension": <int>, "text": "<string>"}, ...]}'
                
                resp = agent.generate(messages)
                content = resp.choices[0].message.content
                cleaned = strip_code_fences(content)
                data = json.loads(cleaned)
                tags = ContentTags(**data)
            
            # Verify format is correct
            if (hasattr(tags, 'chunks') and 
                isinstance(tags.chunks, list) and
                all(hasattr(chunk, 'dimension') and hasattr(chunk, 'text') for chunk in tags.chunks)):
                return tags
            else:
                print("⚠️ Invalid format for ContentTags, retrying...")
                
        except Exception as e:
            print(f"⚠️ Error generating ContentTags: {e}, retrying...")
            continue

# ------------------------------------------------------
# Confidence Committee
# ------------------------------------------------------
def evaluate_confidence(
    hypothesis: Hypothesis,
    evidence: List[Evidence],
    agent: LLMBaseAgent,
) -> int:
    """
    Gather all evidence, ask LLM for a confidence in [1, 6].
    """
    evidence_joined = "\n".join([ev.text for ev in evidence])
    # Adjust subject based on dimension group: 1-3 about Party B, 4-6 about Party A
    subject = "the other party's" if hypothesis.dimension_id in [1, 2, 3] else "Party A's"

    messages = [
        {
            "role": "system",
            "content": f"""
                Given a hypothesis and its set of evidence, evaluate the confidence in the hypothesis.
                The hypothesis is a statement about {subject} {dim2desc[hypothesis.dimension_id]}.
                The evidence is a collection of statements or observations that support or contradict the hypothesis.
                The confidence is a number between 1 and 6, where:
                - 1 indicates "very low confidence (<10%)",
                - 2 indicates "low confidence (~20%)",
                - 3 indicates "medium confidence (~30%)",
                - 4 indicates "medium-high confidence (~50%)",
                - 5 indicates "high confidence (~75%)",
                - and 6 indicates "very high confidence (~90%)".
                The confidence should be based on the strength of the evidence and the clarity of the hypothesis.
                The evidence may be ambiguous or contradictory, and the hypothesis may be uncertain or speculative.

                Return ONLY a raw JSON object with exactly one field and nothing else (no prose, no code fences).
                For example:
                {{"confidence": 5}}
            """
        },
        {
            "role": "user",
            "content": f"""
                [Hypothesis]
                {hypothesis.description}

                [Evidence]
                {evidence_joined}
            """
        }
    ]

    # 3-LLM committee with strict JSON parsing and retry on invalid range
    committee_size = 3
    prompts = [messages] * committee_size
    default_conf = 3

    def _strip_fences(text: str) -> str:
        text = (text or "").strip()
        if text.startswith("```"):
            nl = text.find("\n")
            if nl != -1:
                text = text[nl+1:]
            if text.endswith("```"):
                text = text[:-3]
        return text.strip()

    while True:
        resps = agent.generate_batch(prompts)
        scores: List[int] = []
        for resp in resps:
            content = _strip_fences(resp.choices[0].message.content)
            try:
                data = json.loads(content)
                val = int(data.get("confidence", default_conf))
                if 1 <= val <= 6:
                    scores.append(val)
                else:
                    raise ValueError("out of range")
            except Exception:
                scores = []
                break
        if len(scores) == committee_size:
            final_conf = committee_majority_vote(scores, committee_size)
            return final_conf

# ------------------------------------------------------
# Hypothesis Updater LLM
# ------------------------------------------------------
def update_or_create_hypothesis(
    party_a: str,
    party_b: str,
    relationship: str,
    background_context: str,
    dimension_id: int,
    # chunk_text: str,
    # message: Message,
    source_text: str,
    source_type: Literal["message", "hypothesis"],
    source_id: str,
    candidate_hypotheses: List[Hypothesis],
    agent: LLMBaseAgent,
) -> Tuple[HypothesisAction, Hypothesis, List[Evidence]]:
    """Generic routine for merging / spawning hypotheses, agnostic to source origin."""
    # Perspective and extra guidance depend on dimension groups
    if dimension_id in [1, 2, 3]:
        perspective_line = (
            f"The hypothesis should be written in a concise manner from the perspective of Party A ({party_a}) "
            f"about Party B ({party_b})."
        )
    else:
        perspective_line = (
            f"The hypothesis should be written in a concise manner from the perspective of Party A ({party_a}) "
            f"about Party A's own disclosure policy, next steps, or privacy considerations."
        )
    if len(candidate_hypotheses) < 1:
        messages = [
            {
                "role": "system",
                "content": f"""
                    Write a concise new hypothesis that is supported/evidenced by the text.
                    {perspective_line}
                    It should belong in the following category: {dim2desc[dimension_id]}.
                    
                    Party A and B are engaged in a conversation about {background_context}, and their relationship is:
                    {relationship}.

                    Return ONLY a raw JSON object with the following two fields, and nothing else. Do NOT include code fences, backticks, markdown, or any extra text:
                    - "decision": "new"
                    - "new_hypothesis_desc": a concise description of the newly generated hypothesis
                """
            },
            {
                "role": "user",
                "content": f"""[Text]\n{source_text}"""
            }
        ]
    else:
        candidate_strings = "\n".join([
            f"(ID: {h.hypothesis_id}) Description: {h.description}"
            for h in candidate_hypotheses
        ])
        messages = [
            {
                "role": "system",
                "content": f"""
                    Decide for the text if it:
                    - should merge into an existing hypothesis as evidence (if it obviously fits, either as supporting or refuting).
                    - should create a new hypothesis, as it doesn't fit any of the existing hypotheses.
                    In case of "new", write a concise new hypothesis statement that is evidenced by the text.

                    {perspective_line}
                    Party A and B are engaged in a conversation about {background_context}.
                    Their relationship is: {relationship}.
                    
                    The hypothesis considered belongs (and should belong) in the following category: {dim2desc[dimension_id]}.

                    Return ONLY a raw JSON object with these fields and nothing else (no code fences, backticks, markdown, or extra prose):
                    - If merging: {{"decision": "merge", "target_hypothesis_id": "<ID>", "updated_hypothesis_desc": "<updated text>"}}
                    - If new: {{"decision": "new", "new_hypothesis_desc": "<new hypothesis>"}}
                """
            },
            {
                "role": "user",
                "content": f"""
                    Existing hypotheses:\n{candidate_strings}\n\n[Text]\n{source_text}
                """
            }
        ]

    # Retry until successful format matching (unstructured → manual parse → validate)
    while True:
        try:
            raw_resp = agent.generate(messages)  # no structured schema; avoid early validation
            content = raw_resp.choices[0].message.content
            decision: HypothesisAction = parse_hypothesis_action_from_text(content)
        except Exception as e2:
            print(f"⚠️ Error generating HypothesisAction: {e2}, retrying...")
            continue

        # Verify minimal correctness
        if (hasattr(decision, 'decision') and 
            decision.decision in ["new", "merge"]):
            
            # Additional validation based on decision type
            if decision.decision == "merge":
                if hasattr(decision, 'target_hypothesis_id') and decision.target_hypothesis_id:
                    break
                else:
                    print("⚠️ Invalid format for HypothesisAction (merge without target_hypothesis_id), retrying...")
            else:  # new
                if hasattr(decision, 'new_hypothesis_desc') and decision.new_hypothesis_desc:
                    break
                else:
                    print("⚠️ Invalid format for HypothesisAction (new without new_hypothesis_desc), retrying...")
        else:
            print("⚠️ Invalid format for HypothesisAction, retrying...")
    
    evidence = Evidence(text=source_text, source_type=source_type, source_id=source_id)

    # Apply the decision
    if decision.decision == "merge": # merge into an existing hypothesis
        target_id = decision.target_hypothesis_id
        Hyp2Evi[target_id].append(evidence)

        hyp = HypFaiss.get(dimension_id, target_id)
        if hyp is None:
            raise ValueError(f"Cannot find hypothesis {target_id} in dimension {dimension_id}")
        
        hyp.update_description(decision.updated_hypothesis_desc)
        hyp.update_confidence(evaluate_confidence(hyp, Hyp2Evi[target_id], agent))
        HypFaiss.update(hyp)

    else: # new hypothesis
        hyp = Hypothesis(
            dimension_id=dimension_id,
            description=decision.new_hypothesis_desc,
            confidence=3,
        )
        hyp.update_confidence(evaluate_confidence(hyp, [evidence], agent))
        Hyp2Evi[hyp.hypothesis_id].append(evidence)
        HypFaiss.insert(hyp)

    return decision, hyp, Hyp2Evi[hyp.hypothesis_id]

def process_future_forward_hypotheses(
    party_a: str,
    party_b: str,
    relationship: str,
    background_context: str,
    agent: LLMBaseAgent,
) -> List[Tuple[HypothesisAction, Hypothesis, List[Evidence]]]:
    """Leverages understanding‑dimension hypotheses as evidence for future‑forward dimensions."""
    # Step 3: Based on the understanding hypotheses, update or create hypotheses along the two
    #         future-forward dimensions
    # --> Maybe there is an iterative RAG loop here (generate initial hypotheses about next steps
    #     using existing hypotheses, calculate coverage & quality, then iterate until both metrics are
    #     saturated). How to define the "quality"? Is conciseness part of it?
    # Alternatively, if performance is a concern, we can summarize the hypotheses into a more concise form
    # first, before supplying it for new hypothesis generation.
    understanding_dims, future_dims = [1,2,3], [4,5,6]
    understanding_hyps_by_dim = HypFaiss.list_all(understanding_dims)
    future_forward_hyps_by_dim = HypFaiss.list_all(future_dims)

    results: List[Tuple[HypothesisAction, Hypothesis, List[Evidence]]] = []
    all_understanding_hyps = [
        hyp for hyps in understanding_hyps_by_dim.values() for hyp in hyps
    ]
    # If we have no understanding hypotheses yet, skip future-forward generation to avoid empty prompts
    if len(all_understanding_hyps) == 0:
        return results
    for dim_id in future_dims:
        decision, new_hyp, evs = update_or_create_hypothesis(
            party_a=party_a,
            party_b=party_b,
            relationship=relationship,
            background_context=background_context,
            dimension_id=dim_id,
            source_text='\n'.join([h.description for h in all_understanding_hyps]),
            source_type="hypothesis",
            source_id=str(uuid.uuid4()),
            candidate_hypotheses=future_forward_hyps_by_dim.get(dim_id, []),
            agent=agent,
        )
        future_forward_hyps_by_dim[dim_id].append(new_hyp)
        results.append((decision, new_hyp, evs))

    return results

def tag_n_upsert(
    party_a: str,
    party_b: str,
    relationship: str,
    background_context: str,
    earlier_conversation: List[Dict],
    last_message: Dict,
    agent: LLMBaseAgent,
) -> Tuple[List[HypothesisAction], List[Hypothesis], List[List[Evidence]]]:
    """
    1. Call the tagger LLM to split the last message into dimension-tagged chunks.
    2. For each chunk, call the hypothesis updater LLM to decide how to incorporate it.
    3. Process future-forward hypotheses based on understanding dimensions.
    4. Return decisions, hypotheses, and evidence lists.
    """
    # Step 1: Tag the message along the three understanding schema dimensions
    last_message_cast = Message(
        text=last_message['content'], sender=party_b
    )
    tagged_chunks: List[ContentTag] = tag_message(
        party_a, party_b, relationship, background_context,
        earlier_conversation, last_message_cast, agent
    ).chunks

    # Step 2: Update or create "understanding" hypotheses based on the tagged content
    decisions, hyps, evs = [], [], []
    for chunk in tagged_chunks:
        dimension_id = chunk.dimension
        # Skip empty chunks to avoid prompting the model to ask for text
        if not getattr(chunk, 'text', '').strip():
            continue
        candidate_hypotheses = HypFaiss.list_all_one(dimension_id)
        decision, hyp, _evs = update_or_create_hypothesis(
            party_a, party_b, relationship, background_context, dimension_id,
            chunk.text, "message", str(uuid.uuid4()), candidate_hypotheses, agent
        )
        decisions.append(decision)
        hyps.append(hyp)
        evs.append(_evs)

    # Step 3: Process "future-forward" hypotheses based on the understanding dimensions
    res = process_future_forward_hypotheses(party_a, party_b, relationship, background_context, agent)
    for decision, hyp, _evs in res:
        decisions.append(decision)
        hyps.append(hyp)
        evs.append(_evs)

    return decisions, hyps, evs