import os

from numpy.ma.core import around
from torch.utils.data import DataLoader, Dataset
from PIL import Image, ImageOps
import numpy as np
import torch
from glob import glob
import random
from datasets import load_dataset
import io
import cv2
from tqdm import tqdm

class Data(Dataset):
    def __init__(self, cls_id_name=None, return_image_path=False):
        super(Data, self).__init__()

        self.image_list = []
        self.label_list = []
        self.prompt = []

        self.cls_id_name = cls_id_name
        self.return_image_path = return_image_path

        self.index = 0
    def __len__(self):
        return len(self.image_list)

    def __rmul__(self, v):
        random.seed(1234)
        np.random.seed(1234)
        new_length = int(len(self.label_list) * v)
        indices = random.sample(range(len(self.label_list)), new_length)

        self.label_list = [self.label_list[i] for i in indices]
        self.image_list = [self.image_list[i] for i in indices]
        self.prompt = [self.prompt[i] for i in indices]

        return self

    def load_image(self, image_path):

        image = Image.open(image_path)
        image = ImageOps.exif_transpose(image)
        image = image.convert("RGB")

        return image

    def __getitem__(self, index):

        image_path, label = self.image_list[index], self.label_list[index]
        lang_x = self.init_prompt(self.prompt[index])
        images = image_path

        return images, lang_x, label

    def collate_fn(self, batch):

        image = [item[0] for item in batch]
        lang_x = [item[1] for item in batch]
        label = [item[2] for item in batch]

        return image, lang_x, label

    def generate_prompts(self, lang_x, label_list):
        pass

    def init_prompt(self, prompt):
        pass

    def n_shot(self, image_list, prompts, label_list, shot=1, cls='same'):
        """
        :param image_list: [image1, image2] : all image path
        :param prompts: [prompt1, prompt2]
        :param shot: n-shot
        :param cls: different class or same class with input
        :return: [Image, Image]; [[prompt1], [prompt2]]
        """
        random.seed(1234)
        np.random.seed(1234)

        all_image_path = []

        if shot == 0:
            if self.return_image_path:
                return image_list, [self.load_image(i) for i in image_list], [[i] for i in prompts], label_list
            else:
                return [self.load_image(i) for i in image_list], [[i] for i in prompts], label_list

        if shot < 0:
            raise ValueError(f'Error: shot must be 0 or greater than 0, but got {shot}.')

        image_list_ = []
        prompts_ = []

        for batch_image, batch_prompt, batch_label in zip(image_list, prompts, label_list):
            image = batch_image
            label = batch_label
            label_list_ = []
            image_path = []

            if cls == 'same':
                indices = np.where(np.array(self.label_list) == label)
                random_indices = np.random.choice(indices[0], shot, replace=True)
                for i in random_indices:
                    # add few-shot image and label
                    image_list_.append(self.load_image(self.image_list[i]))
                    label_list_.append(self.label_list[i])

                    if self.return_image_path:
                        image_path.append(self.image_list[i])

            else:
                indices = np.where(np.array(self.label_list) != label)
                random_indices = np.random.choice(indices[0], shot, replace=True)
                for i in random_indices:
                    image_list_.append(self.load_image(self.image_list[i]))
                    label_list_.append(self.label_list[i])

                    if self.return_image_path:
                        image_path.append(self.image_list[i])

            image_list_.append(self.load_image(image))
            new_prompt = self.generate_prompts(batch_prompt, label_list_)
            prompts_.append(new_prompt)


            if self.return_image_path:
                image_path.append(image)

            all_image_path.append(image_path)

        if self.return_image_path:
            return all_image_path, image_list_, prompts_, label_list

        return image_list_, prompts_, label_list


