# coding=utf-8
# Copyright 2022 Gen Luo. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import  json, re,random
import torch.utils.data as Data
from torchvision.transforms import transforms
import os
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from PIL import Image
from util.base_prompt import *
import torch
from lavin import Tokenizer
import copy
from datasets import load_dataset
from transformers import AutoTokenizer
import random
class ScienceQADataSet(Data.Dataset):
    def __init__(self, args,split,model_path,max_words=512,max_image_feats=1):
        super(ScienceQADataSet, self).__init__()
        self.args = args
        # --------------------------
        # ---- Raw data loading ---
        # --------------------------
        self.problems = json.load(open(os.path.join(args.data_root, 'problems.json')))
        pid_splits = json.load(open(os.path.join(args.data_root, 'pid_splits.json')))
        captions = json.load(open(args.caption_file))["captions"]
        self.image_path=os.path.join(args.data_root,'images',split)
        self.tokenizer = Tokenizer(model_path=model_path + '/tokenizer.model')
        self.max_words = max_words
        self.max_image_feats=max_image_feats
        self.split=split
        for qid in self.problems:
            self.problems[qid]['caption'] = captions[qid] if qid in captions else ""

        self.qids = pid_splits['%s' % (split)]

        print(f"number of problems in split {split}: {len(self.qids)}\n")

        self.transforms=transforms.Compose([transforms.Resize((224, 224), interpolation=Image.BICUBIC),transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)])

    def tokenize(self,prompt,answer):
        example=prompt+answer

        prompt=torch.tensor(self.tokenizer.encode(prompt, bos=True, eos=False), dtype=torch.int64)
        example = torch.tensor(self.tokenizer.encode(example, bos=True, eos=True), dtype=torch.int64)
        padding = self.max_words - example.shape[0]
        if padding > 0:
            example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
        elif padding < 0:
            example = example[:self.max_words]
        labels = copy.deepcopy(example) 
        labels[:len(prompt)] = -1       
        example_mask = example.ge(0) 
        label_mask = labels.ge(0)  
        example[~example_mask] = 0 
        labels[~label_mask] = 0 
        example_mask = example_mask.float()
        label_mask = label_mask.float()
        return example, labels, example_mask, label_mask


    def __getitem__(self, idx):

        prompt_question,prompt_answer= build_prompt(self.problems,self.qids[idx],self.args)
        answer,choices,qid=self.problems[self.qids[idx]]["answer"], self.problems[self.qids[idx]]["choices"],self.qids[idx]

        if self.problems[self.qids[idx]]['image'] is not None:
            image = Image.open(os.path.join(self.image_path, self.qids[idx], 'image.png')).convert('RGB')
            image = self.transforms(image)
            image_mask=torch.cat([torch.Tensor([float('-inf')]*self.max_image_feats),torch.zeros(self.max_words)])
            indicator=1
        else:
            image=torch.Tensor(torch.zeros(3,224,224).float())
            image_mask=torch.zeros(self.max_words+self.max_image_feats)
            indicator=0

        example, labels, example_mask, label_mask=self.tokenize(prompt_question,prompt_answer)

        return example, labels, example_mask, image, indicator, self.qids[idx]

    def __len__(self):
        return len(self.qids)

    def shuffle_list(self, list):
        random.shuffle(list)



