import torch
from transformers import SamModel, SamProcessor,pipeline
generator =  pipeline("mask-generation", device = 0, points_per_batch = 256,model="facebook/sam-vit-huge")

class_names = list(range(1000))

from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt

from torchvision import datasets, models, transforms
model = models.resnet50(weights='IMAGENET1K_V2').cuda().eval()
def mymodel(images):
    ### define three wats to pre-process a given image
    if type(images[0]) is np.ndarray:
        images = [Image.fromarray(image) for image in images]
    test_preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]) ### a complete imagenet data pre-process
    
    images = [test_preprocess(image) for image in images]
    images = torch.stack(images).to('cuda')

    predict_org = torch.nn.functional.softmax(model(images),dim=1)
    return predict_org.detach().cpu().numpy()
    


def filter_masks(masks):
    to_remove = set()
    def check(mask1, mask2):
        intersection = np.logical_and(mask1, mask2).sum()
        t =  intersection/min(np.sum(mask1), np.sum(mask2))
        # print(t)
        if t > 0.9:
            return True
    for i in range(len(masks)):
        for j in range(i + 1, len(masks)):
            if check(masks[i], masks[j]):
                if masks[i].sum() > masks[j].sum():  # 选择面积较大的mask
                    to_remove.add(j)
                else:
                    to_remove.add(i)
    # print(to_remove)
    # add a mask which is not contained by any other mask
    now = np.ones(masks[0].shape, dtype=bool)
    for i in range(len(masks)):
        now &= ~masks[i]
    if now.sum() > 0:
        masks.append(now)
    return [mask for idx, mask in enumerate(masks) if idx not in to_remove]
from cexl.utils import Concept, AbstractConceptPredictor, LocalConceptPredictor, AbstractPredictor
class ImagePredictor(AbstractPredictor):
    
    batch_szie = 32
    def __init__(self, model):
        super().__init__()
        self.model = model

    def _predict(self, images):
        if isinstance(images, list):
            pass
        elif isinstance(images, np.ndarray):
            if images.ndim == 3:
                images = [images]
            elif images.ndim == 4:
                images = [Image.fromarray(images[x]) for x in range(images.shape[0])]
        res = []
        for i in range(0, len(images), self.batch_szie):
            res.extend(self.model(images[i:i+self.batch_szie]))
            
        # res = model.predict(images)
        # res = [res[x].probs.data for x in range(len(res))]
        res = np.stack(res)
        return res
    
ipredictor = ImagePredictor(mymodel)
class LocalImageConceptPredictor(AbstractPredictor):
    def __init__(self, predictor:ImagePredictor,image:np.ndarray, masks:np.ndarray=None, sam_masks:list=None):
        super().__init__()
        self.predictor = predictor
        if masks is None:
            masks = [sam_masks[i] for i in range(len(sam_masks))]
        self.masks = masks.copy()
        self.image = image

    def apply_mask(self, image, mask):
        return image * np.stack([mask]*3,axis=-1)
    
    def _predict(self, data: np.ndarray) -> np.ndarray:
        def gen_image(data):
            mask = np.ones_like(self.masks[0])
            for i in range(len(data)):
                if not data[i]:
                    mask &= ~self.masks[i]
            # plt.imshow(self.apply_mask(self.image, mask))
            return self.apply_mask(self.image, mask)
        if data.shape[-1] != len(self.masks):
            raise ValueError("data shape is not match with masks")
        data = data.reshape(-1, len(self.masks))
        images = [gen_image(d) for d in data]
        return self.predictor._predict(images)
    
def apply_mask(image,mask):
    return image * ((np.stack([mask]*3,axis=-1)*9+1)/10)/255


from lime.lime_image import LimeImageExplainer
from cexl.cexl_lime.lime_concept import LimeConceptExplainer
from anchor.anchor_image import AnchorImage
from cexl.cexl_anchor.anchor_concept import AnchorConcept
from cexl.cexl_kshap import ConceptKernelExplainer
from cexl.cexl_lore.cexl_lore import ConceptLore
from cexl.cexl_lore.encoder_decoder.concept_enc import IdenticalEnc
from cexl.cexl_lore.surrogate.text_decision_tree import CeXLTextDecisionTreeSurrogate
from cexl.cexl_lore.neighgen.genetic_text import TextGeneticGenerator
from cexl.cexl_lore.bbox.bbox import AbstractBBox


class LoreBBox(AbstractBBox):
    def __init__(self, predictor):
        self.model = predictor

    def predict(self, data):
        return self.model.predict(data).argmax(axis=-1)
    
    def predict_proba(self, data):
        return self.model.predict(data)
