import json
import numpy as np
import os
import random
import wandb
from tqdm import tqdm

# from data.utils import _get_img_from_path
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from data.utils import init_tokenizer
from data.collate_fns import BLIPPaddingCollateFunctionTest4CIRR
import torchvision.transforms.functional as F
import torch

import PIL
import PIL.Image

# Credits: Preprocess taken from CLIP4CIR
# https://github.com/ABaldrati/CLIP4Cir

_DEFAULT_FASHION_IQ_DATASET_ROOT = ''
base_path = _DEFAULT_FASHION_IQ_DATASET_ROOT
target_ratio = 1.5
input_dim = 224

def collate_fn(batch):
    '''
    function which discard None images in a batch when using torch DataLoader
    :param batch: input_batch
    :return: output_batch = input_batch - None_values
    '''
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

def _convert_image_to_rgb(image):
    return image.convert("RGB")

class TargetPad:
    """
    Pad the image if its aspect ratio is above a target ratio.
    Pad the image to match such target ratio
    """

    def __init__(self, target_ratio: float, size: int):
        """
        :param target_ratio: target ratio
        :param size: preprocessing output dimension
        """
        self.size = size
        self.target_ratio = target_ratio

    def __call__(self, image):
        w, h = image.size
        actual_ratio = max(w, h) / min(w, h)
        if actual_ratio < self.target_ratio:  # check if the ratio is above or below the target ratio
            return image
        scaled_max_wh = max(w, h) / self.target_ratio  # rescale the pad to match the target ratio
        hp = max(int((scaled_max_wh - w) / 2), 0)
        vp = max(int((scaled_max_wh - h) / 2), 0)
        padding = [hp, vp, hp, vp]
        return F.pad(image, padding, 0, 'constant')

def targetpad_transform(target_ratio: float, dim: int):
    """
    CLIP-like preprocessing transform computed after using TargetPad pad
    :param target_ratio: target ratio for TargetPad
    :param dim: image output dimension
    :return: CLIP-like torchvision Compose transform
    """
    return Compose([
        TargetPad(target_ratio, dim),
        Resize(dim, interpolation=PIL.Image.BICUBIC),
        CenterCrop(dim),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])


