import torch
from torch import Tensor
import numpy as np
from typing import List, Optional, Union
import torch.nn as nn
from modelscope import AutoModelForSequenceClassification, AutoTokenizer, AutoModel
import torch.nn.functional as F
import sys
sys.path.append('..')
from global_utils.utils import generate_general, generate_general_rm, async_generate_general
# from utils import generate_general, generate_general_rm, async_generate_general
import math
import re
from tqdm import tqdm
import asyncio
import time
from transformers import AutoTokenizer, AutoModel


def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]


def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery: {query}'


class LinqEmbedMistral:
    def __init__(self, model_path, name, device):
        self.em_tokenizer = None
        self.em = None
        self.model_path = model_path
        self.name = name
        self.device = device
        self.load_model()

    def load_model(self):
        self.em = AutoModel.from_pretrained(
            pretrained_model_name_or_path=self.model_path,
            torch_dtype=torch.bfloat16,
            device_map=self.device,
        )
        self.em_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_path)

    def obtain_embedding(self, sentences, tasks=None, batch_size=4, max_length=8192):
        # obtain the embedding of each sentence
        # if using batch mode, question is list[str], response is list[str], return list[float]
        if tasks is not None:
            assert len(sentences) == len(tasks)
            sentences = [get_detailed_instruct(t, s) if t != '' else s for s, t in zip(sentences, tasks)]
        ins_num = len(sentences)
        batch_index = list(range(0, ins_num, batch_size)) + [ins_num]
        embedding_list = []
        for i in range(len(batch_index) - 1):
            st_index, end_index = batch_index[i], batch_index[i + 1]
            input_texts = np.array(sentences)[st_index:end_index].tolist()
            # Tokenize the input texts
            batch_dict = self.em_tokenizer(input_texts, max_length=max_length, padding=True, truncation=True,
                                   return_tensors="pt").to(self.em.device)
            with torch.no_grad():
                outputs = self.em(**batch_dict)
            embedding = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
            embedding_list.extend(embedding.detach().cpu().tolist())
        return embedding_list

class QwenEmbed:
    def __init__(self, model_path, name, device):
        self.em_tokenizer = None
        self.em = None
        self.model_path = model_path
        self.name = name
        self.device = device
        self.load_model()

    def load_model(self):
        self.em = AutoModel.from_pretrained(
            pretrained_model_name_or_path=self.model_path,
            device_map=self.device, 
            attn_implementation="flash_attention_2", 
            torch_dtype=torch.float16
        )
        self.em_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_path, padding_side='left')

    def obtain_embedding(self, sentences, tasks=None, batch_size=4, max_length=8192):
        # obtain the embedding of each sentence
        # if using batch mode, question is list[str], response is list[str], return list[float]
        if tasks is not None:
            assert len(sentences) == len(tasks)
            sentences = [get_detailed_instruct(t, s) if t != '' else s for s, t in zip(sentences, tasks)]
        ins_num = len(sentences)
        batch_index = list(range(0, ins_num, batch_size)) + [ins_num]
        embedding_list = []
        for i in range(len(batch_index) - 1):
            st_index, end_index = batch_index[i], batch_index[i + 1]
            input_texts = np.array(sentences)[st_index:end_index].tolist()
            # Tokenize the input texts
            batch_dict = self.em_tokenizer(input_texts, max_length=max_length, padding=True,truncation=True,
                                   return_tensors="pt").to(self.em.device)
            with torch.no_grad():
                outputs = self.em(**batch_dict)
            embedding = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
            embedding_list.extend(embedding.detach().cpu().tolist())
        return embedding_list

model_dict = {
    'Linq-Embed-Mistral': LinqEmbedMistral
}

em_path_dict = {
    "Linq-Embed-Mistral": "/fs-computility/mabasic/shared/models/embedding_model/Linq-Embed-Mistral"
}

def auto_get_em(model_name):
    return model_dict[model_name]

if __name__ == '__main__':
    pass