# %%
from curses import meta
from hashlib import sha256
import os
import pickle
import token
from typing import List

from networkx import union
from transformers import BertTokenizer, BertForSequenceClassification,GPT2Tokenizer, GPT2ForSequenceClassification, T5Tokenizer, AutoTokenizer, T5ForConditionalGeneration,AutoModelForSequenceClassification, AutoModelForCausalLM, AutoTokenizer,BitsAndBytesConfig
# from auto_gptq import AutoGPTQForCausalLM
from openai import OpenAI

import numpy as np
import torch
import socket
import subprocess
import multiprocessing
import json
import sys
from tqdm.auto import tqdm
import time

# from cexl.utils import AbstractConceptLLM,AbstractPredictor
# from ..lstm import LSTMTokenizer,LSTMforTextClassification
# %%
from abc import ABCMeta, abstractmethod

class AbstractTextPredictor:

    @classmethod
    def from_pretrained(cls, model_type,model_path,device='cuda',batch_size = 32):
        if model_type == 'bert':
            return BertPredictor(model_path,device,batch_size)
        elif model_type == 'gpt2':
            return GPTPredictor(model_path,device,batch_size)
        elif model_type == 'llama2':
            return Llama2Predictor()
        elif model_type == 't5':
            return T5Predictor(model_path,device,batch_size)
        elif model_type == 'lstm':
            return LSTMPreidctor(model_path,device)
        else:
            raise Exception("Unknown model type: %s" % model_type)

    @abstractmethod
    def __init__(self,unk_token):
        self.unk_token = unk_token
        self.__timer = 0
        pass
    def clear_timer(self):
        self.__timer = 0
    def timer(self):
        return self.__timer
    
    @abstractmethod
    def _predict(self, text, **kwargs):
        raise NotImplementedError
    
    def _predict_lr(self, text, **kwargs):
        res:np.ndarray = self._predict(text, **kwargs)
        res = (res[:,1] > res[:,0]).astype(int)
        return res

    def predict(self, text, **kwargs):
        if type(text) == str:
            text = [text]
        if type(text) == np.ndarray:
            text = text.tolist()
        if type(text) != list:
            raise TypeError('Input must be a string or a list of strings')
        bg = time.time()
        res =  self._predict(text, **kwargs)
        self.__timer += time.time() - bg
        return res
    
    def predict_lr(self, text, **kwargs):
        if type(text) == str:
            text = [text]
        if type(text) == np.ndarray:
            text = text.tolist()
        if type(text) != list:
            raise TypeError('Input must be a string or a list of strings')
        bg = time.time()
        res =  self._predict_lr(text, **kwargs)
        self.__timer += time.time() - bg
        return res

# %%
class BertPredictor(AbstractTextPredictor):
    def __init__(self, model_path, device='cuda', batch_size = 32):
        
        model = AutoModelForSequenceClassification.from_pretrained(model_path)
        model.eval()
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.device = device
        self.batch_size = batch_size
        super().__init__(unk_token=tokenizer.unk_token)

    @torch.no_grad()
    def _predict(self,text, logtis = True, **kwargs):
        batch_size = self.batch_size
        model = self.model
        tokenizer = self.tokenizer
        probs = []
        for i in range(0,len(text),batch_size):
            inputs = tokenizer(text[i:i+batch_size], return_tensors="pt", padding=True).to(self.device)
            if not logtis:
                outputs = model(**inputs).logits.softmax(1)
            else:
                outputs = model(**inputs).logits
            probs.append(outputs)
        probs = torch.concat(probs)
        return probs.cpu().detach().numpy()


