import numpy as np
from cexl_exp.utils.text_predict import AbstractTextPredictor
from cexl.cexl_lore.surrogate.text_decision_tree import TextDecisionTreeSurrogate,CeXLTextDecisionTreeSurrogate
from cexl.cexl_lore.cexl_lore import ConceptLore
from cexl.cexl_lore.encoder_decoder.text_enc import TextEnc
from cexl.cexl_lore.encoder_decoder.concept_enc import IdenticalEnc
from cexl.utils import AbstractConceptLLM
from llm.utils import ChatLlama
import pandas as pd
from cexl.utils import Concept
from hashlib import sha256
from typing import List
from cexl.utils import AbstractConceptPredictor,LocalConceptPredictor
import itertools

class TextConceptPredictor(AbstractConceptPredictor):

    def __init__(self, base_model:AbstractTextPredictor,llm:AbstractConceptLLM, sample_size:int = 1):
        super().__init__()
        self.cache = {}
        self.base_model = base_model
        self.llm = llm
        self.sample_size = sample_size

    def _predict(self, concepts_list:List[List[str]], **kwargs)->np.ndarray:
        res = np.zeros((len(concepts_list),2))
        for idx,concepts in enumerate(concepts_list):
            sha256_concepts = sha256(str(concepts).encode()).hexdigest()
            if sha256_concepts not in self.cache:
                instances = self.llm.inverse_transform([concepts]*self.sample_size)
                r = self.base_model.predict(instances, **kwargs)
                self.cache[sha256_concepts] = np.average(r,axis=0)
            res[idx] = self.cache[sha256_concepts]
        
        return res



from typing import Union


class LocalTextConceptPredictor(LocalConceptPredictor):
    
    def __init__(self, base_model:TextConceptPredictor, concepts: list[Concept], now_c: Union[list[str], np.ndarray], num_samples = 1):
        super().__init__(concepts,now_c)
        self.num_samples = num_samples
        self.base_model = base_model
        self.alt_concept = []
        for i in range(len(now_c)):
            self.alt_concept.append([x for x in concepts[i].possible_responses if x != now_c[i] and x != 'uncertain'])

    def _predict(self, Z:np.ndarray):
        res = []
        for z in Z:
            t = []
            rr = []
            for i in range(len(z)):
                if z[i]:
                    t.append([self.now_c[i]])
                else:
                    t.append(self.alt_concept[i])
            t = list(itertools.product(*t))
            idx = np.random.choice(len(t), self.num_samples, replace=True)
            t = [t[i] for i in idx]
            t = [list(x) for x in t]
            rr = self.base_model.predict(t)
            rr = np.mean(rr, axis=0)
            res.append(rr)
        return np.stack(res)

data = pd.read_csv('imdb-test.csv',index_col=0)
from utils.llama_predict import LlamaSentimentPredictor
llama_predictor = LlamaSentimentPredictor(model="llama3.1:8b-instruct-q4_K_M",base_urls=['http://162.105.88.109:11431'])
                                                                                        #  ,'http://162.105.88.109:11432'])
from utils.text_predict import AbstractTextPredictor
Bertmodel = AbstractTextPredictor.from_pretrained('bert','textattack/bert-base-uncased-imdb')
sentences = data['review'].to_list()
concept_json = pd.read_json('TBM_concept4.json')
concept_json = concept_json.to_dict(orient='records')
concepts = [Concept.from_dict(concept_json[i]) for i in range(len(concept_json))]
test = pd.read_json('check_res4.json',orient='index')

from typing import Union, List

from cexl.utils import Concept
from llm.utils import ChatLlama
from cexl.utils import AbstractConceptLLM

class ConceptLLM(AbstractConceptLLM):
    def __init__(self, basemodel:ChatLlama, concepts:List[Concept]):
        self.concepts = concepts
        self.basemodel = basemodel

        
    def inverse_transform(self, now_cs: Union[List[List[str]], np.ndarray])->np.ndarray:
        requirement = [list(zip(self.concepts,now_cs[idx])) for idx in range(len(now_cs))]

        # print(requirement) 
        res = self.basemodel.tbm_inverse_data4(np.ones((len(now_cs),len(self.concepts))),requirement)
        return res
