"""
RAG Agent for Banner - Retrieves questions from fixed format experiment results using semantic search.

RAG baseline:
- Retrieves questions from fixed format experiment results (Fixed_Binary, Fixed_MultiChoice, Fixed_OpenText)
- Uses sentence-transformers embedding for semantic search
- Used with Flexible and Adaptive question formats
- Data source: rag_data directory containing fixed format experiment results
"""

import json
import pickle
import re
from pathlib import Path
from typing import List, Dict, Optional, Tuple, Any

import numpy as np

try:
    from sentence_transformers import SentenceTransformer
    SENTENCE_TRANSFORMERS_AVAILABLE = True
except ImportError:
    SENTENCE_TRANSFORMERS_AVAILABLE = False
    print("Warning: sentence-transformers not available. Install with: pip install sentence-transformers")

try:
    from sklearn.metrics.pairwise import cosine_similarity
    SKLEARN_AVAILABLE = True
except ImportError:
    SKLEARN_AVAILABLE = False
    print("Warning: scikit-learn not available. Install with: pip install scikit-learn")

from poster_config import PosterInteractionConfig, QuestionFormat


class PosterRAGAgent:
    """RAG agent that retrieves questions from fixed format experiment results using semantic search."""
    
    def __init__(self, rag_data_dir: Path, question_format: QuestionFormat, embedding_model_name: str = 'all-MiniLM-L6-v2'):
        """
        Initialize RAG agent.
        
        Args:
            rag_data_dir: Path to rag_data directory containing fixed format experiment results
            question_format: Question format (FLEXIBLE or ADAPTIVE)
            embedding_model_name: Sentence-BERT model name for embeddings
        """
        self.rag_data_dir = rag_data_dir
        self.question_format = question_format
        self.embedding_model_name = embedding_model_name
        self._qa_database: List[Dict] = []
        self._embeddings: Optional[np.ndarray] = None
        self._embedding_model: Optional[Any] = None
        
        # Initialize embedding model
        if SENTENCE_TRANSFORMERS_AVAILABLE:
            try:
                self._embedding_model = SentenceTransformer(embedding_model_name)
            except Exception as e:
                print(f"Warning: Failed to load embedding model {embedding_model_name}: {e}")
                self._embedding_model = None
        else:
            self._embedding_model = None
        
        # Load database and embeddings
        self._load_database()
        self._load_or_compute_embeddings()
    
    def _load_database(self):
        """Load question database from fixed format experiment results in rag_data directory."""
        if not self.rag_data_dir.exists():
            return
        
        fixed_format_configs = ["Fixed_Binary", "Fixed_MultiChoice", "Fixed_OpenText"]
        
        for sample_dir in self.rag_data_dir.iterdir():
            if not sample_dir.is_dir() or not sample_dir.name.startswith('sample_'):
                continue
            
            for exp_dir in sample_dir.iterdir():
                if not exp_dir.is_dir():
                    continue
                
                matches_any_format = any(config_name in exp_dir.name for config_name in fixed_format_configs)
                if not matches_any_format:
                    continue
                
                self._load_questions_from_experiment(exp_dir, sample_dir.name)
    
    def _load_questions_from_experiment(self, exp_dir: Path, sample_id: str):
        """Load questions from an experiment directory."""
        # Try to load from steps directory first (Web format)
        steps_dir = exp_dir / "steps"
        if steps_dir.exists():
            # Load all step JSON files
            step_files = sorted(steps_dir.glob("step_*.json"), key=lambda x: int(x.stem.split('_')[1]) if x.stem.split('_')[1].isdigit() else 0)
            
            for step_file in step_files:
                try:
                    with open(step_file, 'r', encoding='utf-8') as f:
                        step_data = json.load(f)
                    
                    # Extract questions from ask_questions actions or qa_log
                    question_text = None
                    if step_data.get('action') == 'ask_questions':
                        question_text = step_data.get('output', step_data.get('question', ''))
                    elif 'qa_log' in step_data:
                        qa_log = step_data['qa_log']
                        if isinstance(qa_log, list) and len(qa_log) > 0:
                            last_qa = qa_log[-1]
                            question_text = last_qa.get('question', '')
                    
                    if question_text:
                        # Skip format markers (e.g., [FORMAT_OPEN_TEXT])
                        if question_text.strip().startswith('[') and question_text.strip().endswith(']'):
                            continue
                        
                        # Parse questions - only take the first question to align with max_questions_per_batch
                        questions = self._parse_questions(question_text)
                        if questions:
                            # Only add the first parsed question to align with one-question-per-step logic
                            q = questions[0]
                            # Skip if question_text is just a format marker
                            if q.get('question_text', '').strip().startswith('[') and q.get('question_text', '').strip().endswith(']'):
                                continue
                            q['source'] = f"{sample_id}/{exp_dir.name}"
                            q['step_index'] = step_data.get('step_index', 0)
                            self._qa_database.append(q)
                except Exception:
                    continue
            return
        
        # Fallback: Try to load from qa_conversation.json (Banner format)
        qa_conversation_file = exp_dir / "qa_conversation.json"
        if not qa_conversation_file.exists():
            # Try to find qa_conversation.json in subdirectories (Banner saves to {index}_{brand_name}/)
            for subdir in exp_dir.iterdir():
                if subdir.is_dir():
                    qa_file = subdir / "qa_conversation.json"
                    if qa_file.exists():
                        qa_conversation_file = qa_file
                        break
        
        if qa_conversation_file.exists():
            try:
                with open(qa_conversation_file, 'r', encoding='utf-8') as f:
                    qa_data = json.load(f)
                
                # Extract questions from conversation_history
                conversation_history = qa_data.get('conversation_history', [])
                for idx, qa_pair in enumerate(conversation_history):
                    if isinstance(qa_pair, dict):
                        question_text = qa_pair.get('question', '')
                    elif isinstance(qa_pair, (list, tuple)) and len(qa_pair) >= 1:
                        question_text = qa_pair[0] if isinstance(qa_pair[0], str) else ''
                    else:
                        continue
                    
                    if question_text:
                        # Skip format markers (e.g., [FORMAT_OPEN_TEXT])
                        if question_text.strip().startswith('[') and question_text.strip().endswith(']'):
                            continue
                        
                        # Parse questions - only take the first question to align with max_questions_per_batch
                        questions = self._parse_questions(question_text)
                        if questions:
                            # Only add the first parsed question to align with one-question-per-step logic
                            q = questions[0]
                            # Skip if question_text is just a format marker
                            if q.get('question_text', '').strip().startswith('[') and q.get('question_text', '').strip().endswith(']'):
                                continue
                            q['source'] = f"{sample_id}/{exp_dir.name}"
                            q['step_index'] = idx
                            self._qa_database.append(q)
            except Exception:
                pass
    
    def _parse_questions(self, questions_text: str) -> List[Dict]:
        """Parse questions text into structured list."""
        questions = []
        questions_text_clean = questions_text.strip()
        
        # First, remove format markers like [FORMAT_OPEN_TEXT], [FORMAT_BINARY], etc.
        # These are added by the questioner agent but should not be part of the actual question
        import re
        format_marker_pattern = r'\[FORMAT_[^\]]+\]\s*\n?'
        questions_text_clean = re.sub(format_marker_pattern, '', questions_text_clean, flags=re.IGNORECASE).strip()
        
        # If after removing format markers, the text is empty or too short, skip it
        if not questions_text_clean or len(questions_text_clean) < 10:
            return []
        
        # Pattern: "### Question 1: ..." or "## Question 1: ..." or simple question text
        markdown_pattern = re.compile(
            r'(?:###|##)\s*Question\s+(?P<num>\d+)[:\-]\s*(?P<title>.*?)\n'
            r'(?P<content>.*?)(?=(?:###|##)\s*Question\s+\d+[:\-]|$)',
            re.DOTALL | re.IGNORECASE
        )
        
        matches = list(markdown_pattern.finditer(questions_text_clean))
        if matches:
            for match in matches:
                title = match.group('title').strip()
                content = match.group('content').strip()
                
                # Extract Direct Question
                direct_question_match = re.search(
                    r'\*\s*\*?\*?Direct Question[:\-]?\s*\*?\*?\*\s*(.+?)(?=\*|$)',
                    content, re.IGNORECASE | re.DOTALL
                )
                question_text = direct_question_match.group(1).strip() if direct_question_match else title
                
                # Extract Options
                options_match = re.search(
                    r'\*\s*\*?\*?Options?[:\-]?\s*\*?\*?\*\s*(.+?)(?=\*|$)',
                    content, re.IGNORECASE | re.DOTALL
                )
                options_text = options_match.group(1).strip() if options_match else ""
                
                questions.append({
                    'question_text': question_text,
                    'options': options_text,
                    'format_type': self.question_format.value,
                })
        else:
            # Try to parse as a single question (most common case for Banner)
            # Check if it's a multi-choice question (contains A), B), C), etc.)
            multi_choice_match = re.search(r'([A-Z]\)\s+[^\n]+)', questions_text_clean, re.MULTILINE)
            if multi_choice_match:
                # Extract the question part (before the options)
                question_part = questions_text_clean[:multi_choice_match.start()].strip()
                options_part = questions_text_clean[multi_choice_match.start():].strip()
                questions.append({
                    'question_text': question_part if question_part else questions_text_clean,
                    'options': options_part,
                    'format_type': self.question_format.value,
                })
            else:
                # Single question, no options
                questions.append({
                    'question_text': questions_text_clean,
                    'options': '',
                    'format_type': self.question_format.value,
                })
        
        return questions
    
    def _load_or_compute_embeddings(self):
        """Load embeddings from file or compute and save them."""
        if not self._qa_database or not self._embedding_model:
            return
        
        embeddings_file = self.rag_data_dir / f"embeddings_{self.embedding_model_name.replace('/', '_')}.pkl"
        
        # Try to load existing embeddings
        if embeddings_file.exists():
            try:
                with open(embeddings_file, 'rb') as f:
                    saved_data = pickle.load(f)
                    saved_db = saved_data.get('database')
                    saved_embeddings = saved_data.get('embeddings')
                    
                    # Verify database matches (compare question texts)
                    if saved_db and len(saved_db) == len(self._qa_database):
                        db_matches = all(
                            saved_db[i].get('question_text') == self._qa_database[i].get('question_text')
                            for i in range(len(self._qa_database))
                        )
                        if db_matches and saved_embeddings is not None:
                            self._embeddings = np.array(saved_embeddings)
                            print(f"✓ Loaded {len(self._qa_database)} question embeddings from {embeddings_file}")
                            return
            except Exception as e:
                print(f"Warning: Failed to load embeddings from {embeddings_file}: {e}")
        
        # Compute embeddings
        print(f"Computing embeddings for {len(self._qa_database)} questions...")
        # Debug: Print question sources to understand where questions come from
        if len(self._qa_database) > 0:
            print(f"[DEBUG] Question sources breakdown:")
            source_counts = {}
            for q in self._qa_database:
                source = q.get('source', 'unknown')
                source_counts[source] = source_counts.get(source, 0) + 1
            for source, count in sorted(source_counts.items()):
                print(f"  {source}: {count} questions")
        question_texts = [q.get('question_text', '') for q in self._qa_database]
        
        if question_texts:
            try:
                self._embeddings = self._embedding_model.encode(
                    question_texts,
                    show_progress_bar=True,
                    convert_to_numpy=True
                )
                
                # Save embeddings
                try:
                    embeddings_file.parent.mkdir(parents=True, exist_ok=True)
                    with open(embeddings_file, 'wb') as f:
                        pickle.dump({
                            'database': self._qa_database,
                            'embeddings': self._embeddings.tolist(),
                            'model_name': self.embedding_model_name
                        }, f)
                    print(f"✓ Saved embeddings to {embeddings_file}")
                except Exception as e:
                    print(f"Warning: Failed to save embeddings: {e}")
            except Exception as e:
                print(f"Error computing embeddings: {e}")
                self._embeddings = None
    
    def retrieve_questions(
        self,
        current_plan: Optional[str] = None,
        qa_history: Optional[List[Tuple[str, str]]] = None,
        target_count: int = 3,
    ) -> Optional[Dict]:
        """
        Retrieve questions using semantic search based on current plan and QA history.
        Returns retrieved questions with similarity scores as context for question generation.
        
        Args:
            current_plan: Current design plan (for semantic matching)
            qa_history: Recent Q&A history (for context-aware retrieval)
            target_count: Number of questions to retrieve (default: 3 for context)
        
        Returns:
            Dictionary with 'retrieved_questions' (list of dicts with question_text, options, similarity_score) 
            and 'context_text' (formatted string for prompt), or None if no questions found
        """
        if not self._qa_database:
            return None
        
        # If embedding model not available, fallback to simple sequential retrieval
        if self._embedding_model is None or self._embeddings is None or not SKLEARN_AVAILABLE:
            return self._retrieve_sequential_for_context(target_count)
        
        # Build query text from current_plan and qa_history
        query_parts = []
        
        if current_plan:
            # Use full plan text (do not truncate - pass complete information to LLM)
            plan_text = current_plan
            query_parts.append(f"Current design plan: {plan_text}")
        
        if qa_history:
            # Build recent Q&A context (last 3-5 Q&As)
            recent_qa = qa_history[-5:] if len(qa_history) > 5 else qa_history
            qa_text_parts = []
            for qa in recent_qa:
                if isinstance(qa, tuple):
                    q_text, a_text = qa
                elif isinstance(qa, dict):
                    q_text = qa.get('question', qa.get('question_text', ''))
                    a_text = qa.get('answer', '')
                else:
                    continue
                if q_text:
                    qa_text_parts.append(f"Q: {q_text}")
                if a_text:
                    qa_text_parts.append(f"A: {a_text}")
            if qa_text_parts:
                qa_context = "\n".join(qa_text_parts)
                query_parts.append(f"Recent Q&A:\n{qa_context}")
        
        if not query_parts:
            # No context available, use a generic query
            query_text = "design question about layout, colors, typography, or content"
        else:
            query_text = "\n".join(query_parts)
        
        # Compute query embedding
        try:
            query_embedding = self._embedding_model.encode(
                [query_text],
                convert_to_numpy=True
            )
            
            # Compute cosine similarity
            similarities = cosine_similarity(query_embedding, self._embeddings)[0]
            
            # Get top-k most similar questions (avoid duplicates by tracking used indices)
            if not hasattr(self, '_used_indices'):
                self._used_indices = set()
            
            # Filter out already used indices and get top candidates
            available_indices = [i for i in range(len(similarities)) if i not in self._used_indices]
            
            if not available_indices:
                # All questions used, reset or return empty
                return ""
            
            # Get top-k from available indices
            available_similarities = similarities[available_indices]
            top_k = min(target_count, len(available_indices))
            top_local_indices = np.argsort(available_similarities)[-top_k:][::-1]
            top_global_indices = [available_indices[i] for i in top_local_indices]
            
            # Get top-k questions with similarity scores
            retrieved_questions = []
            for idx in top_global_indices:
                similarity_score = similarities[idx]
                q = self._qa_database[idx]
                retrieved_questions.append({
                    'question_text': q.get('question_text', ''),
                    'options': q.get('options', ''),
                    'similarity_score': float(similarity_score),
                    'source': q.get('source', 'unknown'),
                })
                # Mark as used to avoid duplicates
                self._used_indices.add(idx)
            
            # Build context text for prompt
            context_parts = []
            context_parts.append("**Retrieved Reference Questions from Similar Design Scenarios:**")
            context_parts.append("(These questions were asked in similar design contexts. Use them as reference, but generate a NEW question appropriate for the current design.)")
            context_parts.append("")
            
            for i, rq in enumerate(retrieved_questions, 1):
                context_parts.append(f"Reference Question {i} (similarity: {rq['similarity_score']:.3f}):")
                if rq['options']:
                    context_parts.append(f"  Question: {rq['question_text']}")
                    context_parts.append(f"  Options: {rq['options']}")
                else:
                    context_parts.append(f"  {rq['question_text']}")
                context_parts.append("")
            
            context_text = "\n".join(context_parts)
            
            return {
                'retrieved_questions': retrieved_questions,
                'context_text': context_text,
            }
            
        except Exception as e:
            print(f"Error in semantic retrieval: {e}")
            # Fallback to sequential retrieval
            return self._retrieve_sequential_for_context(target_count)
    
    def _retrieve_sequential_for_context(self, target_count: int = 3) -> Optional[Dict]:
        """Fallback sequential retrieval for context."""
        if not self._qa_database:
            return None
        
        if not hasattr(self, '_current_index'):
            self._current_index = 0
        
        if self._current_index >= len(self._qa_database):
            return None
        
        # Get top-k questions sequentially
        retrieved_questions = []
        for _ in range(min(target_count, len(self._qa_database) - self._current_index)):
            if self._current_index >= len(self._qa_database):
                break
            q = self._qa_database[self._current_index]
            retrieved_questions.append({
                'question_text': q.get('question_text', ''),
                'options': q.get('options', ''),
                'similarity_score': 0.0,  # No similarity score for sequential
                'source': q.get('source', 'unknown'),
            })
            self._current_index += 1
        
        if not retrieved_questions:
            return None
        
        # Build context text
        context_parts = []
        context_parts.append("**Retrieved Reference Questions from Similar Design Scenarios:**")
        context_parts.append("(These questions were asked in similar design contexts. Use them as reference, but generate a NEW question appropriate for the current design.)")
        context_parts.append("")
        
        for i, rq in enumerate(retrieved_questions, 1):
            context_parts.append(f"Reference Question {i}:")
            if rq['options']:
                context_parts.append(f"  Question: {rq['question_text']}")
                context_parts.append(f"  Options: {rq['options']}")
            else:
                context_parts.append(f"  {rq['question_text']}")
            context_parts.append("")
        
        context_text = "\n".join(context_parts)
        
        return {
            'retrieved_questions': retrieved_questions,
            'context_text': context_text,
        }
    
    def has_data(self) -> bool:
        """Check if RAG database has data."""
        return len(self._qa_database) > 0