# %%
class GPTPredictor(AbstractTextPredictor):
    
    def __init__(self, model_path, device='cuda',batch_size = 32):
        model = GPT2ForSequenceClassification.from_pretrained(model_path)
        model.eval()
        tokenizer = GPT2Tokenizer.from_pretrained(model_path)
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.device = device
        self.batch_size = batch_size
        super().__init__(unk_token=tokenizer.unk_token)

    @torch.no_grad()
    def _predict(self,text, **kwargs):
        batch_size = self.batch_size
        model = self.model
        tokenizer = self.tokenizer
        probs = []
        for i in range(0,len(text),batch_size):
            if len(text[i]) == 0:
                text[i] = tokenizer.pad_token
            inputs = tokenizer(text[i:i+batch_size], return_tensors="pt", padding=True).to(self.device)
            outputs = model(**inputs,).logits.softmax(1)
            probs.append(outputs)
        probs = torch.concat(probs)
        return probs.cpu().detach().numpy()
    
  
# %%
class LLMSentimentPredictor(AbstractTextPredictor):
    
    system_prompt = \
"""From now on, you should act as a sentiment analysis neaural network. You should classify the sentiment of a sentence into positive or negative. The input sentence may be empty. In each task, you will be given the sentences to be classified, which end with ##### , and then you should reply the sentiment of the sentence by positive or negative. 
"""

    user_prompt = \
"""Perform the following task, your answer should only be positive or negative:
Sentence:
{}
#####

Sentiment:
"""

    labels = ['negative','positive']

    def __init__(self, **kwargs):
        super().__init__('<UNK>', **kwargs)
    
    @classmethod
    def get_user_prompt(cls,texts:str|List[str]):
        if type(texts) == str:
            texts = [texts]
        return [cls.user_prompt.format(text) for text in texts]
        

class QwenPredictor(LLMSentimentPredictor):
    
    def __init__(self, examples, model_name="qwen/Qwen2.5-1.5B-Instruct", **kwargs):
        super().__init__(examples, **kwargs)
        self.model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype="auto",device_map="sequential")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.encoded_label = self.tokenizer.encode(self.labels, return_tensors="pt")

    @torch.no_grad()
    def _predict(self, texts, **kwargs):
        results = []
        
        for text in texts:
            user_prompt = self.get_user_prompt(text)[0]
            messages =[
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": user_prompt}
                ] 

            input_texts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            # print(input_texts)
            
            model_inputs = self.tokenizer(input_texts, return_tensors="pt").to(self.model.device)
            res = self.model.generate(
                **model_inputs,
                max_new_tokens=1,
                output_scores=True,
                return_dict_in_generate=True,
                # temperature=0.0
            )
            results.append(res.scores[0][0][self.encoded_label].cpu())
        results =  torch.concatenate(results, axis=0)
        results = torch.softmax(results, dim=1).cpu().numpy()
        return results

class QwenGPTQPredictor(QwenPredictor):
    
    def __init__(self, examples, model_name="qwen/Qwen2.5-1.5B-Instruct-GPTQ", **kwargs):
        self.model = AutoGPTQForCausalLM.from_quantized(model_name,torch_dtype="auto",device_map="sequential")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.encoded_label = self.tokenizer.encode(self.labels, return_tensors="pt")
            
class LLamaPredictor(LLMSentimentPredictor):
    
    def __init__(self, examples, model_name="meta-llama/Llama-3.1-8B-Instruct", quantize = False, **kwargs):
        super().__init__(examples, **kwargs)
        if quantize:
            quantization_config = BitsAndBytesConfig(load_in_8bit=True)
            self.model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.bfloat16,device_map="sequential",quantization_config=quantization_config)
        else:
            self.model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype="auto",device_map="sequential")
        # self.model.generate = torch.compile(self.model.generate, mode="reduce-overhead", fullgraph=True)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.encoded_label = self.tokenizer.encode(self.labels, return_tensors="pt",  add_special_tokens=False)

    @torch.no_grad()
    def _predict(self, texts, **kwargs):
        results = []
        
        for text in texts:
            user_prompt = self.get_user_prompt(text)[0]
            messages =[
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": user_prompt}
                ] 

            input_texts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            # print(input_texts)
            
            model_inputs = self.tokenizer(input_texts, return_tensors="pt").to(self.model.device)
            res = self.model.generate(
                **model_inputs,
                max_new_tokens=1,
                output_scores=True,
                return_dict_in_generate=True,
                pad_token_id=self.tokenizer.eos_token_id
            )
            results.append(res.scores[0][0][self.encoded_label].cpu())
        results =  torch.concatenate(results, axis=0)
        results = torch.softmax(results, dim=1).cpu().numpy()
        return results
            