import spacy
nlp = spacy.load('en_core_web_sm')
def get_nowc(text):
    idx = test[test['text']==text].index[0]
    now_c = [test[f'concept_{i}'][idx] for i in range(len(concepts))]
    return now_c
def remove_punctuation(text):
    # 使用 spaCy 处理文本
    doc = nlp(text)
    # 过滤掉标点符号并重新组合成字符串
    return ''.join(token.text_with_ws for token in doc if not token.is_punct)
ports = [11431,11432]
llama_model = ChatLlama(model='llama3.1:70b-instruct-q4_K_M',base_urls=
                        # ['http://localhost:30001','http://localhost:30002'])
                        ['http://10.129.162.125:'+str(port) for port in ports])
cllm = ConceptLLM(basemodel=llama_model,concepts=concepts)
from cexl.cexl_lore.bbox.bbox import AbstractBBox


class BBox(AbstractBBox):
    def __init__(self, predictor):
        self.predictor = predictor
    
    def predict(self, sample_matrix: list):
        return self.predictor.predict_lr(sample_matrix)

    def predict_proba(self, sample_matrix: list):
        return self.predictor.predict(sample_matrix)[:,:1]
class CBBox(AbstractBBox):
    def __init__(self, predcitor:LocalTextConceptPredictor, now_c:List[str]):
        self.predictor = predcitor
        self.now_c = now_c
    
    def predict(self, x:np.ndarray):
        res =  self.predictor.predict(x)
        res = np.argmax(res,axis=1)
        return res
        
    def predict_proba(self, x:np.ndarray)->np.ndarray:
        return self.predictor.predict(x)[:,1]
