from pathlib import Path
from typing import List, Dict, Any
import numpy as np
from tqdm import tqdm
import time

from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.core import (
    VectorStoreIndex,
    Document,
    StorageContext,
    load_index_from_storage,
    Settings
)
from llama_index.core.vector_stores.types import VectorStoreQuery
from openai import OpenAI

from .config import RAGConfig, RetrievalConfig
from .embedding_service import create_embedding_service
from .llamaindex_adapter import SiliconFlowEmbeddingAdapter


class VectorRetriever:
                  

    def __init__(self, config):
        self.config = config
        self.index = None
        self.retriever = None

                  
        if isinstance(config, RetrievalConfig):
            self._init_with_retrieval_config(config)
        else:
            self._init_with_rag_config(config)

    def _init_with_retrieval_config(self, config: RetrievalConfig):
                                  
                
        self.embedding_service = create_embedding_service(config)

                                 
        if config.embedding_service.lower() == "siliconflow":
                              
            Settings.embed_model = SiliconFlowEmbeddingAdapter(self.embedding_service)
        else:
                        
            Settings.embed_model = OpenAIEmbedding(
                api_key=config.api_key,
                api_base=config.base_url,
                model=config.embedding_model
            )

    def _init_with_rag_config(self, config: RAGConfig):
                                    
                
        Settings.embed_model = OpenAIEmbedding(
            api_key=config.api_key,
            api_base=config.base_url,
            model=config.embedding_model
        )

                                  
        self.openai_client = OpenAI(
            api_key=config.api_key,
            base_url=config.base_url
        )

        if config.verbose:
            print("VectorRetriever initialized")
    def build_index(self, documents: List[Document]) -> bool:
                       
                
        index_name = self._get_index_name()
        index_path = Path(self.config.index_dir) / self.config.game_name / index_name

                  
        if index_path.exists() and not self.config.force_rebuild:
            if self._load_existing_index(index_path):
                return True

               
        return self._build_new_index(documents, index_path)

    def _get_index_name(self) -> str:
                    
        model_name = self.config.embedding_model.replace('-', '_').replace('/', '_')

                                    
        if isinstance(self.config, RetrievalConfig) and hasattr(self.config, 'embedding_service'):
            model_name = f"{self.config.embedding_service}_{model_name}"

        if self.config.target_segment_id:
            suffix = "_with_timeless" if self.config.include_timeless else ""
            return f"{model_name}/segment_{self.config.target_segment_id}{suffix}"
        return f"{model_name}/full_corpus"

    def _load_existing_index(self, index_path: Path) -> bool:
                    
        try:
            storage_context = StorageContext.from_defaults(persist_dir=str(index_path))
            self.index = load_index_from_storage(storage_context)
            self.retriever = self.index.as_retriever(similarity_top_k=self.config.top_k)
            return True
        except Exception as e:
            if self.config.verbose:
                print(f"Error loading existing index: {e}")
            return False

    def _build_new_index(self, documents: List[Document], index_path: Path) -> bool:
                   
        try:
            self.index = VectorStoreIndex.from_documents(documents, show_progress=self.config.verbose)

                  
            index_path.mkdir(parents=True, exist_ok=True)
            self.index.storage_context.persist(persist_dir=str(index_path))

            self.retriever = self.index.as_retriever(similarity_top_k=self.config.top_k)
            return True
        except Exception as e:
            if self.config.verbose:
                print(f"Error building new index: {e}")
            return False

    def retrieve(self, query: str) -> List[Dict[str, Any]]:
                    
        if not self.retriever:
            raise ValueError("Retriever not initialized")

        try:
            nodes = self.retriever.retrieve(query)
            return [{
                'id': node.id_,
                'content': node.get_content(),
                'metadata': node.metadata,
                'score': getattr(node, 'score', 0.0)
            } for node in nodes]
        except Exception as e:
            if self.config.verbose:
                print(f"Error retrieving: {e}")
            return []

    def batch_get_embeddings(self, texts: List[str], batch_size: int = 100) -> np.ndarray:
                                  
        if isinstance(self.config, RetrievalConfig) and hasattr(self, 'embedding_service'):
            return self.embedding_service.get_embeddings(texts, batch_size)

                             
        embeddings = []
        progress_bar = tqdm(range(0, len(texts), batch_size), desc="Getting embeddings", disable=not self.config.verbose)

        for i in progress_bar:
            batch_texts = texts[i:i + batch_size]

            try:
                response = self.openai_client.embeddings.create(
                    model=self.config.embedding_model,
                    input=batch_texts
                )

                batch_embeddings = [item.embedding for item in response.data]
                embeddings.extend(batch_embeddings)

                              
                time.sleep(0.05)

            except Exception as e:
                if self.config.verbose:
                    print(f"Error getting embeddings for batch {i}: {e}")
                embeddings.extend([[0.0] * 1536] * len(batch_texts))

        return np.array(embeddings)

    def batch_retrieve(self, queries: List[str]) -> List[List[Dict[str, Any]]]:
         
        if not self.index:
            raise ValueError("Index not initialized")

        try:
            query_embeddings = self.batch_get_embeddings(queries)    
            all_results = []
            for i, (query, query_embedding) in enumerate(zip(queries, query_embeddings)):
                if self.config.verbose and i % 10 == 0:
                    print("Batch retrieving")
                                                
                try:
                                                        
                    vector_store = self.index.vector_store

                                          
                    query_obj = VectorStoreQuery(
                        query_embedding=query_embedding.tolist(),
                        similarity_top_k=self.config.top_k
                    )

                             
                    query_result = vector_store.query(query_obj)

                             
                    results = []
                    if hasattr(query_result, 'nodes') and query_result.nodes:
                        for node in query_result.nodes:
                            results.append({
                                'id': node.id_,
                                'content': node.get_content(),
                                'metadata': node.metadata,
                                'score': getattr(node, 'score', 0.0)
                            })
                    elif hasattr(query_result, 'similarities') and query_result.similarities:
                                                      
                        for j, (node_id, similarity) in enumerate(zip(query_result.ids, query_result.similarities)):
                            if j >= self.config.top_k:
                                break
                                             
                            try:
                                node = self.index.docstore.get_node(node_id)
                                results.append({
                                    'id': node_id,
                                    'content': node.get_content(),
                                    'metadata': node.metadata,
                                    'score': float(similarity)
                                })
                            except Exception:
                                                     
                                results.append({
                                    'id': node_id,
                                    'content': f"Document {node_id}",
                                    'metadata': {},
                                    'score': float(similarity)
                                })

                    all_results.append(results)

                except Exception as e:
                    if self.config.verbose:
                        print(f"Error retrieving for query {i}: {e}")
                    all_results.append([])
            return all_results

        except Exception as e:
            if self.config.verbose:
                print(f"Error retrieving: {e}")
            return [self.retrieve(query) for query in queries]