class GPT4oPredictor(LLMSentimentPredictor):

    def __init__(self, examples, model_name="gpt-4o", **kwargs):
        super().__init__(examples, **kwargs)

        self.client = OpenAI(
            base_url='https://api.openai-proxy.org/v1',
            api_key='sk-r54xmSIwF20yoE5TgPYeResLFP0S4SdJTuMiQjn5v8Pa7EoP',
        )
        self.model_name=model_name


    def _predict(self, texts, **kwargs):
        results = []
        
        for text in texts:
            user_prompt = self.get_user_prompt(text)[0]
            messages =[
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": user_prompt}
                ] 
            
            response = self.client.chat.completions.create(
                messages=messages,
                max_tokens=1,
                logprobs=2,
                model=self.model_name,
            )   

            logprobs = response["choices"][0]["logprobs"]["top_logprobs"][0]
            print(logprobs)
            results.append(logprobs)

        results =  torch.concatenate(results, axis=0)
        results = torch.softmax(results, dim=1).cpu().numpy()
        return results

         


class Llama2Predictor(AbstractTextPredictor):
    
    def __init__(self, **kwargs):
        super().__init__('<UNK>')
        # check if a server is listened at localhost:9999
        try:
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect(('localhost', 9999))
            s.close()
        except ConnectionRefusedError:
            # if not, start a server
            print("Please The Starting the llama Server")
            raise ConnectionRefusedError
    
    @staticmethod
    def chat_completion(messages):
    # connect to server localhost:9999
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.connect(('localhost', 9999))
        s.sendall(json.dumps(messages).encode())
        response = b""
        while True:
            data = s.recv(1024)
            if not data:
                break
            response += data
        s.close()
        if response == b"error":
            raise Exception("error")
        return json.loads(response.decode())


    def _predict_lr(self, text, **kwargs):
        message = [
            {"role": "system", "content": "From now on, you should act as a sentiment analysis neaural network. You should classify the sentiment of a sentence into positive or negative. If the sentence is positive, you should reply 1. Otherwise, if it's negative, you should reply 0. There may be some words that are masked in the sentence, which are represented by <UNK>. The input sentence may be empty, which is represented by <EMPTY>.You will be given the sentences to be classified, and you should reply the sentiment of the sentence by 1 or 0.\nThere are two examples:\nSentence:\nI am good\nSentiment:\n1\nSentence:\nThe movie is bad.\nSentiment:\n0\n You must follow this format. Then I'll give you the sentence. Remember Your reply should be only 1 or 0. Do not contains any other content in your response. The input sentence may be empty."}
        ]
        res = np.zeros(len(text))
        res = list(range(len(text)))
        for _ in range(3):
            for text in tqdm(text,position=1):
                flag = False
                if len(text.strip())==0:
                    text = '<EMPTY>'
                for _ in range(3):
                    try:
                        respond = self.chat_completion([message+[{"role": "user", "content": 'Sentence:\n'+text}]])[0]['generation']['content'][-1]
                        if respond == '1' or respond == '0':
                            res.append(int(respond))
                            flag = True
                            break
                    except Exception as e:
                        if e is KeyboardInterrupt:
                            raise e
                        pass
                if not flag:
                    res.append(0)
                    print(f'{text} is not replied correctly.',file=sys.stderr)
        return np.array(res)

    def _predict(self, text, **kwargs):
        message = [
            {"role": "system", "content": "From now on, you should act as a sentiment analysis neaural network. You should classify the sentiment of a sentence into positive or negative. You should give the probability of the sentence to be a positive sentense. There may be some words that are masked in the sentence, which are represented by <UNK>. The input sentence may be empty, which is represented by <EMPTY>. You will be given the sentences to be classified, and you should reply the sentiment of the sentence by 1 or 0.\nThere are two examples:\nInput:\nSentence:\nI am good\nReply:\nSentiment:\n0.934\nSentence:\nThe movie is bad.\nnReply:\nSentiment:\n0.003\n The response probability value is just an example, not a real reault of sentiment. You must follow this format. Then I'll give you the sentence. Remember Your reply should only contains a single number, which is the probability of the sentence to be a positive sentence. Do not contains any other content in your response."}
        ]
        res = []
        for text in tqdm(text,position=1): 
            flag = False
            if len(text.strip())==0:
                text = '<EMPTY>'
            for i in range(3):
                try:
                    respond = self.chat_completion([message+[{"role": "user", "content": 'Sentence:\n'+text}]])[0]['generation']['content'].split(':')[-1]
                    res.append([1-float(respond),float(respond)])
                    flag = True
                    break
                except Exception as e:
                    if e is KeyboardInterrupt:
                        raise e
            if not flag:
                res.append([0.5,0.5])
                print(f'{text} is not replied correctly.',file=sys.stderr)
        return np.array(res)