import pickle
def save_obj(obj, path):
    with open(path, 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
# LIME
from functools import partial
from lime.lime_text import LimeTextExplainer
from cexl.cexl_lime.lime_concept import LimeConceptExplainer
import time
from anchor.anchor_text import AnchorText
from cexl.cexl_anchor.anchor_concept import AnchorConcept
from cexl.cexl_kshap import ConceptKernelExplainer
from cexl.cexl_lore.neighgen.genetic_text import TextGeneticGenerator
import shap
import os


# text = test['text'][0]
class TextExplanation:
    def __init__(self, text, exp, exp_tm, pred_tm,type="Unknown"):
        self.text = text
        self.exp = exp
        self.exp_tm = exp_tm
        self.pred_tm = pred_tm
        self.type = type
    
    def to_dict(self):
        return {
            'text':self.text,
            'exp':self.exp,
            'exp_tm':self.exp_tm,
            'pred_tm':self.pred_tm,
            'type':self.type
        }
        
class TextUnifiedExplainer:
    
    
    def __init__(self, predictor:AbstractTextPredictor, text, save_path):
        self.text = text
        self.predictor = predictor
        self.now_c = get_nowc(text)
        self.tcpredictor = TextConceptPredictor(predictor, cllm, 1)
        self.lcpredictor = LocalTextConceptPredictor(self.tcpredictor,concepts,self.now_c,1)
        self.lore_bbox = BBox(predictor)
        self.clore_bbox = CBBox(self.lcpredictor,self.now_c)
        self.save_path = save_path

    def get_lime_exp(self):
        text = self.text
        predictor = self.predictor
        predictor.clear_timer()
        tm = time.time()
        text = remove_punctuation(text)
        lime_explainer = LimeTextExplainer(class_names=['0','1'])
        lime_exp = lime_explainer.explain_instance(text,predictor.predict)
        tm = time.time() - tm
        scores = lime_exp.as_list()
        save_obj(scores, os.path.join(self.save_path,'lime.pkl'))
        return TextExplanation(text,lime_exp.as_list(),tm,predictor.timer(),type="LIME")
    
    def get_lime_concept_exp(self):
        text = self.text
        # predictor = self.predictor
        tm = time.time()
        lime_explainer = LimeConceptExplainer()
        lcpredictor = self.lcpredictor
        lcpredictor.clear_timer()
        cexl_lime_exp = lime_explainer.explain_instance(concepts,lcpredictor,num_features=len(concepts),num_samples=50)
        llama_model.cleanup()
        tm = time.time() - tm
        scores = cexl_lime_exp.as_list()
        save_obj(scores, os.path.join(self.save_path,'clime.pkl'))

        return TextExplanation(text,cexl_lime_exp.as_list(),tm,lcpredictor.timer(),type="LIME_CONCEPT")
    
    def get_anchor_exp(self):
        text = self.text
        tm = time.time()
        predictor = self.predictor
        text = remove_punctuation(text)
        explainer = AnchorText(nlp,['neg','pos'])
        anchor_exp = explainer.explain_instance(text,predictor.predict_lr)
        tm = time.time() - tm
        scores = anchor_exp.names()
        save_obj(scores, os.path.join(self.save_path,'anchor.pkl'))

        return TextExplanation(text, anchor_exp.names(), tm, predictor.timer(),type="ANCHOR")
    
    def get_anchor_concept_exp(self):
        text = self.text
        tm = time.time()
        now_c = get_nowc(text)
        lcpredictor = self.lcpredictor
        explainer = AnchorConcept(['neg','pos'])
        lcpredictor.clear_timer()
        cexl_anchor_exp = explainer.explain_instance(lcpredictor, concepts)
        llama_model.cleanup()
        tm = time.time() - tm
        scores = cexl_anchor_exp.names()
        save_obj(scores, os.path.join(self.save_path,'canchor.pkl'))

        return TextExplanation(text,cexl_anchor_exp.names(),tm,lcpredictor.timer(),type="ANCHOR_CONCEPT")
        
    def get_kshap_exp(self):
        text = self.text
        def shap_predict(predict,words,x):
            # print(x)
            words = np.array(words)
            x = np.array(x).astype(bool)
            x = [words[z] for z in x]
            x = [' '.join(z).strip() for z in x]
            # print(x)
            return predict(x)
        
        
        words= remove_punctuation(text).split()
        predictor = self.predictor
        
        
        tm = time.time()
        back_ground = np.zeros((1,len(words)),dtype=int)
        test = np.ones(len(words),dtype=int)
        predict = partial(shap_predict,predictor.predict,words)
        predictor.clear_timer()
        explainer = shap.KernelExplainer(predict, back_ground, link="logit")
        explanation = explainer.shap_values(test,res_col=[1])
        explanation = list(zip(words,explanation[1]))
        tm = time.time() - tm
        scores = explanation.copy()
        save_obj(scores, os.path.join(self.save_path,'kshap.pkl'))

        return TextExplanation(text,explanation,tm,predictor.timer(),type="KSHAP")
    
    def get_kshap_concept_exp(self):
        now_c = self.now_c
        text = self.text

        lcpredictor = self.lcpredictor

        tm = time.time()
        lcpredictor.clear_timer()
        
        
        cexl_kshap_explainer = ConceptKernelExplainer(lcpredictor,np.zeros((1,len(concepts))),concepts)
        cexl_kshap_exp = cexl_kshap_explainer.shap_values(np.ones((1,len(concepts))),nsamples=50,model_out = lcpredictor.predict(np.zeros((1,len(now_c)))),res_col=[1])
        llama_model.cleanup()
        tm = time.time() - tm
        scores = cexl_kshap_exp.copy()
        save_obj(scores, os.path.join(self.save_path,'ckshap.pkl'))

        return TextExplanation(text,cexl_kshap_exp,tm,lcpredictor.timer(),type="KSHAP_CONCEPT")
    
    def get_lore_exp(self):
        bbox = self.lore_bbox
        text = self.text

        encoder = TextEnc(text,{})
        generator = TextGeneticGenerator(metric=1,encoder=encoder,ngen=2,bbox=bbox)
        surrogate = TextDecisionTreeSurrogate()
        bbox.predictor.clear_timer()
        tm = time.time()
        explainer = ConceptLore(bbox = bbox, surrogate=surrogate, encoder=encoder, generator=generator)
        lore_exp = explainer.explain(text)
        tm = time.time() - tm
        scores = [(int(lore_exp['rule'].premises[i].variable.split('_')[1]),lore_exp['rule'].premises[i].variable) for i in range(len(lore_exp['rule'].premises))]
        save_obj(scores, os.path.join(self.save_path,'lore_rule.pkl'))

        scores = [(int(lore_exp['counterfactuals'][0].premises[i].variable.split('_')[1]),str(lore_exp['counterfactuals'][0].premises[i])) for i in range(len(lore_exp['counterfactuals'][0].premises))]

        save_obj((scores,lore_exp['counterfactuals'][0].consequences), os.path.join(self.save_path,'lore_counterfactual.pkl'))

        return TextExplanation(text,[str(res['rule']), str(res['counterfactuals'])], tm, bbox.predictor.timer(),type="LORE")

    def get_lore_concept_exp(self):
        now_c = get_nowc(text)
        text = self.text
        
        bbox = self.clore_bbox
        encoder = IdenticalEnc()
        generator = TextGeneticGenerator(metric=1,encoder=encoder,ngen=2,bbox=bbox)
        surrogate = CeXLTextDecisionTreeSurrogate()
        bbox.predictor.clear_timer()
        tm = time.time()
        explainer = ConceptLore(bbox = bbox, surrogate=surrogate, encoder=encoder, generator=generator)
        lore_exp = explainer.explain(text)
        tm = time.time() - tm
        scores = [(int(lore_exp['rule'].premises[i].variable.split('_')[1]),lore_exp['rule'].premises[i].variable) for i in range(len(lore_exp['rule'].premises))]
        save_obj(scores, os.path.join(self.save_path,'clore_rule.pkl'))
        
        scores = [(int(lore_exp['counterfactuals'][0].premises[i].variable.split('_')[1]),str(lore_exp['counterfactuals'][0].premises[i])) for i in range(len(lore_exp['counterfactuals'][0].premises))]
        save_obj((scores,lore_exp['counterfactuals'][0].consequences), os.path.join(self.save_path,'clore_counterfactual.pkl'))

        return TextExplanation(text,[str(res['rule']), str(res['counterfactuals'])], tm, bbox.predictor.timer(),type="LORE_CONCEPT")
        


import os
res_path = 'results-bert'

res = []

for i in range(len(test)):
    idx = test.index[i]
    
    text = test['text'][idx]
    if len(Bertmodel.tokenizer(text)['input_ids']) > 512:
        continue
    # if os.path.exists(os.path.join(res_path,str(idx),'lime.pkl')):
        # continue
    os.makedirs(os.path.join(res_path,str(idx)),exist_ok=True)
    
    explainer = TextUnifiedExplainer(Bertmodel,text,os.path.join(res_path,str(idx)))

    # res.append(explainer.get_lime_exp().to_dict())
    # print('lime')
    # res.append(explainer.get_lime_concept_exp().to_dict())
    # print('lime_concept')
    # res.append(explainer.get_anchor_exp().to_dict())
    # print('anchor')
    # res.append(explainer.get_anchor_concept_exp().to_dict())
    # print('anchor_concept')
    res.append(explainer.get_kshap_exp().to_dict())
    print('kshap')
    res.append(explainer.get_kshap_concept_exp().to_dict())
    print('kshap_concept')
    # res.append(explainer.get_lore_exp().to_dict())
    # print('lore')
    # res.append(explainer.get_lore_concept_exp().to_dict())
    # print('lore_concept')
    # break