# class Cifar10(Data):
#     def __init__(self, root='data/cifar10/', cls_id_name=CIFAR10, test=False, return_image_path=False):
#         super(Cifar10, self).__init__()
#
#         self.root = root
#         self.cls_id_name = cls_id_name
#         self.return_image_path = return_image_path
#
#         if test:
#             self.root = os.path.join(self.root, 'test')
#             self.image_list = glob(os.path.join(self.root, '*.png'))
#         else:
#             self.root = os.path.join(self.root, 'train')
#             self.image_list = glob(os.path.join(self.root, '*.jpg'))
#
#         self.label_list = [int(i[i.index('_')-1]) for i in self.image_list]
#
#     def generate_prompts(self, lang_x, label_list):
#         mesg = []
#         for i in label_list:
#             usr = {
#                 "role": "user",
#                 "content": [
#                     {
#                         "type": "image",
#                     },
#                     {"type": "text", "text": "An image of what?"},
#                 ]
#             }
#
#             ass = {
#                 "role": "assistant",
#                 "content": [
#                     {"type": "text", "text": f"{self.cls_id_name[i]}"},
#                 ]
#
#             }
#
#             mesg.append(usr)
#             mesg.append(ass)
#
#         mesg.append(lang_x)
#
#         return mesg
#
#     def init_prompt(self):
#         message = {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is shown in the image? Answer using one lowercase word."}]}
#         return message


class Animals(Data):
    def __init__(self, root='data/Animals_with_Attrubutes2/', test=False, return_image_path=False):
        super(Animals, self).__init__()
        self.root = root
        self.return_image_path = return_image_path

        self.image_list = []
        self.label_list = []

        if test:  # 6985 test samples
            with open(os.path.join(root, 'testclasses.txt'), 'r') as f:
                test_class = f.readlines()
                test_class = [i.replace('\n', '') for i in test_class]
            for c in test_class:
                self.image_list.append(glob(os.path.join(self.root, 'JPEGImages', c, '*.jpg')))
            self.image_list = sum(self.image_list, [])
            for i in self.image_list:
                image_name = os.path.basename(i)
                self.label_list.append(image_name.split('_')[0].replace('+', ' '))
        else:  # 30337 train samples
            with open(os.path.join(root, 'trainclasses.txt'), 'r') as f:
                train_class = f.readlines()
                train_class = [i.replace('\n', '') for i in train_class]
            for c in train_class:
                self.image_list.append(glob(os.path.join(self.root, 'JPEGImages', c, '*.jpg')))
            self.image_list = sum(self.image_list, [])
            for i in self.image_list:
                image_name = os.path.basename(i)
                self.label_list.append(image_name.split('_')[0].replace('+', ' '))

        self.label_to_indices = {}
        self.prompt = ['What is the name of the species in the image? Answer one word or phrases in lower case.'] * len(self.image_list)

        for idx, label in enumerate(self.label_list):
            if label not in self.label_to_indices:
                self.label_to_indices[label] = []
            self.label_to_indices[label].append(idx)

    def init_prompt(self, prompt): #What is the name of the species in the image? Answer in lower case.
        message = {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": f"{prompt}"}]}
        return message

    def generate_prompts(self, lang_x, label_list):
        mesg = []
        for i in label_list:
            usr = {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                    },
                    {"type": "text", "text": "What is the scientific name of the species in the image?"},
                ]
            }

            ass = {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": f"{i}"},
                ]

            }

            mesg.append(usr)
            mesg.append(ass)

        mesg.append(lang_x)

        return mesg