import pickle
def save_obj(obj, path):
    with open(path, 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
import json
from tqdm.auto import tqdm
tests = sorted(os.listdir('imagenet-test'))
mask_generator = generator
from skimage.segmentation import quickshift
segmentation_fn = lambda x: quickshift(x, kernel_size=20, # noqa
                                        max_dist=200, ratio=0.2)
res_path = "resutls3-resnet"
for path in tqdm(os.listdir('imagenet-test')):
    image = Image.open(os.path.join('./imagenet-test', path))
    
    if image.size[0] > 1024 or image.size[1] > 1024:
        continue

    image = np.array(image.convert("RGB"))
    filename = path.split('.')[0]
    
    if os.path.exists(os.path.join(res_path, filename, 'lore_counterfactual')):
        continue
    os.makedirs(os.path.join(res_path, filename),exist_ok=True)

    outputs = mask_generator(Image.fromarray(image), points_per_batch = 256)
    segments = segmentation_fn(image)

    
    masks = [segments == i for i in range(np.max(segments)+1)]
    sam_masks = outputs['masks']

    sam_masks = filter_masks(sam_masks)
    lpredictor = LocalImageConceptPredictor(ipredictor, image, sam_masks=sam_masks)


    pred = lpredictor.predict(np.ones((1,len(sam_masks))))
    prediction = pred.argmax()
    config = {"prediction:":(prediction,class_names[prediction]), "masks": masks, "sam_masks": sam_masks}
    save_obj(config,os.path.join(res_path, filename,'config.pkl'))



    # LIME
    lime_explainer = LimeImageExplainer()
    np.array(image).shape
    exp = lime_explainer.explain_instance(image,ipredictor.predict, top_labels=1000,batch_size=1000,num_features=len(masks),num_samples=10,segmentation_fn=segmentation_fn)

    # scores = exp.local_exp[list(exp.local_exp.keys())[0]]
    scores = exp.local_exp
    save_obj(scores, os.path.join(res_path, filename,'lime.pkl'))


    clime_explainer = LimeConceptExplainer()
    cexp = clime_explainer.explain_instance(sam_masks, lpredictor, top_labels=1000, num_samples=1000,num_features=len(sam_masks))
    # scores = exp.local_exp[cexp.predict_proba.argmax()]
    scores = cexp.local_exp
    save_obj(scores, os.path.join(res_path, filename,'clime.pkl'))

    #Anchor
    anchor_explainer = AnchorImage(segmentation_fn=segmentation_fn)
    segments,exp = anchor_explainer.explain_instance(image, ipredictor.predict)
    # ret.append((feature, name, mean, negatives, train_support))
    scores = [(e[0],e[0]) for e in exp]
    save_obj(scores, os.path.join(res_path, filename,'anchor.pkl'))

    canchor_explainer = AnchorConcept(None)
    cexp = canchor_explainer.explain_instance(lpredictor, sam_masks)
    cexp.names()
    scores = [(int(name.split('_')[1]),name) for name in cexp.names()]
    save_obj(scores, os.path.join(res_path, filename,'canchor.pkl'))

    # #SHAP
    res_col = np.argsort(pred[0])[-5:]
    res_col = list(res_col)
    res_col.reverse()
    
    # res_col = list(range(1000))

    shap_predictor = LocalImageConceptPredictor(ipredictor, image, masks)
    shap_predictor.predict(np.ones((1,len(masks)))).argmax()
    shap_explainer = ConceptKernelExplainer(shap_predictor, np.zeros((1,len(masks))), concepts=masks)
    exp = shap_explainer.shap_values(np.ones((1,len(masks))),nsamples = 1000,model_out = shap_predictor.predict(np.ones((1,len(masks)))),res_col = res_col)
    scores = exp.copy()
    save_obj(scores, os.path.join(res_path, filename,'shap.pkl'))

    cexl_shap_explainer = ConceptKernelExplainer(lpredictor, np.zeros((1,len(sam_masks))), concepts=sam_masks)
    cexl_shap_exp = cexl_shap_explainer.shap_values(np.ones((1,len(sam_masks))),nsamples = 1000,model_out = lpredictor.predict(np.ones((1,len(sam_masks)))), res_col = res_col)
    scores = cexl_shap_exp.copy()
    save_obj(scores, os.path.join(res_path, filename,'cexl_shap.pkl'))


    #LORE
    lorepredictor = LoreBBox(shap_predictor)
    encoder = IdenticalEnc()
    surrogate = CeXLTextDecisionTreeSurrogate()
    generator = TextGeneticGenerator(bbox=lorepredictor,encoder=encoder,surrogate=surrogate,metric=1,ngen=5)
    lore_explainer = ConceptLore(lorepredictor, encoder=encoder, surrogate=surrogate, generator= generator)
    lore_exp = lore_explainer.explain(np.ones(len(masks)))
    print(lore_exp['rule'])
    scores = [(int(lore_exp['rule'].premises[i].variable.split('_')[1]),lore_exp['rule'].premises[i].variable) for i in range(len(lore_exp['rule'].premises))]
    save_obj(scores, os.path.join(res_path, filename,'lore_rule.pkl'))

    print(lore_exp['counterfactuals'][0])
    scores = [(int(lore_exp['counterfactuals'][0].premises[i].variable.split('_')[1]),str(lore_exp['counterfactuals'][0].premises[i])) for i in range(len(lore_exp['counterfactuals'][0].premises))]
    save_obj((scores,lore_exp['counterfactuals'][0].consequences), os.path.join(res_path, filename,'lore_counterfactual.pkl'))

    cexllorepredictor = LoreBBox(lpredictor)
    cgenerator = TextGeneticGenerator(bbox=cexllorepredictor,encoder=encoder,surrogate=surrogate,metric=1,ngen=5)

    cexl_lore_explainer = ConceptLore(cexllorepredictor, encoder=encoder, surrogate=surrogate, generator= cgenerator)
    cexl_lore_exp = cexl_lore_explainer.explain(np.ones(len(sam_masks)))
    print(cexl_lore_exp['rule'])
    scores = [(int(cexl_lore_exp['rule'].premises[i].variable.split('_')[1])-1,cexl_lore_exp['rule'].premises[i].variable) for i in range(len(cexl_lore_exp['rule'].premises))]
    save_obj(scores
              , os.path.join(res_path, filename,'cexl_lore_rule.pkl'))

    print(cexl_lore_exp['counterfactuals'][0])
    scores = [(int(cexl_lore_exp['counterfactuals'][0].premises[i].variable.split('_')[1])-1,str(cexl_lore_exp['counterfactuals'][0].premises[i])) for i in range(len(cexl_lore_exp['counterfactuals'][0].premises))]
    save_obj((scores, cexl_lore_exp['counterfactuals'][0].consequences)
              , os.path.join(res_path, filename,'cexl_lore_counterfactual.pkl'))