class CIRRDataset(Dataset):
    """
       CIRR dataset class which manage CIRR data
       The dataset can be used in 'relative' or 'classic' mode:
           - In 'classic' mode the dataset yield tuples made of (image_name, image)
           - In 'relative' mode the dataset yield tuples made of:
                - (reference_image, target_image, rel_caption) when split == train
                - (reference_name, target_name, rel_caption, group_members) when split == val
                - (pair_id, reference_name, rel_caption, group_members) when split == test1
    """

    def __init__(self, split: str, mode: str, preprocess: callable, config):
        """
        :param split: dataset split, should be in ['test', 'train', 'val']
        :param mode: dataset mode, should be in ['relative', 'classic']:
                  - In 'classic' mode the dataset yield tuples made of (image_name, image)
                  - In 'relative' mode the dataset yield tuples made of:
                        - (reference_image, target_image, rel_caption) when split == train
                        - (reference_name, target_name, rel_caption, group_members) when split == val
                        - (pair_id, reference_name, rel_caption, group_members) when split == test1
        :param preprocess: function which preprocesses the image
        """
        self.preprocess = targetpad_transform(target_ratio, input_dim)
        self.mode = mode
        self.split = split
        self.config = config
        self.negative = self.config["negative"]

        if split not in ['test1', 'train', 'val']:
            raise ValueError("split should be in ['test1', 'train', 'val']")
        if mode not in ['relative', 'classic']:
            raise ValueError("mode should be in ['relative', 'classic']")

        # get triplets made by (reference_image, target_image, relative caption)
        with open(os.path.join(base_path , 'cirr' , 'captions' , f'cap.rc2.{split}.json')) as f:
            self.triplets = json.load(f)

        # get a mapping from image name to relative path
        with open(os.path.join(base_path , 'cirr' , 'image_splits' , f'split.rc2.{split}.json')) as f:
            self.name_to_relpath = json.load(f)

        self.image_names = list(self.name_to_relpath.keys())
        self.hard_images = []
        print(f"CIRR {split} dataset in {mode} mode initialized")

    def __getitem__(self, index):
        try:
            if self.mode == 'relative':
                group_members = self.triplets[index]['img_set']['members']
                reference_name = self.triplets[index]['reference']
                rel_caption = self.triplets[index]['caption']

                if self.split == 'train':
                    reference_image_path = os.path.join(base_path  , self.name_to_relpath[reference_name])
                    reference_image = self.preprocess(PIL.Image.open(reference_image_path))
                    target_hard_name = self.triplets[index]['target_hard']
                    target_image_path = os.path.join(base_path  , self.name_to_relpath[target_hard_name])
                    target_image = self.preprocess(PIL.Image.open(target_image_path))

                    if self.negative == 'random': # random negative sampling
                        negative_name = random.choice(self.image_names)
                        negative_targ_img_path = os.path.join(base_path  , self.name_to_relpath[negative_name])
                        negative_target_img = self.preprocess(PIL.Image.open(negative_targ_img_path))
                    elif self.negative == 'random_rerank': # hard negative sampling after reranking
                        negative_name = random.choice(self.hard_images[index])
                        negative_targ_img_path = os.path.join(base_path  , self.name_to_relpath[negative_name])
                        negative_target_img = self.preprocess(PIL.Image.open(negative_targ_img_path))
                    else:
                        raise ValueError("Undefined Negative Sampling Method")
                    # captions = [txt_processors["eval"](caption) for caption in rel_caption]
                    return reference_image, target_image, negative_target_img, target_hard_name, rel_caption

                elif self.split == 'val':
                    reference_image_path = os.path.join(base_path, self.name_to_relpath[reference_name])
                    reference_image = self.preprocess(PIL.Image.open(reference_image_path))
                    target_hard_name = self.triplets[index]['target_hard']
                    target_image_path = os.path.join(base_path, self.name_to_relpath[target_hard_name])
                    target_image = self.preprocess(PIL.Image.open(target_image_path))

                    

                    pair_id = self.triplets[index]['pairid']
                    return reference_image, reference_image_path, target_image, target_image_path, None, rel_caption
                    # target_hard_name = self.triplets[index]['target_hard']
                    # return reference_name, target_hard_name, rel_caption, group_members

                elif self.split == 'test1':
                    reference_image_path = os.path.join(base_path, self.name_to_relpath[reference_name])
                    reference_image = self.preprocess(PIL.Image.open(reference_image_path))

                    pair_id = self.triplets[index]['pairid']
                    return reference_name, reference_image, rel_caption, pair_id, group_members

            elif self.mode == 'classic':
                image_name = list(self.name_to_relpath.keys())[index]
                image_path = os.path.join(base_path, self.name_to_relpath[image_name])
                im = PIL.Image.open(image_path)
                image = self.preprocess(im)

                # # This is only used for User survey generation
                # if self.split == 'val':
                #     return image, image_path, image_name
                return image, image_name

            else:
                raise ValueError("mode should be in ['relative', 'classic']")

        except Exception as e:
            print(f"Exception: {e}")

    def __len__(self):
        if self.mode == 'relative':
            return len(self.triplets)
        elif self.mode == 'classic':
            return len(self.name_to_relpath)
        else:
            raise ValueError("mode should be in ['relative', 'classic']")

    @torch.no_grad()
    def rerank_score(self, model, device, topk, txt_processors, configs, mode):
        self.negative = "random_rerank"
        self.hard_images = []
        model.eval()
        print("\n[START] Negative Method : Rerank Score")

        sample_dataloader = self.get_sample_loader()
        query_dataloader = self.get_query_loader()

        index_features, index_names = model.extract_target_features(sample_dataloader, configs['use_temp'], device)
        predicted_features, _, _, _ = model.extract_query_features_cirr(query_dataloader, configs['use_temp'],
                                                                        txt_processors, device)

        scores = model.score(predicted_features, index_features)
        score_sorted_indices = torch.argsort(scores, dim=-1, descending=True).cpu()

        for index in tqdm(range(len(self.triplets)), desc=f"[MODE: {mode}]Reranking with score and image similarity"):
            target_hard_name = self.triplets[index]['target_hard']
            # target_index = self.image_names.index(target_hard_name)

            # Compute Score (it consider not only the image, but also text)
            cur_query_score_sorted_indices = score_sorted_indices[index]

            if mode == 'topk':
                top_k_score_indices = [i for i in cur_query_score_sorted_indices.tolist()][:topk]

            elif mode == 'lower_target_topk':
                target_index = self.image_names.index(target_hard_name)
                sorted_indices = [i for i in cur_query_score_sorted_indices.tolist()]
                target_sorted_index = sorted_indices.index(target_index)
                lower_target_score_indices = sorted_indices[target_sorted_index:]
                top_k_score_indices = lower_target_score_indices[:topk]

            # wandb.log({"query_idx" : index, "number_of_hard_images" : len(top_k_score_indices)})
            topk_image_names = list(
                set([self.image_names[i] for i in top_k_score_indices if self.image_names[i] != target_hard_name]))
            self.hard_images.append(topk_image_names)

        model.train()
        print("Done\n")

    def get_sample_loader(self):
        sample_dataset = CIRRSampleDataset(base_path, self.image_names, self.name_to_relpath, self.preprocess)
        collate_fn = BLIPPaddingCollateFunctionTest4CIRR()
        sample_dataloader = DataLoader(sample_dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True, collate_fn=collate_fn)

        return sample_dataloader
    def get_query_loader(self):
        tokenizer = init_tokenizer()
        query_dataset = CIRRQueryDataset(base_path, self.triplets, tokenizer, self.name_to_relpath, self.preprocess)
        collate_fn = BLIPPaddingCollateFunctionTest4CIRR()
        query_dataloader = DataLoader(query_dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True, collate_fn=collate_fn)

        return query_dataloader
    @classmethod
    def code(cls):
        return 'cirr'

    @classmethod
    def all_codes(cls):
        return ['cirr']

    @classmethod
    def vocab_path(cls):
        return None

class CIRRSampleDataset(Dataset):
    def __init__(self, base_path, image_names, name_to_relpath, preprocess):
        self.base_path = base_path
        self.image_names = image_names
        self.name_to_relpath = name_to_relpath
        self.preprocess = preprocess

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        image_path = os.path.join(self.base_path, self.name_to_relpath[image_name])
        im = PIL.Image.open(image_path)
        image = self.preprocess(im)

        return image, image_name


class CIRRQueryDataset(Dataset):
    def __init__(self, base_path, triplets, tokenizer, name_to_relpath, preprocess):
        self.base_path = base_path
        self.triplets = triplets
        self.tokenizer = tokenizer
        self.name_to_relpath = name_to_relpath
        self.preprocess = preprocess

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, index):
        group_members = self.triplets[index]['img_set']['members']
        reference_name = self.triplets[index]['reference']
        rel_caption = self.triplets[index]['caption']
        reference_image_path = os.path.join(base_path, self.name_to_relpath[reference_name])
        reference_image = self.preprocess(PIL.Image.open(reference_image_path))
        pair_id = self.triplets[index]['pairid']

        return reference_name, reference_image, rel_caption, pair_id, group_members

if __name__ == '__main__':
    pass
