import os
import json
import pickle
import random
random.seed(42)
import time
import numpy as np
from PIL import Image
import skimage.io as io
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon, Rectangle
from torch.utils.data import Dataset
import webdataset as wds

from minigpt4.datasets.datasets.base_dataset import BaseDataset
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset


class LlavaTestDataset(Dataset):
    def __init__(self, vis_processor, text_processor, vis_root,ann_path):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        """
        self.vis_root = vis_root

        self.vis_processor = vis_processor
        self.text_processor = text_processor

        with open(ann_path, 'r') as f:
            self.ann = json.load(f)

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        info = self.ann[index]

        image_file = info['image']
        image_path = os.path.join(self.vis_root, image_file)
        image = Image.open(image_path).convert("RGB")
        image = self.vis_processor(image)
        instruction = info['text']
        # answer=info['gth']

        # instruction = '<Img><ImageHere></Img> {} '.format(self.text_processor(instruction))

        return image,instruction,info['id']
class LlavaEvalDataset(Dataset):
    def __init__(self, vis_processor, text_processor, vis_root,ann_path):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        """
        self.vis_root = vis_root

        self.vis_processor = vis_processor
        self.text_processor = text_processor

        with open(ann_path, 'r') as f:
            self.ann = json.load(f)

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        info = self.ann[index]

        image_file = info['image']
        image_path = os.path.join(self.vis_root, image_file)
        image = Image.open(image_path).convert("RGB")
        image = self.vis_processor(image)
        instruction = info['text']
        # answer=info['gth']

        # instruction = '<Img><ImageHere></Img> {} '.format(self.text_processor(instruction))

        return image,instruction,info['id']
class LlavaDetailDataset(Dataset):
    def __init__(self, vis_processor, text_processor, vis_root, ann_path):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        """
        self.vis_root = vis_root

        self.vis_processor = vis_processor
        self.text_processor = text_processor

        with open(ann_path, 'r') as f:
            self.ann = json.load(f)

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        info = self.ann[index]

        image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
        image_path = os.path.join(self.vis_root, image_file)
        image = Image.open(image_path).convert("RGB")
        image = self.vis_processor(image)

        answer = info['conversations'][1]['value']
        instruction = info['conversations'][0]['value'].replace('<image>', '').replace('\n', '').strip()
        
        instruction = '<Img><ImageHere></Img> {} '.format(self.text_processor(instruction))

        return {
            "image": image,
            "instruction_input": instruction,
            "answer": answer,
            "image_id": info['id'],
        }

class LlavaReasonDataset(Dataset):
    def __init__(self, vis_processor, text_processor, vis_root, ann_path):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        """
        self.vis_root = vis_root

        self.vis_processor = vis_processor
        self.text_processor = text_processor

        with open(ann_path, 'r') as f:
            self.ann = json.load(f)

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        info = self.ann[index]

        image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
        image_path = os.path.join(self.vis_root, image_file)
        image = Image.open(image_path).convert("RGB")
        image = self.vis_processor(image)

        answer = info['conversations'][1]['value']
        instruction = info['conversations'][0]['value'].replace('<image>', '').replace('\n', '').strip()

        instruction = '<Img><ImageHere></Img> {} '.format(self.text_processor(instruction))

        return {
            "image": image,
            "instruction_input": instruction,
            "answer": answer,
            "image_id": info['id'],
        }




