import os
from transformers import CLIPProcessor, CLIPModel
import torch
from PIL import Image
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration,  AutoProcessor, LlavaForConditionalGeneration, AutoModelForCausalLM 
import torch
from domainbed.datasets import DSPRITES, SMALLNORB, SHAPES3D, CELEBA, DEEPFASHION
from domainbed.iwildcam import iwildcam
from domainbed.fmow import FMoW
from domainbed.camelyon17 import Camelyon17 
import numpy as np
from PIL import Image
from tqdm import tqdm
from domainbed.lib.fast_data_loader import InfiniteDataLoader, FastDataLoader
import argparse
from domainbed import datasets
import sys
from domainbed.lib import misc
import time
import transformers
from torchvision import transforms
import re

class Zeroshot:
    def __init__(self, mode, prompt_type, dataset, batch_size=1):
        mode_list = ['InstructBLIP', 'CLIP-base', 'CLIP-large', 'LLAVA', 'Phi-3','Phi-3.5']
        assert mode in mode_list, f'choose one from {mode_list}'
        self.batch_size = batch_size
        self.mode = mode
        self.prompt_type = prompt_type
        self.dataset = dataset
        if dataset == 'iwildcam':
            self.data = iwildcam(root='/data', split='test',algo='zeroshot')
        elif dataset == 'FMOW':
            self.data = FMoW(root='/data', split='test',algo='zeroshot')
        elif dataset == 'CAMELYON17':
            self.data = Camelyon17(root='/data', split='test',algo='zeroshot')
        else:
            self.data = vars(datasets)[dataset](root= "/data/NeurIPS2024", dist_type='SC', dataset_size=1, split='test')
        # label_dict = {'DSPRITES': ['square', 'ellipse', 'heart'], 'SHAPES3D': ['cube', 'cylinder', 'sphere', 'capsule'], 'SMALLNORB':['animal', 'human', 'airplane', 'truck', 'car'],
        #     'CELEBA': ['female', 'male'], 'DEEPFASHION': ['dress', 'no_dress'], 'iwildcam':  ['background', 'cattle', 'elephants', 'impalas', 'zebras', 'giraffes', 'dik-diks'],
        #     'FMOW': self.data.label_names}

        self.labels = self.data.label_names#label_dict[dataset]

        label_list = ''
        for l_i, label in enumerate(self.labels):
            label = label.replace("_", " ")
            if l_i == len(self.labels)-1:
                label_list+=' or '
            else:
                label_list+=' '
            label_list+=f'\'{label}\''
        self.label_list = label_list
        self.get_model()
        self.get_prompt()

    def get_prompt(self):
        
        if self.mode in ['InstructBLIP','LLAVA']:
            if self.dataset == 'CAMELYON17':
                prompt_dict = {
                'general'  :  f"USER: <image>\nPlease answer yes if the image contains any tumor tissue, and no otherwise. Please provide only the name of the label.\nASSISTANT:",
                'general2'  :  f"USER: <image>\nPlease answer yes if the image contains any tumor tissue, and no otherwise. Please respond with a single word.\nASSISTANT:",
                'tailored'  :  f"USER: <image>\nPlease analyze the image and determine if it contains any tumor tissue. Respond with 'Yes' if tumor tissue is present, or 'No' if it is not. \nASSISTANT:",

                        }
                self.label_list = ['yes', 'no']
            else:


                if self.dataset == 'iwildcam':
                    if self.mode == 'LLAVA':
                        prompt_dict = {
                            'general'  :  f"USER: <image>\nClassify the object or animal in the image. Here is the list of labels to choose from: {self.label_list}. Please provide only the name of the label.\nASSISTANT:",
                            'general2'  :  f"USER: <image>\nClassify the image using the following labels: {self.label_list}. Please provide only the name of the label.\nASSISTANT:",
                            'tailored' :  f"USER: <image>\nChoose a label that best describes the image. Here is the list of labels to choose from: {self.label_list}. Please provide only the name of the label.\nASSISTANT:",
                            }
                    else:
                        prompt_dict = {
                            'general'  :  f"USER: <image>\nClassify the object or animal in the image. Please provide only the name of the label.\nASSISTANT:",
                            'general2'  :  f"USER: <image>\nClassify the image. Please provide only the name of the label.\nASSISTANT:",
                            'tailored' :  f"USER: <image>\nChoose a label that best describes the image. Please provide only the name of the label.\nASSISTANT:",
                            }
                    
                else:
                
                    if self.dataset == 'DEEPFASHION':
                        prompt = f"USER: <image>\nIs a person wearing a dress or not? Please answer in yes or no.\nASSISTANT:"
                        self.labels = ['yes', 'no']
                    elif self.dataset == 'CELEBA':
                        prompt = f"USER: <image>\nClassify the person in the image into {self.label_list}. Please provide only the name of the label\nASSISTANT:"
                
                    elif self.dataset in ['DSPRITES','SMALLNORB','SHAPES3D']:
                        prompt = f"USER: <image>\nClassify the object in the image into {self.label_list}. Please provide only the name of the label\nASSISTANT:"
                    elif self.dataset in ['FMOW']:
                        prompt = f"USER: <image>\nClassify the building or land-use in the image into {self.label_list}. Please provide only the name of the label\nASSISTANT:"
                    else:
                        prompt = None
                    prompt_dict = {
                        'general'  :  f"USER: <image>\nClassify the image into {self.label_list}. Please provide only the name of the label.\nASSISTANT:",
                        'general2' :  f"Classify the image into {self.label_list}. Please provide only the name of the label.",
                        'tailored' :  f"USER: <image>\nChoose a label that best describes the image. Here is the list of labels to choose from: {self.label_list}. Please provide only the name of the label.\nASSISTANT:",
                        'tailored2':  f"Choose a label that best describes the image. Here is the list of labels to choose from: {self.label_list}. Please provide only the name of the label.",
                        'tailored3':  prompt
                                }

        elif 'Phi' in self.mode:
                if self.dataset == 'iwildcam':
                    prompt_dict = {
                        'general'  :  f"USER: <|image_1|>\nClassify the object or animal in the image. Here is the list of labels to choose from: {self.label_list}. Please provide only the name of the label.",
                        'general2'  :  f"USER: <|image_1|>\nClassify the image using the following labels:  {self.label_list}. Please provide only the name of the label.",
                        'tailored' :  f"USER: <|image_1|>\nChoose a label that best describes the image. Here is the list of labels to choose from: {self.label_list}. Please provide only the name of the label.",
                        }
                elif self.dataset == 'CAMELYON17':
                    prompt_dict = {
                    'general'  :  f"USER: <|image_1|>\nPlease answer yes if the image contains any tumor tissue, and no otherwise. Please provide only the name of the label.",
                    'general2'  :  f"USER: <|image_1|>\nPlease answer yes if the image contains any tumor tissue, and no otherwise. Please respond with a single word.",
                    'tailored'  :  f"USER: <|image_1|>\nPlease analyze the image and determine if it contains any tumor tissue. Respond with 'Yes' if tumor tissue is present, or 'No' if it is not.",
                            }
                    self.label_list = ['no', 'yes']
                else:

                    if self.dataset == 'DEEPFASHION':
                        prompt = f"USER: <|image_1|>\nIs a person wearing a dress or not? Please answer in yes or no."
                        self.labels = ['yes', 'no']
                    elif self.dataset == 'CELEBA':
                        prompt = f"USER: <|image_1|>\nClassify the person in the image into {self.label_list}. Please provide only the name of the label."
                
                    elif self.dataset in ['DSPRITES','SMALLNORB','SHAPES3D']:
                        prompt = f"USER: <|image_1|>\nClassify the object in the image into {self.label_list}. Please provide only the name of the label."
                    elif self.dataset in ['FMOW']:
                        prompt = f"USER: <|image_1|>\nClassify the building or land-use in the image into {self.label_list}. Please provide only the name of the label."
                    else:
                        prompt = None
                    prompt_dict = {
                        'general'  :  f"USER: <|image_1|>\nClassify the image into {self.label_list}. Please provide only the name of the label.",
                        'tailored' :  f"USER: <|image_1|>\nChoose a label that best describes the image. Here is the list of labels to choose from: {self.label_list}. Please provide only the name of the label.",
                        'tailored2':  prompt
                                }

        elif 'CLIP' in self.mode:
            if self.dataset == 'CAMELYON17':
                prompt_dict = {'general':["no tumor tissue in a photo", "a photo of tumor tissue"]}
            else:
                prompt_dict = {'general':[f"a photo of a {label.replace('_', ' ')}" for label in self.labels]}
                if self.dataset == 'iwildcam':
                    tailored = ['There is no animal in a photo']+[f"a photo of a {label}" for label in self.labels[1:]]
                    general2 = ['a photo of no animal']+[f"a photo of a {label}" for label in self.labels[1:]]
                    prompt_dict['general2'] = general2
                elif self.dataset == 'CELEBA':
                    tailored = [f'a {label} person' for label in self.labels]
                elif self.dataset in ['DSPRITES', 'SMALLNORB', 'SHAPES3D']:
                    tailored = [f'a {label} shape object' for label in self.labels]

            
        
        self.prompt = prompt_dict[self.prompt_type]
        print(self.prompt)

    def get_model(self):

        if self.mode == 'InstructBLIP':
    
            self.processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
            self.model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b",load_in_4bit=True, torch_dtype=torch.float16)#load_in_4bit=True

        elif self.mode == 'LLAVA':

            model_id = "llava-hf/llava-1.5-7b-hf"
            self.model = LlavaForConditionalGeneration.from_pretrained(
                model_id, 
                torch_dtype=torch.float16, 
                low_cpu_mem_usage=True, 
            ).to(0)

            self.processor = AutoProcessor.from_pretrained(model_id)
        
        elif 'Phi' in self.mode:
            model_id = "microsoft/Phi-3-vision-128k-instruct" if self.mode == 'Phi-3' else "microsoft/Phi-3.5-vision-instruct"
            self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto", _attn_implementation='flash_attention_2') # use _attn_implementation='eager' to disable flash attention
            self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) 

        elif 'CLIP' in self.mode:

            id_dict = {'CLIP-base':"openai/clip-vit-base-patch32", 'CLIP-large':"openai/clip-vit-large-patch14"}
            model_id = id_dict[self.mode]
            self.processor = CLIPProcessor.from_pretrained(model_id)
            self.model = CLIPModel.from_pretrained(model_id).to("cuda")
    
    def get_pred(self, text):
        try:
            pred = self.labels.index(text)
        except:
            pred = -1
            if self.dataset == 'CAMELYON17':
                if 'yes' in text.lower():
                    pred = 1
                elif 'no' in text.lower():
                    pred = 0
                else:
                    print(f'Warning: unknown gt_answer: {text}')
            elif self.dataset == 'DEEPFASHION':
                if 'no' in text.lower() and 'dress' in text.lower():
                    pred = 1
                elif 'dress' in text.lower() and not 'no' in text.lower():
                    pred = 0
            else:
                for i, label in enumerate(self.labels):
                    label = re.sub(r'[^A-Za-z0-9 ]+', ' ', label)
                    label = label.strip().lower()
                    text = re.sub(r'[^A-Za-z0-9 ]+', ' ', text)
                    text = text.strip().lower()
                    if text in label:
                        pred = i
                    if label in text:
                        pred = i
                        break
        return pred

    def forward(self, im, gts=None):
        generated_text = 'not generated yet'

        if self.mode == 'InstructBLIP':
            inputs = self.processor(images=im, text=[self.prompt]*self.batch_size, return_tensors="pt").to(device="cuda", dtype=torch.float16)
            outputs = self.model.generate(
                    **inputs,
                    num_beams=1,
                    max_new_tokens=20,#50
                    min_length=1,
                    repetition_penalty=1.5,
                    length_penalty=2,
                    temperature=1
            )
            outputs[outputs == 0] = 2 # this line can be removed once https://github.com/huggingface/transformers/pull/24492 is fixed
            generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)
            if self.batch_size < 2:
                generated_text = generated_text[0].strip()
                pred = self.get_pred(generated_text)
            else:
                pred = self.get_correct(generated_text,gts)
        
        elif 'Phi' in self.mode:
            messages = [ 
            {"role": "user", "content": self.prompt}
            ] 
            prompt = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = self.processor(prompt, [im], return_tensors="pt").to("cuda:0") 

            generation_args = { 
                "max_new_tokens": 200, #if 'max_new_tokens' not in self.config else self.config['max_new_tokens'],#500, 
                "temperature": 0.2 # 0.0 if 'temperature' not in self.config else self.config['temperature'], 
            } 
            generation_args['do_sample'] = True if generation_args['temperature'] > 0 else False 
            generate_ids = self.model.generate(**inputs, eos_token_id=self.processor.tokenizer.eos_token_id, **generation_args) 

            # remove input tokens 
            generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
            generated_text = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 
            pred = self.get_pred(generated_text)

        elif self.mode == 'LLAVA':
            inputs = self.processor(self.prompt, im, return_tensors='pt').to(0, torch.float16)
            output = self.model.generate(**inputs, max_new_tokens=200, do_sample=False)
            generated_text = self.processor.decode(output[0][2:], skip_special_tokens=True).split('ASSISTANT: ')[-1].lower()
            
            pred = self.get_pred(generated_text)
        
        elif 'CLIP' in self.mode:

            label_tokens = self.processor(
            text=self.prompt,
            padding=True,
            images=None,
            return_tensors='pt'
            ).to("cuda")
            # encode tokens to sentence embeddings
            label_emb = self.model.get_text_features(**label_tokens)
            # detach from pytorch gradient computation
            label_emb = label_emb.detach().cpu().numpy()

            label_emb = label_emb / np.linalg.norm(label_emb, axis=1,keepdims=True)

            image = self.processor(
            text=None,
            images=im,
            return_tensors='pt'
            )['pixel_values'].cuda()
            img_emb = self.model.get_image_features(image)
            img_emb = img_emb.detach().cpu().numpy()
            img_emb = img_emb / np.linalg.norm(img_emb, axis=1,keepdims=True)
            scores = np.dot(img_emb, label_emb.T)
            pred = np.argmax(scores)
            generated_text = self.prompt[pred]

        return pred, generated_text

    def get_correct(self, generated_texts, labels):
        correct=0
        for label, text in zip(labels,generated_texts):
            label_text = self.labels[int(label)]
            text = re.sub(r'[^A-Za-z0-9 ]+', ' ', text)
            label_text = re.sub(r'[^A-Za-z0-9 ]+', ' ', label_text)

            if label_text in text or text in label_text:
                correct+=1
        return correct
    
    def inference(self):
        
        pred = 'empty'
        #pbar = tqdm(range(len(self.data)))
        correct_0 = 0
        correct_1 = 0
        correct=0

        for i in range(len(self.data)):

            if self.dataset in ['FMOW', 'iwildcam', 'CAMELYON17']:
                im, gt = self.data[i]
                gt = int(gt)
            else:
                im = Image.fromarray(np.uint8(self.data._imgs[i]*255))
                gt = int(self.data._labels[i])
            
            pred, text = self.forward(im)
 
            if pred == -1:
                print(f'wrong output from an image with label {gt}: ', text)
            if self.dataset == 'iwildcam':
                if i < 492:
                    correct_0+=(int(gt)==pred)
                else:
                    correct_1+=(int(gt)==pred)
                correct=correct_0+correct_1
            else:
                correct+= (int(gt)==pred)
            if i % 50 == 0:
                print(f'{i+1}/{len(self.data)}| pred:gt = {pred}:{gt}| {text}')
        
        return {'total': correct/len(self.data), '0': correct_0/429, '1': correct_1/(len(self.data)-429)}

    def inference_batch(self):
        
        pred = 'empty'
        #pbar = tqdm(range(len(self.data)))
        correct_0 = 0
        correct_1 = 0
        correct=0
        imgs=[]
        gts=[]
        for i in range(len(self.data)):
            if self.dataset == 'iwildcam':
                im_path, gt = self.data.samples[i]
                im = Image.open(im_path)
            elif self.dataset == 'FMOW':
                im, gt = self.data[i]
                gt = int(gt)
            else:
                im = Image.fromarray(np.uint8(self.data._imgs[i]*255))
                gt = int(self.data._labels[i])
            
            imgs.append(im)
            gts.append(gt)

            #im = Image.fromarray(np.uint8(data._imgs[i]*255))
            if len(imgs) == self.batch_size:
                correct_pred, texts = self.forward(imgs, gts)
                imgs=[]
                gts=[]
                correct+=correct_pred
                print(f'{i}/{len(self.data)}| pred:gt = {correct}/{i+1}| {texts[-1]}')
            else:
                continue
        
        return {'total': correct/len(self.data), '0': correct_0/429, '1': correct_1/(len(self.data)-429)}