class QA(Data):
    def __init__(self, root='', return_image_path=False):
        super(QA, self).__init__()
        self.image_list = []
        self.label_list = []
        self.prompt = []
        self.return_image_path = return_image_path

    def init_prompt(self, prompt):
        message = {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": f"{prompt}"}]}
        return message

    def n_shot(self, image_list, prompts, label_list, shot=1, cls='same'):
        """
        :param image_list: [image1, image2] : all image path
        :param prompts: [prompt1, prompt2]
        :param shot: n-shot
        :param cls: different class or same class with input
        :return: [Image, Image]; [[prompt1], [prompt2]]
        """
        random.seed(1234)
        np.random.seed(1234)

        all_image_path = []

        if shot == 0:
            if self.return_image_path:
                return image_list, [self.load_image(i) for i in image_list], [[i] for i in prompts], label_list
            else:
                return [self.load_image(i) for i in image_list], [[i] for i in prompts], label_list

        if shot < 0:
            raise ValueError(f'Error: shot must be 0 or greater than 0, but got {shot}.')

        image_list_ = []
        prompts_ = []

        for batch_image, batch_prompt, batch_label in zip(image_list, prompts, label_list):
            image = batch_image

            label_list_ = []
            image_path = []
            prompt_list = []

            random_indices = np.random.choice(list(range(len(self.image_list))), shot, replace=True)
            for i in random_indices:
                # add few-shot image and label
                image_list_.append(self.load_image(self.image_list[i]))
                label_list_.append(self.label_list[i])
                prompt_list.append(self.prompt[i])

                if self.return_image_path:
                    image_path.append(self.image_list[i])

            image_list_.append(self.load_image(image))
            new_prompt = self.generate_prompts(batch_prompt, label_list_, prompt_list)
            prompts_.append(new_prompt)

            if self.return_image_path:
                image_path.append(image)

            all_image_path.append(image_path)

        if self.return_image_path:
            return all_image_path, image_list_, prompts_, label_list

        return image_list_, prompts_, label_list

    def generate_prompts(self, lang_x, label_list, prompt_list):
        mesg = []
        for index, i in enumerate(label_list):
            usr = {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                    },
                    {"type": "text", "text": f"{prompt_list[index]}"},
                ]
            }

            ass = {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": f"{i}"},
                ]

            }

            mesg.append(usr)
            mesg.append(ass)

        mesg.append(lang_x)

        return mesg

class RealworldQA(QA):
    def __init__(self, root='data/realworldQA/data', return_image_path=False):
        super(RealworldQA, self).__init__()

        self.root = root
        self.return_image_path = return_image_path

        self.image_list = []

        if not os.path.exists(os.path.join(root, 'images')):
            os.makedirs(os.path.join(root, 'images'))
        self.image_list = glob(os.path.join(root, 'images', '*.jpg'))

        dataset = load_dataset(root)
        self.all_data = dataset['test']

        self.label_list = self.all_data['answer']
        self.prompt = self.all_data['question']
        # remove the prompt, keep only question
        self.prompt = [p.rsplit('\n', 1)[0] for p in self.prompt]
        images = self.all_data['image']

        if self.image_list == []:
            print('Save images to jpg')
            for index, image in enumerate(images):
                if image.mode == 'RGBA':
                    image = image.convert('RGB')
                image_bytes = io.BytesIO()
                image.save(image_bytes, format='JPEG')
                image_bytes = image_bytes.getvalue()
                image_np = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR)
                cv2.imwrite(os.path.join(root, 'images', f"{index}.jpg"), image_np)
        else:
            self.image_list = [os.path.join(root, 'images', f'{i}.jpg')for i in range(len(self.label_list))]


class MMBench(QA):
    def __init__(self, root='/data/mmbench/data', return_image_path=False):
        super(MMBench, self).__init__()
        self.root = root
        dataset = load_dataset(root)
        self.all_data = dataset['validation']
        self.return_image_path = return_image_path

        self.image_list = []
        self.label_list = []
        self.prompt = []

        if not os.path.exists(os.path.join(root, 'images')):
            os.makedirs(os.path.join(root, 'images'))
        self.image_list = glob(os.path.join(root, 'images', '*.jpg'))

        for idx, df in enumerate(self.all_data):
            question = df['question']
            options = " A. {}. B. {}. C. {}. D. {}.".format(df['A'], df['B'], df['C'], df['D'])
            self.prompt.append(question+options)

        self.label_list = self.all_data['answer']

        images = self.all_data['image']

        if self.image_list == []:
            print('Save images to jpg')
            for index, image in enumerate(images):
                if image.mode == 'RGBA':
                    image = image.convert('RGB')
                image_bytes = io.BytesIO()
                image.save(image_bytes, format='JPEG')
                image_bytes = image_bytes.getvalue()
                image_np = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR)
                cv2.imwrite(os.path.join(root, 'images', f"{index}.jpg"), image_np)
        else:
            self.image_list = [os.path.join(root, 'images', f'{i}.jpg')for i in range(len(self.label_list))]