class T5Predictor(AbstractTextPredictor):
    def __init__(self, model_path, device='cuda',batch_size = 32):
        model = T5ForConditionalGeneration.from_pretrained(model_path)
        model.eval()
        tokenizer = T5Tokenizer.from_pretrained(model_path)
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.device = device
        self.batch_size = batch_size
        super().__init__(unk_token=tokenizer.unk_token)

    @torch.no_grad()
    def _predict(self,text, **kwargs):
        batch_size = self.batch_size
        model = self.model
        tokenizer = self.tokenizer
        probs = []
        for i in range(0,len(text),batch_size):
            inputs = tokenizer(text[i:i+batch_size], return_tensors="pt", padding=True).to(self.device)
            preds = model.generate(inputs.input_ids,output_scores=True, return_dict_in_generate=True)
            # decoded_pred = tokenizer.batch_decode(preds.sequences, skip_special_tokens=True)
            probs.append(preds.scores[1][:,np.array([29,102])].softmax(-1))  # 29:n 102:p
        probs = torch.concat(probs)
        return probs.cpu().detach().numpy()

class LSTMPreidctor(AbstractTextPredictor):
    
    def __init__(self, model_path, device='cuda'):
        model = LSTMforTextClassification.from_pretrained(model_path,device)
        model.eval()
        self.tokenizer = LSTMTokenizer.from_pretrained(model_path)
        self.model = model
        self.device = device
        super().__init__(unk_token='<unk>')
    
    def collect_fn(self,batch):
        data = []
        for _data in batch:
            data.append(torch.tensor(self.tokenizer(_data),dtype=torch.int64))
        data = torch.nn.utils.rnn.pad_sequence(data, batch_first=True, padding_value=0)
        # print(data)
        # data = torch.cat(data)
        return data

    @torch.no_grad()
    def _predict(self,text,batch_size = 32, **kwargs):
        model = self.model
        tokenizer = self.tokenizer
        probs = []
        for i in range(0,len(text),batch_size):
            if len(text[i]) == 0:
                text[i] = tokenizer.pad_token
            inputs = self.collect_fn(text[i:i+batch_size]).to(self.device)
            # print(inputs)
            # print(inputs.dtype)
            outputs = model(inputs)[0].softmax(-1)
            # print(outputs)
            probs.append(outputs)
        probs = torch.concat(probs)
        return probs.cpu().detach().numpy()



        

# %%
def get_predictor(model_type,model_path,device='cuda',batch_size = 32):
    if model_type == 'bert':
        return BertPredictor(model_path,device,batch_size)
    elif model_type == 'gpt2':
        return GPTPredictor(model_path,device,batch_size)
    elif model_type == 'llama2':
        return Llama2Predictor()
    elif model_type == 't5':
        return T5Predictor(model_path,device,batch_size)
    elif model_type == 'lstm':
        return LSTMPreidctor(model_path,device)
    else:
        raise Exception("Unknown model type: %s" % model_type)



# %%
if __name__ == '__main__':
    model = get_predictor('lstm','/home/XXXX-4/github-repo/ReX/models/LSTM-sst2/lstm_res')

# %%
