import os
import json
import math
import torch
from torch.utils.data import Dataset
from config.utils import *
from config.options import *
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm
from transformers import BertTokenizer
import clip

try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

def _convert_image_to_rgb(image):
    return image.convert("RGB")

def _transform(n_px):
    return Compose([
        Resize(n_px, interpolation=BICUBIC),
        CenterCrop(n_px),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

def init_tokenizer():
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    tokenizer.add_special_tokens({'bos_token':'[DEC]'})
    tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})       
    tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]  
    return tokenizer

class rank_pair_dataset(Dataset):
    def __init__(self, dataset):
        self.preprocess = _transform(config['BLIP']['image_size'])
        self.tokenizer = init_tokenizer()
        self.rank = config["rank"]
        self.models_use = config["models_use"]
        
        if opts.load_pair_store:
            self.dataset_path = os.path.join(config['pair_store_base'], f"{dataset}.pth")
            self.data = torch.load(self.dataset_path)
        else:
            self.dataset_path = os.path.join(config['data_base'], f"{dataset}.json")
            with open(self.dataset_path, "r") as f:
                self.data = json.load(f)
            self.data = self.make_data_all_score()
        
        self.iters_per_epoch = int(math.ceil(len(self.data)*1.0/opts.batch_size))

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)
    
    def store_dataset(self, dataset):
        makedir(config['pair_store_base'])
        torch.save(self.data, os.path.join(config['pair_store_base'], f"{dataset}.pth"))
    
    def make_data(self):
        data_items = []
        
        bar = tqdm(range(len(self.data)), desc=f'making dataset: ')
        for item in self.data:
            
            img_set = []
            for generations in item["generations"]:
                img_path = os.path.join(config['image_base'], generations)
                pil_image = Image.open(img_path)
                image = self.preprocess(pil_image)
                img_set.append(image)
                
            text_input = self.tokenizer(item["prompt"], padding='max_length', truncation=True, max_length=35, return_tensors="pt")
            labels = item["ranking"]
            for id_l in range(len(labels)):
                for id_r in range(id_l+1, len(labels)):
                    dict_item = {}
                    dict_item['clip_text'] = clip.tokenize(item["prompt"], truncate=True)
                    dict_item['text_ids'] = text_input.input_ids
                    dict_item['text_mask'] = text_input.attention_mask
                    if labels[id_l] < labels[id_r]:
                        dict_item['img_better'] = img_set[id_l]
                        dict_item['img_worse'] = img_set[id_r]
                    elif labels[id_l] > labels[id_r]:
                        dict_item['img_better'] = img_set[id_r]
                        dict_item['img_worse'] = img_set[id_l]
                    else:
                        continue
                    data_items.append(dict_item)
                    
            bar.update(1)
            
        return data_items
    
    def make_data_change(self):
        data_items = []
        
        bar = tqdm(range(len(self.data)), desc=f'making dataset: ')
        for item in self.data:
            dict_item = {}
            pil_image_sd15 = Image.open(item["image_sd15"])
            image_sd15 = self.preprocess(pil_image_sd15)
            pil_image_pixart = Image.open(item["image_pixart"])
            image_pixart = self.preprocess(pil_image_pixart)

            text_input = self.tokenizer(item["prompt"], padding='max_length', truncation=True, max_length=35, return_tensors="pt")
            dict_item['clip_text'] = clip.tokenize(item["prompt"], truncate=True)
            dict_item['text_ids'] = text_input.input_ids
            dict_item['text_mask'] = text_input.attention_mask
            dict_item['img_better'] = image_pixart
            dict_item['img_worse'] = image_sd15
            data_items.append(dict_item)
                    
            bar.update(1)
            
        return data_items
    
    def make_data_multi_model(self):
        data_items = []
        
        bar = tqdm(range(len(self.data)), desc=f'making dataset: ')
        for item in self.data:
            dict_item = {}
            img_set = []    
            for i in range(len(self.rank)):
                pil_image = Image.open(item[f"image_{self.rank[i]}"])
                image = self.preprocess(pil_image)
                img_set.append(image)
            text_input = self.tokenizer(item["prompt"], padding='max_length', truncation=True, max_length=35, return_tensors="pt")
            for i in range(len(self.rank)):
                for j in range(i+1, len(self.rank)):
                    dict_item = {}
                    dict_item['clip_text'] = clip.tokenize(item["prompt"], truncate=True)
                    dict_item['text_ids'] = text_input.input_ids
                    dict_item['text_mask'] = text_input.attention_mask
                    dict_item['img_better'] = img_set[i]
                    dict_item['img_worse'] = img_set[j]
                    data_items.append(dict_item)
            bar.update(1)
        return data_items
    
    def make_data_all_score(self):
        data_items = []
        
        bar = tqdm(range(len(self.data)), desc=f'making dataset: ')
        for item in self.data:
            dict_item = {}
            img_set = []    
            for i in range(len(self.models_use)):
                pil_image = Image.open(item[f"image_{self.models_use[i]}"])
                image = self.preprocess(pil_image)
                img_set.append(image)
            text_input = self.tokenizer(item["prompt"], padding='max_length', truncation=True, max_length=35, return_tensors="pt")
            ranking = [int(char) for char in item["rank"]]
            for i in range(len(ranking)-1):
                for j in range(i+1, len(ranking)):
                    dict_item = {}
                    dict_item['clip_text'] = clip.tokenize(item["prompt"], truncate=True)
                    dict_item['text_ids'] = text_input.input_ids
                    dict_item['text_mask'] = text_input.attention_mask
                    dict_item['img_better'] = img_set[ranking[i]-1]
                    dict_item['img_worse'] = img_set[ranking[j]-1]
                    data_items.append(dict_item)
            bar.update(1)
        return data_items
    
    def make_data_all_score_iterative(self):
        data_items = []
        
        bar = tqdm(range(len(self.data)), desc=f'making dataset: ')
        for item in self.data:
            dict_item = {}
            img_set = []    
            for i in range(len(self.models_use)):
                pil_image = Image.open(item[f"image_{self.models_use[i]}"])
                image = self.preprocess(pil_image)
                img_set.append(image)
            text_input = self.tokenizer(item["prompt"], padding='max_length', truncation=True, max_length=35, return_tensors="pt")
            ranking = [int(char) for char in item["rank"]]
            for i in range(len(ranking)-1):
                for j in range(i+1, len(ranking)):
                    if ranking[i] == 5 or ranking[j] == 5 or ranking[i] == 6 or ranking[j] == 6:
                        dict_item = {}
                        dict_item['clip_text'] = clip.tokenize(item["prompt"], truncate=True)
                        dict_item['text_ids'] = text_input.input_ids
                        dict_item['text_mask'] = text_input.attention_mask
                        dict_item['img_better'] = img_set[ranking[i]-1]
                        dict_item['img_worse'] = img_set[ranking[j]-1]
                        data_items.append(dict_item)
            bar.update(1)
        return data_items