class MMStar(QA):
    def __init__(self, root='/data/mmstar', return_image_path=False):
        super(MMStar, self).__init__()
        self.root = root
        dataset = load_dataset(root)
        self.all_data = dataset['val']
        self.return_image_path = return_image_path

        self.image_list = []
        self.label_list = []
        self.prompt = []

        if not os.path.exists(os.path.join(root, 'images')):
            os.makedirs(os.path.join(root, 'images'))
        self.image_list = glob(os.path.join(root, 'images', '*.jpg'))

        self.prompt = self.all_data['question']
        self.label_list = self.all_data['answer']

        images = self.all_data['image']

        if self.image_list == []:
            print('Save images to jpg')
            for index, image in enumerate(images):
                if image.mode == 'RGBA':
                    image = image.convert('RGB')
                image_bytes = io.BytesIO()
                image.save(image_bytes, format='JPEG')
                image_bytes = image_bytes.getvalue()
                image_np = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR)
                cv2.imwrite(os.path.join(root, 'images', f"{index}.jpg"), image_np)
        else:
            self.image_list = [os.path.join(root, 'images', f'{i}.jpg')for i in range(len(self.label_list))]


class SeedBench(QA):
    def __init__(self, root='/data/seedbench', return_image_path=False):
        super(SeedBench, self).__init__()
        self.root = root
        dataset = load_dataset(root)
        self.all_data = dataset['test']
        self.return_image_path = return_image_path

        self.image_list = []
        self.label_list = []
        self.prompt = []

        if not os.path.exists(os.path.join(root, 'images')):
            os.makedirs(os.path.join(root, 'images'))
        self.image_list = glob(os.path.join(root, 'images', '*.jpg'))

        with tqdm(self.all_data) as dbar:
            for idx, df in enumerate(dbar):
                question = df['question']

                options = " A. {}. B. {}. C. {}. D. {}.".format(df['choice_a'], df['choice_b'], df['choice_c'], df['choice_d'])
                self.prompt.append(question + options)

        self.label_list = self.all_data['answer']

        if self.image_list == []:
            print('Save images to jpg')
            images = [i[0] for i in self.all_data['image']]
            for index, image in enumerate(images):

                if image.mode == 'RGBA':
                    image = image.convert('RGB')
                image_bytes = io.BytesIO()
                image.save(image_bytes, format='JPEG')
                image_bytes = image_bytes.getvalue()
                image_np = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR)
                cv2.imwrite(os.path.join(root, 'images', f"{index}.jpg"), image_np)

        self.image_list = [os.path.join(root, 'images', f'{i}.jpg') for i in range(len(self.label_list))]

class ScienceQA(QA):
    def __init__(self, root='/data/science', return_image_path=False):
        self.root = root
        dataset = load_dataset(root)
        self.all_data = dataset['test']
        self.return_image_path = return_image_path

        self.image_list = []
        self.label_list = []
        self.prompt = []

        if not os.path.exists(os.path.join(root, 'images')):
            os.makedirs(os.path.join(root, 'images'))
        self.image_list = glob(os.path.join(root, 'images', '*.jpg'))

        with tqdm(self.all_data) as dbar:
            for idx, df in enumerate(dbar):
                question = df['question']
                options = df['choices']
                en_options = '\n' + '\n'.join([f'{chr(65 + i)}. {o}' for i, o in enumerate(options)])

                self.prompt.append(question + en_options)

        self.label_list = self.all_data['answer']
        self.label_list = [chr(65 + i) for i in self.label_list]

        if self.image_list == []:
            print('Save images to jpg')
            images = [i for i in self.all_data['image']]
            for index, image in enumerate(images):

                if image.mode == 'RGBA':
                    image = image.convert('RGB')
                image_bytes = io.BytesIO()
                image.save(image_bytes, format='JPEG')
                image_bytes = image_bytes.getvalue()
                image_np = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR)
                cv2.imwrite(os.path.join(root, 'images', f"{index}.jpg"), image_np)

        self.image_list = [os.path.join(root, 'images', f'{i}.jpg') for i in range(len(self.label_list))]