class InstrcutDataSet(Data.Dataset):
    def __init__(self, args,split,model_path,max_words=512,max_image_feats=1):
        super(InstrcutDataSet, self).__init__()
        self.args = args
        # --------------------------
        # ---- Raw data loading ---
        # --------------------------
        self.data = json.load(open(os.path.join(args.data_root, 'all_data.json')))[split]

        self.tokenizer = Tokenizer(model_path=model_path + '/tokenizer.model')
        self.max_words = max_words
        self.max_image_feats=max_image_feats
        self.split=split
        self.qids = [item['qid'] for item in self.data]

        print(f"number of problems in split {split}: {len(self.qids)}\n")

        self.transforms=transforms.Compose([transforms.Resize((224, 224), interpolation=Image.BICUBIC),transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)])

    def tokenize(self,prompt,answer,max_words=512):
        example=prompt+answer
        # print(prompt)
        prompt=torch.tensor(self.tokenizer.encode(prompt, bos=True, eos=False), dtype=torch.int64)
        example = torch.tensor(self.tokenizer.encode(example, bos=True, eos=True), dtype=torch.int64)
        padding = max_words - example.shape[0]
        if padding > 0:
            example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
        elif padding < 0:
            example = example[:self.max_words]
        labels = copy.deepcopy(example)
        labels[:len(prompt)] = -1
        example_mask = example.ge(0)
        label_mask = labels.ge(0)
        example[~example_mask] = 0
        labels[~label_mask] = 0
        example_mask = example_mask.float()
        label_mask = label_mask.float()
        return example, labels, example_mask,label_mask


    def __getitem__(self, idx):

        prompt_question=self.data[idx]['instruction']
        prompt_answer=self.data[idx]['answer']

        if self.data[idx]['image'] is not None:
            # image_path='../data/images/train' if self.data[idx]['image_source']=='sqa' else '../data/images/train2014'
            if self.data[idx]['image_source'] == 'sqa':
                image = Image.open(os.path.join('../data/images/train', self.qids[idx], 'image.png')).convert('RGB')
            else:
                image = Image.open(os.path.join('../data/images/train2014',   'COCO_train2014_'+self.data[idx]['image'])).convert('RGB')
            image = self.transforms(image)
            indicator=1
        else:
            image=torch.Tensor(torch.zeros(3,224,224).float())
            indicator=0

        # print(prompt_question,prompt_answer)
        example, labels, example_mask, label_mask=self.tokenize(prompt_question,prompt_answer)

        return example, labels, example_mask, image, indicator

    def __len__(self):
        return len(self.qids)

    def shuffle_list(self, list):
        random.shuffle(list)


class WikiTextDataset(Data.Dataset):
    def __init__(self, args, split, model_path, max_words=512):
        super(WikiTextDataset, self).__init__()
        self.args = args
        # --------------------------
        # ---- Raw data loading ---
        # --------------------------
        text = load_dataset('wikitext', 'wikitext-2-raw-v1', split=split)['text']
        self.tokenizer = Tokenizer(model_path=model_path + '/tokenizer.model')
        self.textenc = torch.tensor(self.tokenizer.encode("\n\n".join(text), bos=True, eos=False))
        self.max_words = max_words
        self.split=split
        self.sizes = len(self.textenc) // self.max_words
        print(f"number of Wikitext in split {split}: {self.sizes}\n")

    def __getitem__(self, idx):
        start = idx*self.max_words
        end = (idx+1)*self.max_words
        inp = self.textenc[start:end]
        label = copy.deepcopy(inp)
        return inp, label, 0, 0 ,0

    def __len__(self):
        return self.sizes

    def shuffle_list(self, list):
        random.shuffle(list)