class LlavaConversationDataset(Dataset):
    def __init__(self, vis_processor, text_processor, vis_root, ann_path):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        """
        self.vis_root = vis_root

        self.vis_processor = vis_processor
        self.text_processor = text_processor

        self.ann=[]

    
        with open(ann_path, 'r') as f:
            self.ann = json.load(f)

        self.connect_sym = "!@#"

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        info = self.ann[index]

        image_file = '{}.jpg'.format(info['id'])
        image_path = os.path.join(self.vis_root, image_file)
        image = Image.open(image_path).convert("RGB")
        image = self.vis_processor(image)

        first_instruction = info['conversations'][0]['value'].replace('<image>', '').replace('\n', '').strip()
        first_instruction = '<Img><ImageHere></Img> {} '.format(first_instruction)

        questions = [first_instruction]
        answers = []

        for i, item in enumerate(info["conversations"][1:]):
            if i % 2 ==0:  # assistant
                assistant_answer = item["value"]
                answers.append(assistant_answer)
            else:
                human_instruction = item["value"]+" "
                questions.append(human_instruction)

        questions = self.connect_sym.join(questions)
        answers = self.connect_sym.join(answers)


        return {
            "image": image,
            "conv_q": questions,
            'conv_a': answers,
            "image_id": info['id'],
            "connect_sym": self.connect_sym
        }


class MaliciousQA(Dataset):
    def __init__(self, question_path, confirm_path, negative_path):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        """
        self.question_path = question_path

        self.confirm_path = confirm_path
        self.negative_path = negative_path

        # self.ann = []
        with open(self.question_path,'r') as f:
            lines = f.readlines()
        self.question = [e.strip() for e in lines if e.strip()]#[:50]
        self.acc_response = json.load(open(confirm_path))
        self.rej_response = json.load(open(negative_path))

        # all_messages_txt = [prepend_sys_prompt(l, args) for l in all_queries]

        # with open(self.question_path, 'r') as f:
        #     self.ann = json.load(f)

        self.connect_sym = "!@#"

    def __len__(self):
        return len(self.question)

    def __getitem__(self, index):
        question = self.question[index]


        #
        # image_file = '{}.jpg'.format(info['id'])
        # image_path = os.path.join(self.vis_root, image_file)
        # image = Image.open(image_path).convert("RGB")
        # image = self.vis_processor(image)
        #
        # first_instruction = info['conversations'][0]['value'].replace('<image>', '').replace('\n', '').strip()
        # first_instruction = '<Img><ImageHere></Img> {} '.format(first_instruction)
        #
        # questions = [first_instruction]
        # answers = []
        #
        # for i, item in enumerate(info["conversations"][1:]):
        #     if i % 2 == 0:  # assistant
        #         assistant_answer = item["value"]
        #         answers.append(assistant_answer)
        #     else:
        #         human_instruction = item["value"] + " "
        #         questions.append(human_instruction)
        #
        # questions = self.connect_sym.join(questions)
        # answers = self.connect_sym.join(answers)
        self.acc_label=random.choice(self.acc_response[question])
        self.rej_label = random.choice(self.rej_response[question])

        return {
            "question": question,
            "acc_ans": self.acc_label,
            'rej_ans': self.rej_label,
        }
class BenignQA(Dataset):
    def __init__(self, vis_processor, text_processor, vis_root, ann_path):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        """
        self.vis_root = vis_root

        self.vis_processor = vis_processor
        self.text_processor = text_processor

        self.ann = []

        with open(ann_path, 'r') as f:
            self.ann = json.load(f)

        self.connect_sym = "!@#"

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        info = self.ann[index]

        image_file = '{}.jpg'.format(info['id'])
        image_path = os.path.join(self.vis_root, image_file)
        image = Image.open(image_path).convert("RGB")
        image = self.vis_processor(image)

        instruction = info['conversations'][0]['value'].replace('<image>', '').replace('\n', '').strip()
        instruction = '<Img><ImageHere></Img> {} '.format(instruction)
        acc_ans=info['conversations'][1]['value']
        rej_ans=random.choice(info['conversations'][2]['rej_answer'])

        # questions = [first_instruction]
        # answers = []
        #
        # for i, item in enumerate(info["conversations"][1:]):
        #     if i % 2 == 0:  # assistant
        #         assistant_answer = item["value"]
        #         answers.append(assistant_answer)
        #     else:
        #         human_instruction = item["value"] + " "
        #         questions.append(human_instruction)
        #
        # questions = self.connect_sym.join(questions)
        # answers = self.connect_sym.join(answers)

        return {
            "image": image,
            "question": instruction,
            "acc_ans":acc_ans,
            "rej_ans":rej_ans
            # 'conv_a': answers,
            # "image_id": info['id'],
            # "connect_sym": self.connect_sym
        }