"""
Reasoners for data visualization task disambiguation.
Implements TAI (Topological Active Inference) for Ambi-Plot tasks.
"""
import sys
import os
import re
import numpy as np
import pandas as pd
from collections import defaultdict
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_distances

currentdir = os.path.dirname(os.path.abspath(__file__))
parentdir = os.path.dirname(currentdir)
rootdir = os.path.dirname(parentdir)
sys.path.insert(0, parentdir)

from utils import chat_gpt, obtain_cost, get_embeddings
from viz_utils import (
    extract_code_from_response,
    normalize_viz_code,
    extract_viz_features,
    features_to_text,
    detect_chart_type,
    detect_library,
)


# ============================================================
# Topological Utilities (same as code-generation)
# ============================================================

class UnionFind:
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n

    def find(self, x):
        while self.parent[x] != x:
            self.parent[x] = self.parent[self.parent[x]]
            x = self.parent[x]
        return x

    def union(self, a, b):
        ra = self.find(a)
        rb = self.find(b)
        if ra == rb:
            return False
        if self.rank[ra] < self.rank[rb]:
            self.parent[ra] = rb
        elif self.rank[ra] > self.rank[rb]:
            self.parent[rb] = ra
        else:
            self.parent[rb] = ra
            self.rank[ra] += 1
        return True


def _compute_persistence_deaths(dist_matrix):
    """Compute death times via MST-based filtration."""
    n = dist_matrix.shape[0]
    edges = []
    for i in range(n):
        for j in range(i + 1, n):
            edges.append((dist_matrix[i, j], i, j))
    edges.sort(key=lambda x: x[0])

    uf = UnionFind(n)
    death_times = []
    for dist, i, j in edges:
        if uf.union(i, j):
            death_times.append(dist)
            if len(death_times) == n - 1:
                break
    return death_times


def _select_tau_from_gaps(death_times):
    """Select persistence threshold via max-gap heuristic."""
    if len(death_times) == 0:
        return 0.0
    deaths = sorted(death_times)
    if len(deaths) == 1:
        return deaths[0]
    gaps = [deaths[i + 1] - deaths[i] for i in range(len(deaths) - 1)]
    max_gap = max(gaps)
    if max_gap <= 0:
        return float(np.median(deaths))
    split_idx = gaps.index(max_gap)
    return 0.5 * (deaths[split_idx] + deaths[split_idx + 1])


def _labels_from_tau(dist_matrix, tau):
    """Assign cluster labels based on threshold tau."""
    n = dist_matrix.shape[0]
    uf = UnionFind(n)
    for i in range(n):
        for j in range(i + 1, n):
            if dist_matrix[i, j] <= tau:
                uf.union(i, j)
    roots = [uf.find(i) for i in range(n)]
    root_to_label = {}
    labels = []
    next_label = 0
    for r in roots:
        if r not in root_to_label:
            root_to_label[r] = next_label
            next_label += 1
        labels.append(root_to_label[r])
    return labels


# ============================================================
# Hypothesis Class
# ============================================================

class VizHypothesis:
    """A visualization hypothesis (generated code + metadata)."""
    def __init__(self, content, logp=None, features=None):
        self.content = content  # The generated visualization code
        self.logp = logp
        self.features = features or extract_viz_features(content)
    
    def __dict__(self):
        return {
            "content": self.content,
            "logp": self.logp,
            "features": self.features,
        }


# ============================================================
# Oracle Answerer
# ============================================================

