from dataclasses import dataclass
from typing import Union, Literal
import logging

import torch

import tiktoken
from openai import OpenAI, AzureOpenAI


@dataclass
class OpenAIRetrieverArgs:
    client: str = "openai"  # or "azure"
    model_name: str = "text-embedding-small-3"
    top_k: int = 5
    chunk_size: int = 100
    overlap: int = 10
    measure: Literal["cosine", "dot"] = "cosine"
    normalize_embeddings: bool = True
    use_recursive_text_splitter: bool = False


class OpenAIRetriever:
    def __init__(self, args: OpenAIRetrieverArgs):
        self.args = args
        self.model_name = args.model_name

        if args.client == "openai":
            self.client = OpenAI()
        elif args.client == "azure":
            self.client = AzureOpenAI()

    def encode(self, text: Union[str, list[str]]):
        return self.client.embeddings.create(input=[text], model=self.model_name).data[0].embedding

    def retrieve(self, query: str, chunks: Union[str, list[str]]):
        def _normalize_embeddings(embeddings: torch.Tensor) -> torch.Tensor:
            return embeddings / embeddings.norm(dim=-1, keepdim=True)

        logging.info(f"Encoding {len(chunks)} chunks...")
        chunks_embedddings = [torch.tensor(self.encode(chunk)) for chunk in chunks]
        logging.info(f"Encoding query: {query}")
        query_embeddings = self.encode(query)
        chunks_embedddings = torch.stack(chunks_embedddings)
        query_embeddings = torch.tensor(query_embeddings)

        if self.args.normalize_embeddings:
            query_embeddings = _normalize_embeddings(query_embeddings)
            chunks_embedddings = _normalize_embeddings(chunks_embedddings)

        similarity_scores = torch.nn.functional.cosine_similarity(
            query_embeddings, chunks_embedddings
        )
        scores, indices = torch.topk(similarity_scores, k=min(self.args.top_k, len(chunks)))
        return scores, indices