class MixDataset(Data.Dataset):
    def __init__(self, args, split, model_path, max_words=512, max_image_feats = 1):
        super(MixDataset, self).__init__()
        self.args = args
        # --------------------------
        # ---- Raw data loading ---
        # --------------------------
        text = load_dataset('wikitext', 'wikitext-2-raw-v1', split=split)['text']
        self.tokenizer = Tokenizer(model_path=model_path + '/tokenizer.model')
        self.textenc = torch.tensor(self.tokenizer.encode("\n\n".join(text), bos=True, eos=False))
        self.wiki_size = len(self.textenc) // max_words
        print(f"number of Wikitext in split {split}: {self.wiki_size}\n")

        self.problems = json.load(open(os.path.join(args.data_root, 'problems.json')))
        pid_splits = json.load(open(os.path.join(args.data_root, 'pid_splits.json')))
        captions = json.load(open(args.caption_file))["captions"]
        self.image_path=os.path.join(args.data_root,'images',split)
        for qid in self.problems:
            self.problems[qid]['caption'] = captions[qid] if qid in captions else ""
        
        if args.idx_path is not None:
            print("load idx from %s...\n" % args.idx_path)
            qids_list = []
            qids_train_list = torch.load(args.idx_path)
            if args.need_img:
                print("must have image...\n")
                for idx, _ in qids_train_list:
                    if self.problems[idx]['image'] is not None:
                        qids_list.append(idx)
                qid_size  = int(self.wiki_size // args.mix_ratio) if args.mix_ratio is not None else len(qids_list)
                # import pdb; pdb.set_trace()
            else:
                print("probilitily have image...\n")
                qid_size  = int(self.wiki_size // args.mix_ratio) if args.mix_ratio is not None else len(qids_train_list)
                qids_list = [idx for idx, _ in qids_train_list]
        else:
            print("load idx from random...\n")
            if args.need_img:
                print("must have image...\n")
                qids_list = []
                qids_train_list = pid_splits['%s' % (split)]
                for i in qids_train_list:
                    if self.problems[i]['image'] is not None:
                        qids_list.append(i)
                qid_size  = int(self.wiki_size // args.mix_ratio) if args.mix_ratio is not None else len(qids_list)
            else:
                print("probilitily have image...\n")
                qid_size  = int(self.wiki_size // args.mix_ratio) if args.mix_ratio is not None else len(pid_splits['%s' % (split)])
                qids_list = pid_splits['%s' % (split)]
            random.shuffle(qids_list)
        self.qids = qids_list[0:qid_size]    
        
        print(f"number of Science QA problems in split {split}: {len(self.qids)}\n")
        self.transforms=transforms.Compose([transforms.Resize((224, 224), interpolation=Image.BICUBIC),transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)])
        self.qa_size = len(self.qids)

        self.max_words = max_words
        self.max_image_feats = max_image_feats
        self.split=split
        self.sizes = self.wiki_size + self.qa_size

    def __getitem__(self, idx):
        if idx < self.wiki_size:
            start = idx*self.max_words
            end = (idx+1)*self.max_words
            inp = self.textenc[start:end]
            label = copy.deepcopy(inp)
            image = torch.Tensor(torch.zeros(3,224,224).float())
            return inp, label, 0, image ,0 
        else:
            idx = idx - self.wiki_size
            prompt_question,prompt_answer= build_prompt(self.problems,self.qids[idx],self.args)
            answer,choices,qid=self.problems[self.qids[idx]]["answer"], self.problems[self.qids[idx]]["choices"],self.qids[idx]

            if self.problems[self.qids[idx]]['image'] is not None:
                image = Image.open(os.path.join(self.image_path, self.qids[idx], 'image.png')).convert('RGB')
                image = self.transforms(image)
                image_mask=torch.cat([torch.Tensor([float('-inf')]*self.max_image_feats),torch.zeros(self.max_words)])
                indicator=1
            else:
                image=torch.Tensor(torch.zeros(3,224,224).float())
                image_mask=torch.zeros(self.max_words+self.max_image_feats)
                indicator=0

            example, labels, example_mask, label_mask=self.tokenize(prompt_question,prompt_answer)

            return example, labels, example_mask, image,indicator
    
    def tokenize(self,prompt,answer):
        example=prompt+answer
        # print(prompt)
        prompt=torch.tensor(self.tokenizer.encode(prompt, bos=True, eos=False), dtype=torch.int64)
        example = torch.tensor(self.tokenizer.encode(example, bos=True, eos=True), dtype=torch.int64)
        padding = self.max_words - example.shape[0]
        if padding > 0:
            example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
        elif padding < 0:
            example = example[:self.max_words]
        labels = copy.deepcopy(example) 
        labels[:len(prompt)] = -1       
        example_mask = example.ge(0) 
        label_mask = labels.ge(0)     
        example[~example_mask] = 0
        labels[~label_mask] = 0
        example_mask = example_mask.float()
        label_mask = label_mask.float()
        return example, labels, example_mask, label_mask
    
    def __len__(self): 
        return self.sizes

    def shuffle_list(self, list):
        random.shuffle(list)


if __name__ == '__main__':
    from torch.utils.data import DataLoader
    class Cfg():
        def __init__(self):
            super(Cfg, self).__init__()
            self.options = ["A", "B", "C", "D", "E"]
            self.use_caption = True
            self.prompt_format = 'CQM-A'
            self.data_root = './data'
            self.output_root = './output'
            self.caption_file = './data/captions.json'
    cfg=Cfg()
    dataset=ScienceQADataSet(cfg,'train','./data/weights',max_words=512)
    data_loader = DataLoader(dataset,
                             batch_size=1,
                             shuffle=False,
                             pin_memory=True)
    max_question_len=0
    max_answer_len=0
    #406 max question
    for examples, labels, example_mask,images,indicators in data_loader:
        if indicators[0] == 0:
            import pdb
            pdb.set_trace()
            print(images)
    #     if len(prompt_questions[0].split())>max_question_len:
    #         max_question_len=len(prompt_questions[0].split())
    #     if len(prompt_answers[0].split())>max_answer_len:
    #         max_answer_len=len(prompt_answers[0].split())
    # print(max_question_len,max_answer_len)