class OracleVizAnswerer:
    """
    Oracle that answers questions based on ground truth preferences.
    Ground truth is a dict specifying user's true intent, e.g.:
    {"chart_type": "heatmap", "library": "seaborn", "color_scheme": "viridis"}
    """
    def __init__(self, llm_call, task_data, seed, mode="run"):
        self.llm_call = llm_call
        self.task_data = task_data
        self.ground_truth = task_data.get("ground_truth", {})
        self.total_cost = 0
        self.seed = seed
        self.mode = mode

    def answer(self, question: str) -> str:
        """
        Answer a clarifying question based on ground truth.
        Uses LLM to interpret question and provide consistent answer.
        """
        system_prompt = (
            "You are a user with specific visualization preferences. "
            "Answer the question based ONLY on your preferences below. "
            "Give a short, direct answer.\n\n"
            f"Your preferences:\n{self._format_preferences()}"
        )
        user_prompt = f"Question: {question}"
        
        response = self.llm_call(
            system_prompt=system_prompt,
            user_prompt=user_prompt,
            n_used=1,
            seed=self.seed,
        )
        self.total_cost += obtain_cost(response)
        return response.choices[0].message.content.strip()

    def _format_preferences(self) -> str:
        """Format ground truth preferences for the prompt."""
        lines = []
        if "chart_type" in self.ground_truth:
            lines.append(f"- Preferred chart type: {self.ground_truth['chart_type']}")
        if "library" in self.ground_truth:
            lines.append(f"- Preferred library: {self.ground_truth['library']}")
        if "color_scheme" in self.ground_truth:
            lines.append(f"- Preferred color scheme: {self.ground_truth['color_scheme']}")
        if "style" in self.ground_truth:
            lines.append(f"- Style preferences: {self.ground_truth['style']}")
        if "aggregation" in self.ground_truth:
            lines.append(f"- Data aggregation: {self.ground_truth['aggregation']}")
        if "additional" in self.ground_truth:
            lines.append(f"- Additional requirements: {self.ground_truth['additional']}")
        return "\n".join(lines) if lines else "No specific preferences."

    def evaluate_hypothesis(self, hypothesis: VizHypothesis) -> dict:
        """Evaluate how well a hypothesis matches ground truth."""
        features = hypothesis.features
        score = 0
        max_score = 0
        details = {}
        
        if "chart_type" in self.ground_truth:
            max_score += 1
            if features["chart_type"] == self.ground_truth["chart_type"]:
                score += 1
                details["chart_type"] = "match"
            else:
                details["chart_type"] = f"mismatch ({features['chart_type']} vs {self.ground_truth['chart_type']})"
        
        if "library" in self.ground_truth:
            max_score += 1
            if features["library"] == self.ground_truth["library"]:
                score += 1
                details["library"] = "match"
            else:
                details["library"] = f"mismatch ({features['library']} vs {self.ground_truth['library']})"
        
        if "color_scheme" in self.ground_truth:
            max_score += 1
            if features["color_scheme"] == self.ground_truth["color_scheme"]:
                score += 1
                details["color_scheme"] = "match"
            else:
                details["color_scheme"] = f"mismatch"
        
        return {
            "score": score,
            "max_score": max_score,
            "match_rate": score / max_score if max_score > 0 else 1.0,
            "details": details,
        }


class OracleBinaryVizAnswerer(OracleVizAnswerer):
    """Oracle that answers Yes/No questions about visualization preferences."""
    
    def answer(self, question: str) -> bool:
        """Answer a binary (Yes/No) question."""
        system_prompt = (
            "You are a user with specific visualization preferences. "
            "Answer the Yes/No question based ONLY on your preferences. "
            "Reply with exactly 'Yes' or 'No'.\n\n"
            f"Your preferences:\n{self._format_preferences()}"
        )
        user_prompt = f"Question: {question}"
        
        response = self.llm_call(
            system_prompt=system_prompt,
            user_prompt=user_prompt,
            n_used=1,
            seed=self.seed,
        )
        self.total_cost += obtain_cost(response)
        answer = response.choices[0].message.content.strip().lower()
        return answer.startswith("yes")


# ============================================================
# Base Reasoner
# ============================================================