if __name__ == "__main__":
    start_time = time.time()

    transformers.logging.set_verbosity_error()
    parser = argparse.ArgumentParser(description='Zeroshot classification with large models')
    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--mode', type=str, default="LLAVA", choices=['InstructBLIP', 'CLIP-base', 'CLIP-large', 'LLAVA', 'Phi-3', 'Phi-3.5'])
    parser.add_argument('--prompt_type', type=str, default="general", choices=['general', 'general2','tailored', 'tailored2','tailored3'])
    parser.add_argument('--dataset', type=str, default="iwildcam", choices=['iwildcam','DSPRITES','SHAPES3D','SMALLNORB', 'DEEPFASHION','CELEBA', 'FMOW', 'CAMELYON17'])
    parser.add_argument('--output_dir', type=str, default="../rob_exps/zeroshot")
    parser.add_argument('--batch_size', type=int, default=1)

    args = parser.parse_args()
    args.output_dir = f"{args.output_dir}/{args.dataset}/{args.mode}-{args.prompt_type}"
    
    os.makedirs(args.output_dir, exist_ok=True)
    sys.stdout = misc.Tee(os.path.join(args.output_dir, 'out.txt'))
    sys.stderr = misc.Tee(os.path.join(args.output_dir, 'err.txt'))

    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu

    print("Environment:")
    print("\tPython: {}".format(sys.version.split(" ")[0]))
    print("\tPyTorch: {}".format(torch.__version__))
    print("\tCUDA: {}".format(torch.version.cuda))
    print("\tCUDNN: {}".format(torch.backends.cudnn.version()))
    print("\tNumPy: {}".format(np.__version__))
    
    print('Args:')
    for k, v in sorted(vars(args).items()):
        print('\t{}: {}'.format(k, v))

    exp = Zeroshot(args.mode, args.prompt_type, args.dataset, args.batch_size)
    accuracies = exp.inference_batch() if args.batch_size > 1 else exp.inference()
    total_time = time.time()-start_time

    for key in accuracies:
        acc = accuracies[key]
        with open(os.path.join(args.output_dir, f'done_testacc_{key}_{acc:.4f}'), 'w') as f:
                f.write('done\n')
                f.write(f"test accuracy: {acc:.5f}\n")
                f.write(f'Elapsed Time: {(total_time)/3600: .1f} hour\n')
    print('done' ,args.output_dir)