import torch
import numpy as np
from typing import List, Dict, Tuple, Optional
import json
from pathlib import Path
from collections import defaultdict
import re
from scipy.spatial.distance import cosine, euclidean, cityblock, chebyshev
from scipy.stats import wasserstein_distance
from sklearn.metrics.pairwise import cosine_similarity
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
import warnings

warnings.filterwarnings("ignore")

# Try to import transformers for embedding models
try:
    from transformers import AutoTokenizer, AutoModel

    TRANSFORMERS_AVAILABLE = True
except ImportError:
    TRANSFORMERS_AVAILABLE = False
    print("Warning: transformers not available, falling back to text-only mode")

# Try to import vLLM as fallback
try:
    import vllm
    from vllm import LLM

    VLLM_AVAILABLE = True
except ImportError:
    VLLM_AVAILABLE = False
    print("Warning: vLLM not available, using transformers or text-only mode")


class SemanticAnalyzer:
    """Comprehensive semantic analysis using Qwen3 embeddings and multiple distance metrics."""

    def __init__(
        self, model_name: str = "Qwen/Qwen3-Embedding-0.6B", use_embeddings: bool = True
    ):
        """Initialize the semantic analyzer with Qwen3 embedding model."""
        self.model_name = model_name
        self.embedding_model = None
        self.use_embeddings = use_embeddings
        self.rouge_scorer = rouge_scorer.RougeScorer(
            ["rouge1", "rouge2", "rougeL"], use_stemmer=True
        )
        self.smoothing = SmoothingFunction().method1

    def _get_embedding_model(self):
        """Lazy load the embedding model using transformers or vLLM."""
        if self.embedding_model is None:
            print(f"Loading {self.model_name} embedding model...")

            # Clear GPU cache first
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            # Try transformers first (more memory efficient)
            if TRANSFORMERS_AVAILABLE:
                print("Using transformers library...")
                # Force clear any cached models
                import gc

                gc.collect()
                torch.cuda.empty_cache()

                self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
                self.embedding_model = AutoModel.from_pretrained(
                    self.model_name,
                    torch_dtype=torch.float16,
                    device_map="auto",
                    trust_remote_code=True,
                )
                self.model_type = "transformers"
                print(
                    f"Transformers model loaded successfully! Hidden size: {self.embedding_model.config.hidden_size}"
                )
                return self.embedding_model

            # Fall back to vLLM
            if VLLM_AVAILABLE:
                self.embedding_model = LLM(model=self.model_name, task="embed")
                self.model_type = "vllm"
                print("vLLM model loaded successfully!")
                return self.embedding_model
            else:
                raise RuntimeError("Neither transformers nor vLLM available")

        return self.embedding_model

    def get_detailed_instruct(self, task_description: str, query: str) -> str:
        """Format query with instruction for Qwen3 embedding model."""
        return f"Instruct: {task_description}\nQuery: {query}"

    def get_qwen3_instruct(
        self, query_data: str, analysis_type: str = "general"
    ) -> str:
        """Format query data with research retrieval instructions based on analysis type."""

        if analysis_type == "rediscovery":
            return f"""
Instruct: You are an expert research retrieval system analyzing conceptual alignment between research abstracts and proposed solutions. Your task is to retrieve and assess semantic similarity for solution rediscovery retrieval.

Research Text:
{query_data}
"""
        elif analysis_type == "sr_solver":
            return f"""
Instruct: You are an expert research retrieval system analyzing problem-solution alignment. Your task is to retrieve and assess semantic similarity for evaluating how well proposed solutions address problem statements.

Research Text:
{query_data}
"""
        elif analysis_type == "sr_baseline":
            return f"""
Instruct: You are an expert research retrieval system analyzing problem generalization quality. Your task is to retrieve and assess semantic similarity for evaluating how well problem statements capture original abstract concepts.

Research Text:
{query_data}
"""
        elif analysis_type == "sr_generalizer":
            return f"""
Instruct: You are an expert research retrieval system analyzing problem generalization quality. Your task is to retrieve and assess semantic similarity for evaluating the quality of generalized problem statements.

Research Text:
{query_data}
"""
        elif analysis_type == "solution_equivalence":
            return f"""
Instruct: You are an expert research retrieval system analyzing methodological equivalence. Your task is to retrieve and assess semantic similarity for evaluating conceptual alignment between proposed and actual methodologies.

Research Text:
{query_data}
"""
        else:  # general case
            return f"""
Instruct: You are an expert research retrieval system analyzing semantic similarity between research texts. Your task is to retrieve and assess conceptual alignment and semantic coherence for research retrieval purposes.

Research Text:
{query_data}
"""

    def get_embeddings(
        self,
        texts: List[str],
        task_description: str = "Analyze semantic similarity between research texts",
        analysis_type: str = "general",
    ) -> torch.Tensor:
        """Get embeddings for a list of texts using Qwen3."""
        if not self.use_embeddings:
            # Return dummy embeddings for text-only analysis
            return torch.randn(
                len(texts), 1024
            )  # Qwen3-Embedding-0.6B output dimension

        model = self._get_embedding_model()

        # Format texts with specific analysis type instructions
        formatted_texts = [
            self.get_qwen3_instruct(text, analysis_type) for text in texts
        ]

        if hasattr(self, "model_type") and self.model_type == "transformers":
            # Use transformers approach
            return self._get_transformers_embeddings(formatted_texts)
        else:
            # Use vLLM approach
            outputs = model.embed(formatted_texts)
            embeddings = torch.tensor([o.outputs.embedding for o in outputs])
            return embeddings

    def _get_transformers_embeddings(self, texts: List[str]) -> torch.Tensor:
        """Get embeddings using transformers library."""
        # Tokenize texts
        inputs = self.tokenizer(
            texts, padding=True, truncation=True, return_tensors="pt", max_length=512
        )

        # Move to same device as model
        device = next(self.embedding_model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Get embeddings
        with torch.no_grad():
            outputs = self.embedding_model(**inputs)
            # Use mean pooling of last hidden states
            embeddings = outputs.last_hidden_state.mean(dim=1)

        return embeddings

    def calculate_basic_distances(
        self, emb1: torch.Tensor, emb2: torch.Tensor
    ) -> Dict[str, float]:
        """Calculate basic distance metrics between two embeddings."""
        # Convert to numpy for scipy functions (move to CPU first)
        vec1, vec2 = emb1.cpu().numpy().flatten(), emb2.cpu().numpy().flatten()

        return {
            "euclidean": float(euclidean(vec1, vec2)),
            "cosine": float(cosine(vec1, vec2)),
            "manhattan": float(cityblock(vec1, vec2)),
            "chebyshev": float(chebyshev(vec1, vec2)),
            "cosine_similarity": float(cosine_similarity([vec1], [vec2])[0][0]),
        }

    def calculate_advanced_metrics(self, text1: str, text2: str) -> Dict[str, float]:
        """Calculate advanced text similarity metrics."""
        metrics = {}

        # BLEU Score
        tokens1 = text1.split()
        tokens2 = text2.split()
        metrics["bleu"] = float(
            sentence_bleu([tokens1], tokens2, smoothing_function=self.smoothing)
        )

        # ROUGE Scores
        rouge_scores = self.rouge_scorer.score(text1, text2)
        metrics["rouge1"] = rouge_scores["rouge1"].fmeasure
        metrics["rouge2"] = rouge_scores["rouge2"].fmeasure
        metrics["rougeL"] = rouge_scores["rougeL"].fmeasure

        # Jaccard Similarity (word-level)
        words1 = set(text1.lower().split())
        words2 = set(text2.lower().split())
        intersection = len(words1.intersection(words2))
        union = len(words1.union(words2))
        metrics["jaccard"] = intersection / union if union > 0 else 0.0

        return metrics

    def calculate_semantic_drift(
        self,
        input_text: str,
        intermediate_texts: List[str],
        final_text: str,
        analysis_type: str = "general",
    ) -> Dict[str, any]:
        """Calculate semantic drift across iterations."""
        all_texts = [input_text] + intermediate_texts + [final_text]
        embeddings = self.get_embeddings(
            all_texts,
            "Track semantic evolution from input to final output",
            analysis_type,
        )

        input_emb = embeddings[0]
        final_emb = embeddings[-1]

        # Calculate distances from input to each step
        drift_metrics = {
            "input_to_final": self.calculate_basic_distances(input_emb, final_emb),
            "intermediate_distances": [],
            "convergence_analysis": {},
        }

        # Track intermediate distances
        for i, intermediate_emb in enumerate(embeddings[1:-1]):
            dist = self.calculate_basic_distances(input_emb, intermediate_emb)
            drift_metrics["intermediate_distances"].append(
                {"step": i + 1, "distances": dist}
            )

        # Convergence analysis
        if len(intermediate_texts) > 1:
            cosine_distances = [
                d["distances"]["cosine"]
                for d in drift_metrics["intermediate_distances"]
            ]
            drift_metrics["convergence_analysis"] = {
                "variance": float(np.var(cosine_distances)),
                "trend": (
                    "converging"
                    if cosine_distances[-1] < cosine_distances[0]
                    else "diverging"
                ),
                "stability": float(np.std(cosine_distances)),
            }

        return drift_metrics

    def analyze_semantic_coherence(
        self, texts: List[str], analysis_type: str = "general"
    ) -> Dict[str, float]:
        """Analyze semantic coherence across a sequence of texts."""
        if len(texts) < 2:
            return {"coherence_score": 1.0, "pairwise_similarities": []}

        embeddings = self.get_embeddings(
            texts, "Analyze semantic coherence across text sequence", analysis_type
        )

        pairwise_similarities = []
        for i in range(len(embeddings) - 1):
            sim = self.calculate_basic_distances(embeddings[i], embeddings[i + 1])
            pairwise_similarities.append(sim["cosine_similarity"])

        return {
            "coherence_score": float(np.mean(pairwise_similarities)),
            "pairwise_similarities": pairwise_similarities,
            "coherence_variance": float(np.var(pairwise_similarities)),
        }

    def comprehensive_semantic_analysis(
        self,
        input_text: str,
        output_text: str,
        intermediate_texts: Optional[List[str]] = None,
        analysis_type: str = "general",
    ) -> Dict[str, any]:
        """Perform comprehensive semantic analysis between input and output."""

        # Get embeddings
        texts_to_embed = [input_text, output_text]
        if intermediate_texts:
            texts_to_embed = [input_text] + intermediate_texts + [output_text]

        embeddings = self.get_embeddings(
            texts_to_embed,
            "Comprehensive semantic analysis of research text transformation",
            analysis_type,
        )

        input_emb = embeddings[0]
        output_emb = embeddings[-1]

        # Basic distance metrics
        basic_distances = self.calculate_basic_distances(input_emb, output_emb)

        # Advanced text metrics
        advanced_metrics = self.calculate_advanced_metrics(input_text, output_text)

        # Combine all metrics
        analysis = {
            "basic_distances": basic_distances,
            "advanced_metrics": advanced_metrics,
            "semantic_similarity_score": basic_distances["cosine_similarity"],
        }

        # Add drift analysis if intermediate texts provided
        if intermediate_texts:
            drift_analysis = self.calculate_semantic_drift(
                input_text, intermediate_texts, output_text, analysis_type
            )
            analysis["drift_analysis"] = drift_analysis

            # Coherence analysis
            coherence = self.analyze_semantic_coherence(
                [input_text] + intermediate_texts + [output_text], analysis_type
            )
            analysis["coherence_analysis"] = coherence

        return analysis


def extract_iteration_data_from_solver_log(
    solver_log: Dict,
) -> Tuple[str, List[str], str]:
    """Extract input, intermediate attempts, and final solution from solver workflow log."""
    # Get final solution from summary
    final_solution = solver_log.get("summary", {}).get("final_solution", "")

    # Extract intermediate attempts from external iterations
    intermediate_texts = []
    external_iterations = solver_log.get("external_iterations", [])
    for ext_iter in external_iterations:
        internal_attempts = ext_iter.get("internal_attempts", [])
        for int_attempt in internal_attempts:
            if "solution" in int_attempt:
                intermediate_texts.append(int_attempt["solution"])

    # For input text, we'll use the problem statement (not available in solver log)
    # This will be passed from the paper metadata
    input_text = ""

    return input_text, intermediate_texts, final_solution


def extract_iteration_data_from_generalizer_log(
    generalizer_log: Dict,
) -> Tuple[str, List[str], str]:
    """Extract input, intermediate attempts, and final problem statement from generalizer workflow log."""
    # Get final problem statement from summary
    final_problem_statement = generalizer_log.get("summary", {}).get(
        "final_problem_statement", ""
    )

    # Extract intermediate attempts from external iterations
    intermediate_texts = []
    external_iterations = generalizer_log.get("external_iterations", [])
    for ext_iter in external_iterations:
        internal_attempts = ext_iter.get("internal_attempts", [])
        for int_attempt in internal_attempts:
            if "solution" in int_attempt:
                intermediate_texts.append(int_attempt["solution"])

    # For input text, we'll use the abstract (not available in generalizer log)
    # This will be passed from the paper metadata
    input_text = ""

    return input_text, intermediate_texts, final_problem_statement


def analyze_paper_semantic_metrics(
    paper_data: Dict, analyzer: SemanticAnalyzer
) -> Dict[str, any]:
    """Analyze semantic metrics for a single paper."""
    results = {}

    # Get input texts from custom_log
    custom_log = paper_data.get("custom_log", {})
    abstract = custom_log.get("original_abstract", "")
    problem_statement = custom_log.get("problem_statement", "")

    # Analyze solver workflow if available
    solver_log = custom_log.get("raw_solver_workflow_log", {})
    if solver_log:
        _, intermediate_texts, final_solution = extract_iteration_data_from_solver_log(
            solver_log
        )
        if problem_statement and final_solution:
            solver_analysis = analyzer.comprehensive_semantic_analysis(
                problem_statement, final_solution, intermediate_texts, "sr_solver"
            )
            results["solver_semantic_analysis"] = solver_analysis

    # Analyze generalizer workflow if available
    generalizer_log = custom_log.get("raw_generalizer_workflow_log", {})
    if generalizer_log:
        _, intermediate_texts, final_problem_statement = (
            extract_iteration_data_from_generalizer_log(generalizer_log)
        )
        if abstract and final_problem_statement:
            generalizer_analysis = analyzer.comprehensive_semantic_analysis(
                abstract, final_problem_statement, intermediate_texts, "sr_generalizer"
            )
            results["generalizer_semantic_analysis"] = generalizer_analysis

    return results


def analyze_evaluation_semantic_metrics(
    original_abstract: str,
    problem_statement: str,
    proposed_solution: str,
    analyzer: SemanticAnalyzer,
) -> Dict[str, any]:
    """Analyze semantic metrics for specific evaluation types (rediscovery, sr_baseline, etc.)."""
    results = {}

    # Rediscovery analysis: original abstract vs proposed solution
    rediscovery_analysis = analyzer.comprehensive_semantic_analysis(
        original_abstract, proposed_solution, analysis_type="rediscovery"
    )
    results["rediscovery_semantic_analysis"] = rediscovery_analysis

    # SR_baseline analysis: problem statement vs original abstract
    sr_baseline_analysis = analyzer.comprehensive_semantic_analysis(
        problem_statement, original_abstract, analysis_type="sr_baseline"
    )
    results["sr_baseline_semantic_analysis"] = sr_baseline_analysis

    # SR_solver analysis: problem statement vs proposed solution
    sr_solver_analysis = analyzer.comprehensive_semantic_analysis(
        problem_statement, proposed_solution, analysis_type="sr_solver"
    )
    results["sr_solver_semantic_analysis"] = sr_solver_analysis

    # SR_generalizer analysis: original abstract vs problem statement
    sr_generalizer_analysis = analyzer.comprehensive_semantic_analysis(
        original_abstract, problem_statement, analysis_type="sr_generalizer"
    )
    results["sr_generalizer_semantic_analysis"] = sr_generalizer_analysis

    return results


if __name__ == "__main__":
    # Example usage
    analyzer = SemanticAnalyzer()

    # Test with sample research texts
    original_abstract = "This paper presents a novel approach to machine learning using deep neural networks for image classification."
    problem_statement = "How can we develop effective machine learning methods for visual recognition tasks?"
    proposed_solution = "We introduce a new deep learning methodology using convolutional neural networks for image classification."

    # Test different analysis types
    print("=== Rediscovery Analysis ===")
    rediscovery_analysis = analyzer.comprehensive_semantic_analysis(
        original_abstract, proposed_solution, analysis_type="rediscovery"
    )
    print(
        f"Rediscovery Similarity: {rediscovery_analysis['semantic_similarity_score']:.3f}"
    )

    print("\n=== SR_baseline Analysis ===")
    sr_baseline_analysis = analyzer.comprehensive_semantic_analysis(
        problem_statement, original_abstract, analysis_type="sr_baseline"
    )
    print(
        f"SR_baseline Similarity: {sr_baseline_analysis['semantic_similarity_score']:.3f}"
    )

    print("\n=== SR_solver Analysis ===")
    sr_solver_analysis = analyzer.comprehensive_semantic_analysis(
        problem_statement, proposed_solution, analysis_type="sr_solver"
    )
    print(
        f"SR_solver Similarity: {sr_solver_analysis['semantic_similarity_score']:.3f}"
    )

    print("\n=== SR_generalizer Analysis ===")
    sr_generalizer_analysis = analyzer.comprehensive_semantic_analysis(
        original_abstract, problem_statement, analysis_type="sr_generalizer"
    )
    print(
        f"SR_generalizer Similarity: {sr_generalizer_analysis['semantic_similarity_score']:.3f}"
    )

    print("\n=== Full Evaluation Analysis ===")
    full_analysis = analyze_evaluation_semantic_metrics(
        original_abstract, problem_statement, proposed_solution, analyzer
    )
    print("Full evaluation semantic metrics computed successfully!")