class VizReasonerBase:
    """Base class for visualization task reasoners."""
    
    def __init__(self, llm_call, task_data, seed, mode="run", logprobs=False, unique_hs=False):
        self.llm_call = llm_call
        self.task_data = task_data
        self.data_description = task_data.get("data_description", "")
        self.instruction = task_data.get("instruction", "Visualize the data")
        self.total_cost = 0
        self.timeout = 5.0
        self.max_retry = 5
        self.total_questions = 1
        self.total_hypothesis = 10
        self.seed = seed
        self.mode = mode
        self.logprobs = logprobs
        self.unique_hs = unique_hs

    def generate_hypothesis(self, requirements: dict) -> list:
        """Generate visualization code hypotheses."""
        if self.mode == "run":
            return []
        return self._generate_hypothesis(requirements)

    def _generate_hypothesis(self, requirements: dict) -> list:
        all_hypothesis = []
        it = 0
        total_cost = 0
        
        while len(all_hypothesis) < self.total_hypothesis and it < self.max_retry:
            hypothesis, cost = self._sample_raw_hypothesis(
                requirements, self.total_hypothesis, self.seed + it
            )
            total_cost += cost
            if len(requirements) > 0:
                hypothesis = self._filter_hypothesis(hypothesis, requirements)
            all_hypothesis += hypothesis
            it += 1

        self.total_cost += total_cost
        return all_hypothesis[:self.total_hypothesis]

    def _sample_raw_hypothesis(self, requirements: dict, n_samples: int, seed: int):
        """Sample visualization code from LLM."""
        system_prompt = (
            "You are an expert data visualization programmer. "
            "Generate Python code to visualize the given data. "
            "Use matplotlib, seaborn, or plotly as appropriate. "
            "Include necessary imports. Make the visualization complete and ready to run. "
            "Wrap your code in ```python ``` blocks."
        )
        
        user_prompt = self._get_problem_statement(requirements)
        
        response = self.llm_call(
            system_prompt=system_prompt,
            user_prompt=user_prompt,
            n_used=max(1, n_samples // 2),
            seed=seed,
            logprobs=self.logprobs,
        )
        cost = obtain_cost(response)
        
        hypothesis_ls = []
        for i in range(len(response.choices)):
            content = response.choices[i].message.content
            code = extract_code_from_response(content)
            logp = None
            if self.logprobs and response.choices[i].logprobs:
                logp = sum(x.logprob for x in response.choices[i].logprobs.content)
            
            h = VizHypothesis(content=code, logp=logp)
            
            # Check uniqueness
            if self.unique_hs:
                normalized = normalize_viz_code(code)
                existing = [normalize_viz_code(x.content) for x in hypothesis_ls]
                if normalized not in existing:
                    hypothesis_ls.append(h)
            else:
                hypothesis_ls.append(h)
        
        return hypothesis_ls, cost

    def _filter_hypothesis(self, hypothesis: list, requirements: dict) -> list:
        """Filter hypotheses based on accumulated requirements."""
        # For visualization, we filter based on explicit constraints
        valid = []
        for h in hypothesis:
            features = h.features
            match = True
            
            for req_key, req_value in requirements.values():
                if "chart_type" in req_key.lower() and req_value.lower() != "any":
                    if features["chart_type"] != req_value.lower():
                        match = False
                        break
                if "library" in req_key.lower() and req_value.lower() != "any":
                    if features["library"] != req_value.lower():
                        match = False
                        break
            
            if match:
                valid.append(h)
        
        return valid if valid else hypothesis  # Return all if no matches

    def _get_problem_statement(self, requirements: dict) -> str:
        """Build the problem statement with requirements."""
        prompt = f"Data Description:\n{self.data_description}\n\n"
        prompt += f"Task: {self.instruction}\n"
        
        if requirements:
            prompt += "\nAdditional Requirements:\n"
            for idx, (q, a) in requirements.items():
                prompt += f"- {q}: {a}\n"
        
        return prompt

    def generate_questions(self, requirements: dict, restricted_questions: list) -> list:
        """Generate candidate clarifying questions."""
        q_ls = []
        it = 0
        while len(q_ls) < self.total_questions and it < self.max_retry:
            q_ls += self._generate_questions(requirements, 3, 1, self.seed + it)
            it += 1
            q_ls = list(set(q_ls) - set(restricted_questions))
        return q_ls

    def _generate_questions(self, requirements: dict, n_samples: int, n_used: int, seed: int) -> list:
        """Generate clarifying questions via LLM."""
        system_prompt = (
            "You are helping clarify a user's visualization preferences. "
            f"Generate {n_samples} questions to understand what kind of visualization the user wants. "
            "Focus on:\n"
            "- Chart type (bar, line, scatter, heatmap, pie, etc.)\n"
            "- Library preference (matplotlib, seaborn, plotly)\n"
            "- Color scheme\n"
            "- Aggregation method\n"
            "- Layout and styling\n\n"
            "Format as numbered list:\n"
            "1. Question 1?\n"
            "2. Question 2?\n"
            "..."
        )
        
        user_prompt = self._get_problem_statement(requirements)
        if requirements:
            user_prompt += "\nDo not repeat questions already answered."
        
        response = self.llm_call(
            system_prompt=system_prompt,
            user_prompt=user_prompt,
            n_used=n_used,
            seed=seed,
        )
        self.total_cost += obtain_cost(response)
        
        questions = []
        for choice in response.choices:
            content = choice.message.content
            # Extract numbered questions
            matches = re.findall(r"\d+\.\s*(.+\?)", content)
            questions.extend(matches)
        
        return questions

    def select_best_question(self, questions: list, hypothesis: list) -> str:
        """Select the best question (base: return first)."""
        if questions:
            return questions[0]
        return "What type of chart would you prefer?"

    def q_a_to_requirement(self, question: str, answer: str) -> tuple:
        """Convert Q&A to requirement tuple."""
        return (question, answer)


# ============================================================
# Active Reasoner (Entropy-based)
# ============================================================

class ActiveVizReasoner(VizReasonerBase):
    """Active reasoner using entropy-based question selection."""
    
    def __init__(self, llm_call, task_data, total_questions, total_hypothesis, seed,
                 mode="run", logprobs=False, unique_hs=False):
        super().__init__(llm_call, task_data, seed, mode, logprobs, unique_hs)
        self.total_questions = total_questions
        self.total_hypothesis = total_hypothesis

    def generate_hypothesis(self, requirements: dict) -> list:
        return self._generate_hypothesis(requirements)

    def select_best_question(self, questions: list, hypothesis: list) -> str:
        """Select question that maximizes entropy over hypothesis features."""
        if len(questions) == 0 or len(hypothesis) < 2:
            return questions[0] if questions else "What chart type do you prefer?"
        
        best_question = None
        best_entropy = -np.inf
        
        for q in questions:
            # Estimate answer distribution by simulating on each hypothesis
            answers = self._simulate_answers(q, hypothesis)
            entropy = self._compute_entropy(answers)
            
            if entropy > best_entropy:
                best_entropy = entropy
                best_question = q
        
        return best_question if best_question else questions[0]

    def _simulate_answers(self, question: str, hypothesis: list) -> list:
        """Simulate answers for each hypothesis."""
        q_lower = question.lower()
        answers = []
        
        for h in hypothesis:
            # Simple heuristic: map question to feature
            if "chart" in q_lower or "type" in q_lower:
                answers.append(h.features["chart_type"])
            elif "library" in q_lower or "package" in q_lower:
                answers.append(h.features["library"])
            elif "color" in q_lower:
                answers.append(h.features["color_scheme"])
            elif "interactive" in q_lower:
                answers.append(str(h.features["interactive"]))
            else:
                # Generic: hash the content
                answers.append(h.features["chart_type"])
        
        return answers

    def _compute_entropy(self, answers: list) -> float:
        """Compute entropy from answer distribution."""
        if not answers:
            return 0.0
        counts = {}
        for a in answers:
            counts[a] = counts.get(a, 0) + 1
        total = len(answers)
        probs = [c / total for c in counts.values()]
        entropy = -sum(p * np.log(p + 1e-9) for p in probs)
        return entropy


# ============================================================
# TAI Reasoner (Topological Active Inference)
# ============================================================

class TAIVizReasoner(ActiveVizReasoner):
    """
    TAI (Topological Active Inference) reasoner for visualization tasks.
    Uses persistent homology to discover robust intent clusters and
    TEIG (Topological Expected Information Gain) for question selection.
    """
    
    def __init__(self, llm_call, task_data, total_questions, total_hypothesis, seed,
                 logprobs=False, unique_hs=False, embedding="tfidf",
                 max_features=2048, embedding_model="text-embedding-3-large", tau=None):
        super().__init__(llm_call, task_data, total_questions, total_hypothesis, seed,
                        mode="run", logprobs=logprobs, unique_hs=unique_hs)
        self.embedding = embedding
        self.max_features = max_features
        self.embedding_model = embedding_model
        self.tau = tau

    def _embed_hypothesis(self, hypothesis: list) -> np.ndarray:
        """Embed hypotheses into semantic space."""
        # Combine code and features for richer embedding
        texts = []
        for h in hypothesis:
            feature_text = features_to_text(h.features)
            # Combine normalized code + feature summary
            code_text = normalize_viz_code(h.content)[:500]  # Truncate
            texts.append(f"{feature_text} {code_text}")
        
        if self.embedding == "tfidf":
            vectorizer = TfidfVectorizer(max_features=self.max_features)
            vectors = vectorizer.fit_transform(texts).toarray()
            return vectors
        elif self.embedding == "openai":
            vectors = np.array(get_embeddings(texts, model_name=self.embedding_model))
            return vectors
        else:
            raise ValueError(f"Unsupported embedding: {self.embedding}")

    def _get_cluster_labels(self, hypothesis: list) -> tuple:
        """Compute cluster labels using persistent homology."""
        if len(hypothesis) <= 1:
            return [0 for _ in hypothesis], 1, 0.0
        
        vectors = self._embed_hypothesis(hypothesis)
        dist_matrix = cosine_distances(vectors)
        death_times = _compute_persistence_deaths(dist_matrix)
        tau = self.tau if self.tau is not None else _select_tau_from_gaps(death_times)
        labels = _labels_from_tau(dist_matrix, tau)
        
        return labels, len(set(labels)), tau

    def _weights_from_logprobs(self, hypothesis: list) -> np.ndarray:
        """Compute weights from log probabilities."""
        if not self.logprobs:
            return np.ones(len(hypothesis))
        weights = []
        for h in hypothesis:
            if h.logp is None:
                weights.append(1.0)
            else:
                weights.append(1 / np.exp(h.logp))
        return np.array(weights)

    def _compute_teig(self, outputs: list, labels: list, weights: np.ndarray) -> float:
        """
        Compute Topological Expected Information Gain (TEIG).
        TEIG = I(Z; A | q) where Z is the cluster variable.
        """
        eps = 1e-9
        cluster_to_outputs = defaultdict(lambda: defaultdict(float))
        cluster_weights = defaultdict(float)
        
        for out, label, w in zip(outputs, labels, weights):
            if out is None:
                out = "unknown"
            cluster_to_outputs[label][out] += w
            cluster_weights[label] += w

        total_weight = sum(cluster_weights.values())
        if total_weight == 0:
            return 0.0

        # P(Z=z)
        p_z = {z: w / total_weight for z, w in cluster_weights.items()}
        
        # P(A=a)
        p_a = defaultdict(float)
        
        # P(A=a | Z=z)
        p_a_given_z = {}
        for z, out_counts in cluster_to_outputs.items():
            z_total = cluster_weights[z]
            if z_total == 0:
                continue
            p_a_given_z[z] = {}
            for a, c in out_counts.items():
                p = c / z_total
                p_a_given_z[z][a] = p
                p_a[a] += p_z[z] * p

        # TEIG = sum_z P(z) sum_a P(a|z) log(P(a|z) / P(a))
        teig = 0.0
        for z, out_probs in p_a_given_z.items():
            for a, p in out_probs.items():
                teig += p_z[z] * p * np.log((p + eps) / (p_a[a] + eps))
        
        return teig

    def select_best_question(self, questions: list, hypothesis: list) -> str:
        """Select question maximizing TEIG."""
        if len(questions) == 0 or len(hypothesis) == 0:
            return "What chart type would you prefer?"

        labels, num_clusters, tau = self._get_cluster_labels(hypothesis)
        print(f"[TAI] Discovered {num_clusters} intent clusters (tau={tau:.4f})")
        
        # If only one cluster, fall back to entropy-based
        if num_clusters <= 1:
            return super().select_best_question(questions, hypothesis)

        weights = self._weights_from_logprobs(hypothesis)
        best_question = None
        best_teig = -np.inf

        for q in questions:
            # Simulate answers for each hypothesis
            outputs = self._simulate_answers(q, hypothesis)
            teig = self._compute_teig(outputs, labels, weights)
            print(f"  Question: {q[:50]}... TEIG={teig:.4f}")
            
            if teig > best_teig:
                best_teig = teig
                best_question = q

        return best_question if best_question else questions[0]


# ============================================================
# Binary Question Variants
# ============================================================

class ActiveBinaryVizReasoner(ActiveVizReasoner):
    """Active reasoner with binary (Yes/No) questions."""
    
    def _generate_questions(self, requirements: dict, n_samples: int, n_used: int, seed: int) -> list:
        """Generate binary clarifying questions."""
        system_prompt = (
            "You are helping clarify a user's visualization preferences. "
            f"Generate {n_samples} Yes/No questions to understand what the user wants. "
            "Focus on specific preferences:\n"
            "- 'Do you want a bar chart?'\n"
            "- 'Should I use seaborn for styling?'\n"
            "- 'Do you prefer an interactive chart?'\n\n"
            "Format as numbered list:\n"
            "1. Yes/No question?\n"
            "2. Yes/No question?\n"
        )
        
        user_prompt = self._get_problem_statement(requirements)
        
        response = self.llm_call(
            system_prompt=system_prompt,
            user_prompt=user_prompt,
            n_used=n_used,
            seed=seed,
        )
        self.total_cost += obtain_cost(response)
        
        questions = []
        for choice in response.choices:
            content = choice.message.content
            matches = re.findall(r"\d+\.\s*(.+\?)", content)
            questions.extend(matches)
        
        return questions

    def _simulate_answers(self, question: str, hypothesis: list) -> list:
        """Simulate binary answers."""
        q_lower = question.lower()
        answers = []
        
        for h in hypothesis:
            answer = False
            
            # Chart type questions
            if "bar" in q_lower and "chart" in q_lower:
                answer = h.features["chart_type"] == "bar"
            elif "line" in q_lower and ("chart" in q_lower or "plot" in q_lower):
                answer = h.features["chart_type"] == "line"
            elif "scatter" in q_lower:
                answer = h.features["chart_type"] == "scatter"
            elif "heatmap" in q_lower:
                answer = h.features["chart_type"] == "heatmap"
            elif "pie" in q_lower:
                answer = h.features["chart_type"] == "pie"
            elif "histogram" in q_lower:
                answer = h.features["chart_type"] == "histogram"
            
            # Library questions
            elif "seaborn" in q_lower:
                answer = h.features["library"] == "seaborn"
            elif "plotly" in q_lower:
                answer = h.features["library"] == "plotly"
            elif "matplotlib" in q_lower:
                answer = h.features["library"] == "matplotlib"
            
            # Feature questions
            elif "interactive" in q_lower:
                answer = h.features["interactive"]
            elif "legend" in q_lower:
                answer = h.features["has_legend"]
            elif "title" in q_lower:
                answer = h.features["has_title"]
            elif "grid" in q_lower:
                answer = h.features["has_grid"]
            
            answers.append(str(answer))
        
        return answers


class TAIBinaryVizReasoner(ActiveBinaryVizReasoner, TAIVizReasoner):
    """TAI reasoner with binary questions."""
    
    def __init__(self, llm_call, task_data, total_questions, total_hypothesis, seed,
                 logprobs=False, unique_hs=False, embedding="tfidf",
                 max_features=2048, embedding_model="text-embedding-3-large", tau=None):
        TAIVizReasoner.__init__(
            self, llm_call, task_data, total_questions, total_hypothesis, seed,
            logprobs=logprobs, unique_hs=unique_hs, embedding=embedding,
            max_features=max_features, embedding_model=embedding_model, tau=tau
        )

    def select_best_question(self, questions: list, hypothesis: list) -> str:
        return TAIVizReasoner.select_best_question(self, questions, hypothesis)
