import transformers
import torch
import numpy as np
from tqdm.auto import tqdm
import openai
import os

#BERT_MODEL_LIST = ["prajjwal1/bert-tiny", "prajjwal1/bert-mini", "prajjwal1/bert-small", "prajjwal1/bert-medium", "bert-base-cased", "bert-base-uncased", "bert-large-cased", "t5-base", "gpt2"]

class BERT:
    def __init__(self, model_name: str, device='cuda'):
        #assert model_name in BERT_MODEL_LIST

        self.device = device
        self.tokenizer = transformers.BertTokenizer.from_pretrained(model_name)
        self.model = transformers.BertModel.from_pretrained(model_name).to(self.device)

    def create_embeddings(self, sentences: list, batch_size: int = 8):
        embeddings = []
        print("Generating embeddings...")

        for i in tqdm(range(0, len(sentences), batch_size)):
            batch_sentences = sentences[i:i + batch_size]
            batch_inputs = self.tokenizer(batch_sentences, padding=True, truncation=True, return_tensors="pt").to(self.device)

            with torch.no_grad():
                batch_outputs = self.model(**batch_inputs)
                # For BERT, using the [CLS] token embedding (first token) for sentence representation
                batch_embeddings = batch_outputs.last_hidden_state[:, 0, :].detach().cpu().numpy()
                embeddings.extend(batch_embeddings)

        print("Done!")
        return np.array(embeddings)
    
class T5:
    def __init__(self, model_name: str):
        self.tokenizer = transformers.T5Tokenizer.from_pretrained(model_name)
        self.model = transformers.T5Model.from_pretrained(model_name)

    def create_embedding(self, sentence: str):
        # Prepare the input sentence
        inputs = self.tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=512)

        with torch.no_grad():
            # Get the output from the T5 encoder
            outputs = self.model.encoder(**inputs)

        # Extract the embeddings from the last layer
        # You might want to use pooling (mean, max) over the token embeddings to get a sentence embedding
        embeddings = outputs.last_hidden_state.mean(dim=1).squeeze()
        return embeddings.cpu().numpy()


    def create_embeddings(self, sentences: list):
        embeddings = []

        print("Generating embeddings...")
        for sentence in tqdm(sentences):
            embeddings.append(self.create_embedding(sentence))
        print("Done!")

        embeddings = np.stack(embeddings, axis=0)
        return embeddings

class GPT2:
    def __init__(self, model_name: str, device='cuda'):
        self.device = device
        self.tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name)

        # Set pad token to EOS token
        self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = transformers.GPT2Model.from_pretrained(model_name).to(self.device)

    def create_embedding(self, sentence: str):
        inputs = self.tokenizer(sentence, return_tensors="pt").to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)

        embeddings = np.array(outputs.last_hidden_state[0, -1].view(-1))
        return embeddings

    def create_embeddings(self, sentences: list, batch_size: int = 8):
        embeddings = []
        print("Generating embeddings...")

        for i in tqdm(range(0, len(sentences), batch_size)):
            batch_sentences = sentences[i:i + batch_size]
            batch_inputs = self.tokenizer(batch_sentences, padding=True, truncation=True, return_tensors="pt").to(self.device)

            with torch.no_grad():
                batch_outputs = self.model(**batch_inputs)
                batch_embeddings = batch_outputs.last_hidden_state[:, -1, :].detach().cpu().numpy()
                embeddings.extend(batch_embeddings)

        print("Done!")
        return np.array(embeddings)
    
class GPT_API:
    BATCH_SIZE = 20
    def __init__(self, model_name: str):
        assert model_name in ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]
        
        self.model_name = model_name
        openai.api_key = os.environ["OPENAI_API_KEY"]
        self.client = openai.OpenAI()

    def create_embedding(self, sentence: str):
        # The actual API call would look something like this:
        response = self.client.embeddings.create(input = [sentence], model=self.model_name).data[0].embedding
        return response
    
    def create_batch_embedding(self, sentences: list):
        # The actual API call would look something like this:
        response = self.client.embeddings.create(input = sentences, model=self.model_name).data
        response = [np.array(r.embedding) for r in response]
        return response

    def create_embeddings(self, sentences: list):
        total_embeddings = []

        # Calculate the number of batches
        num_batches = len(sentences) // self.BATCH_SIZE + (1 if len(sentences) % self.BATCH_SIZE > 0 else 0)

        print("Generating embeddings...")
        for i in tqdm(range(num_batches)):
            # Get the start and end indices for the current batch
            start_idx = i * self.BATCH_SIZE
            end_idx = start_idx + self.BATCH_SIZE

            # Process the current batch
            batch_sentences = sentences[start_idx:end_idx]
            batch_embeddings = self.create_batch_embedding(batch_sentences)
            total_embeddings.extend(batch_embeddings)

        print("Done!")

        # Stack the embeddings into a NumPy array
        embeddings = np.stack(total_embeddings, axis=0)
        return embeddings

def print_model_parameters(model):
    print("Model's state_dict:")
    for param_tensor in model.state_dict():
        print(param_tensor, "\t", model.state_dict()[param_tensor].size())