import json
import numpy as np
import os
import random
import wandb
from tqdm import tqdm
from typing import List
from pathlib import Path

# from data.utils import _get_img_from_path
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop, RandomHorizontalFlip
from data.utils import init_tokenizer
from data.collate_fns import BLIPPaddingCollateFunctionTest4CIRR, BLIPPaddingCollateFunctionTest4FIQ
import torchvision.transforms.functional as F
import torch

import PIL
import PIL.Image

from data.utils import _get_img_from_path


_DEFAULT_FASHION_IQ_DATASET_ROOT = ''

base_path = Path(_DEFAULT_FASHION_IQ_DATASET_ROOT)
target_ratio = 1.5
input_dim = 224

# Credits: Preprocess taken from CLIP4CIR
# https://github.com/ABaldrati/CLIP4Cir

def collate_fn(batch):
    '''
    function which discard None images in a batch when using torch DataLoader
    :param batch: input_batch
    :return: output_batch = input_batch - None_values
    '''
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

def _convert_image_to_rgb(image):
    return image.convert("RGB")

class TargetPad:
    """
    Pad the image if its aspect ratio is above a target ratio.
    Pad the image to match such target ratio
    """

    def __init__(self, target_ratio: float, size: int):
        """
        :param target_ratio: target ratio
        :param size: preprocessing output dimension
        """
        self.size = size
        self.target_ratio = target_ratio

    def __call__(self, image):
        w, h = image.size
        actual_ratio = max(w, h) / min(w, h)
        if actual_ratio < self.target_ratio:  # check if the ratio is above or below the target ratio
            return image
        scaled_max_wh = max(w, h) / self.target_ratio  # rescale the pad to match the target ratio
        hp = max(int((scaled_max_wh - w) / 2), 0)
        vp = max(int((scaled_max_wh - h) / 2), 0)
        padding = [hp, vp, hp, vp]
        return F.pad(image, padding, 0, 'constant')

def image_transform(config, mode='train'):
    IMAGENET_STATS = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}
    img_size = config['img_size']

    if mode == 'train':
        return Compose([RandomResizedCrop(size=img_size, scale=(0.75, 1.33)),
                                   RandomHorizontalFlip(),
                                   ToTensor(),
                                   Normalize(**IMAGENET_STATS)])
    elif mode == 'val':
        return Compose([Resize((img_size, img_size)), ToTensor(),
                               Normalize(**IMAGENET_STATS)])

def targetpad_transform(target_ratio: float, dim: int):
    """
        CLIP-like preprocessing transform computed after using TargetPad pad
        :param target_ratio: target ratio for TargetPad
        :param dim: image output dimension
        :return: CLIP-like torchvision Compose transform
        """
    return Compose([
        TargetPad(target_ratio, dim),
        Resize(dim, interpolation=PIL.Image.BICUBIC),
        CenterCrop(dim),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])



class FashionIQDataset(Dataset):
    """
    FashionIQ dataset class which manage FashionIQ data.
    The dataset can be used in 'relative' or 'classic' mode:
        - In 'classic' mode the dataset yield tuples made of (image_name, image)
        - In 'relative' mode the dataset yield tuples made of:
            - (reference_image, target_image, image_captions) when split == train
            - (reference_name, target_name, image_captions) when split == val
            - (reference_name, reference_image, image_captions) when split == test
    The dataset manage an arbitrary numbers of FashionIQ category, e.g. only dress, dress+toptee+shirt, dress+shirt...
    """

    def __init__(self, split, dress_types, mode, preprocess: callable, config):
        """
        :param split: dataset split, should be in ['test', 'train', 'val']
        :param dress_types: list of fashionIQ category
        :param mode: dataset mode, should be in ['relative', 'classic']:
            - In 'classic' mode the dataset yield tuples made of (image_name, image)
            - In 'relative' mode the dataset yield tuples made of:
                - (reference_image, target_image, image_captions) when split == train
                - (reference_name, target_name, image_captions) when split == val
                - (reference_name, reference_image, image_captions) when split == test
        :param preprocess: function which preprocesses the image
        """
        self.preprocess = targetpad_transform(target_ratio, input_dim)
        # self.preprocess = image_transform(config, split)
        self.mode = mode
        self.split = split
        self.config = config
        self.dress_types = dress_types
        self.negative = self.config["negative"]

        if mode not in ['relative', 'classic']:
            raise ValueError("mode should be in ['relative', 'classic']")
        if split not in ['test', 'train', 'val']:
            raise ValueError("split should be in ['test', 'train', 'val']")
        for dress_type in dress_types:
            if dress_type not in ['dress', 'shirt', 'toptee']:
                raise ValueError("dress_type should be in ['dress', 'shirt', 'toptee']")

        # get triplets made by (reference_image, target_image, a pair of relative captions)
        self.triplets: List[dict] = []
        self.domain_triplets = dict()
        self.dress_type_indices_relative = dict()
        for dress_type in dress_types:
            with open(base_path /  'captions' / f'cap.{dress_type}.{split}.json') as f:
                triplets = json.load(f)

                start_idx = len(self.triplets)
                for triplet in triplets:
                    triplet['dress_type'] = dress_type
                self.triplets.extend(triplets)
                end_idx = len(self.triplets)
                self.dress_type_indices_relative[dress_type] = list(range(start_idx, end_idx))

                self.domain_triplets[dress_type] = triplets

        # get the image names
        self.domain_image_names = dict()
        self.image_names: list = []
        self.dress_type_indices_classic = dict()
        for dress_type in dress_types:
            with open(base_path / 'image_splits' / f'split.{dress_type}.{split}.json') as f:
                images = json.load(f)
                start_idx = len(self.image_names)
                self.image_names.extend(images)
                end_idx = len(self.image_names)
                self.dress_type_indices_classic[dress_type] = list(range(start_idx, end_idx))
                self.domain_image_names[dress_type] = images

        self.hard_images = []

        # if self.mode == 'classic':
        #     self.image_names = []
        #     ref_list = [triplet['candidate'] for triplet in self.triplets]
        #     targ_list = [triplet['target'] for triplet in self.triplets]
        #     self.image_names.extend(ref_list)
        #     self.image_names.extend(targ_list)
        #     self.image_names = list(set(self.image_names))
        #     # print(f"Num in {dress_types} : {len(self.image_names)}")

        print(f"FashionIQ {split} - {dress_types} dataset in {mode} mode initialized")

    def __getitem__(self, index):
        try:
            if self.mode == 'relative':
                image_captions = self.triplets[index]['captions']
                reference_name = self.triplets[index]['candidate']
                rel_caption = image_captions[0].strip('.?, ').capitalize() + " and " + image_captions[1].strip('.?, ')
                # rel_caption = caption_post_process(rel_caption)

                if self.split == 'train':
                    reference_image_path = base_path / 'images' / f"{reference_name}.png"
                    reference_image = self.preprocess(PIL.Image.open(reference_image_path))
                    # reference_image = _get_img_from_path(reference_image_path, self.preprocess)
                    target_name = self.triplets[index]['target']
                    target_image_path = base_path / 'images' / f"{target_name}.png"
                    target_image = self.preprocess(PIL.Image.open(target_image_path))
                    # target_image = _get_img_from_path(target_image_path, self.preprocess)

                    cur_dress_type = self.triplets[index]['dress_type']
                    # self.non_target_pool = self.domain_image_names[cur_dress_type]
                    self.non_target_pool = [triplet['target'] for triplet in self.domain_triplets[cur_dress_type]]

                    if self.negative == 'random': # random negative sampling
                        cur_non_target_pool = [name for name in self.non_target_pool if name != target_name]
                        negative_name = random.choice(cur_non_target_pool) # Sample from Target pool (!= target)
                        negative_targ_img_path = base_path / 'images' / f"{negative_name}.png"
                        negative_target_img = self.preprocess(PIL.Image.open(negative_targ_img_path))
                        # negative_target_img = _get_img_from_path(negative_targ_img_path, self.preprocess)
                    elif self.negative == 'random_rerank': # hard negative sampling after reranking
                        negative_name = random.choice(self.hard_images[index])
                        negative_targ_img_path = base_path / 'images' / f"{negative_name}.png"
                        negative_target_img = self.preprocess(PIL.Image.open(negative_targ_img_path))
                    else:
                        raise ValueError("Undefined Negative Sampling Method")
                    
                    return reference_image, target_image, negative_target_img, target_name, rel_caption

                elif self.split == 'val':
                    reference_image_path = base_path / 'images' / f"{reference_name}.png"
                    reference_image = self.preprocess(PIL.Image.open(reference_image_path))
                    # reference_image = _get_img_from_path(reference_image_path, self.preprocess)
                    target_name = self.triplets[index]['target']
                    return reference_image, rel_caption, target_name
                
                elif self.split == 'test':
                    reference_image_path = base_path / 'images' / f"{reference_name}.png"
                    reference_image = self.preprocess(PIL.Image.open(reference_image_path))
                    return reference_name, reference_image, rel_caption

            elif self.mode == 'classic':
                image_name = self.image_names[index]
                image_path = base_path / 'images' / f"{image_name}.png"
                image = self.preprocess(PIL.Image.open(image_path))
                # image = _get_img_from_path(image_path, self.preprocess)
                return image, image_name

            else:
                raise ValueError("mode should be in ['relative', 'classic']")
        except Exception as e:
            print(f"Exception: {e}")

    def __len__(self):
        if self.mode == 'relative':
            return len(self.triplets)
        elif self.mode == 'classic':
            return len(self.image_names)
        else:
            raise ValueError("mode should be in ['relative', 'classic']")

    @torch.no_grad()
    def rerank_score(self, model, device, topk, txt_processors, configs, mode):
        self.negative = "random_rerank"
        self.hard_images = []
        model.eval()
        print(f"\n[START] [MODE: {mode}] Negative Method : Rerank Score")

        sample_dataloader = self.get_sample_loader()
        query_dataloader = self.get_query_loader()

        predicted_features, _ = model.extract_query_features_fiq(query_dataloader, configs['use_temp'], txt_processors, device)
        index_features, index_names = model.extract_target_features(sample_dataloader, configs['use_temp'], device)

        scores = model.score(predicted_features, index_features)

        for index in tqdm(range(len(self.triplets)), desc=f"[MODE: {mode}]Reranking with score and image similarity"):
            target_name = self.triplets[index]['target']
            cur_dress_type = self.triplets[index]['dress_type']
            # target_index = [triplet['target'] for triplet in self.triplets].index(target_name)

            cur_query_scores = scores[index]
            cur_query_cur_dress_type_scores = cur_query_scores[self.dress_type_indices_relative[cur_dress_type]]
            # cur_query_cur_dress_type_scores = cur_query_scores[self.dress_type_indices_classic[cur_dress_type]]
            cur_query_score_sorted_indices = torch.argsort(cur_query_cur_dress_type_scores, dim=-1,
                                                           descending=True).cpu()
            if mode == 'topk':
                top_k_score_indices = [i for i in cur_query_score_sorted_indices.tolist()][:topk]

            elif mode == 'lower_target_topk':
                cur_dress_type_target_images = [triplet['target'] for triplet in self.domain_triplets[cur_dress_type]]
                # cur_dress_type_target_images = self.domain_image_names[cur_dress_type]
                target_index = cur_dress_type_target_images.index(target_name)
                sorted_indices = [i for i in cur_query_score_sorted_indices.tolist()]
                target_sorted_index = sorted_indices.index(target_index)
                lower_target_score_indices = sorted_indices[target_sorted_index:]
                top_k_score_indices = lower_target_score_indices[:topk]

            # cur_dress_type_target_images = self.domain_image_names[cur_dress_type]
            cur_dress_type_target_images = [triplet['target'] for triplet in self.domain_triplets[cur_dress_type]]
            topk_image_names = list(set([cur_dress_type_target_images[i] for i in top_k_score_indices if cur_dress_type_target_images[i] != target_name]))
            self.hard_images.append(topk_image_names)


        model.train()
        print("Done\n")

    def get_sample_loader(self):
        all_targets = [triplet['target'] for triplet in self.triplets]
        sample_dataset = FashionIQSampleDataset(base_path, all_targets, self.preprocess)
        collate_fn = BLIPPaddingCollateFunctionTest4CIRR()
        sample_dataloader = DataLoader(sample_dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True, collate_fn=collate_fn)

        return sample_dataloader
    def get_query_loader(self):
        tokenizer = init_tokenizer()
        query_dataset = FashionIQQueryDataset(base_path, self.triplets, tokenizer, self.preprocess)
        collate_fn = BLIPPaddingCollateFunctionTest4FIQ()
        query_dataloader = DataLoader(query_dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True, collate_fn=collate_fn)

        return query_dataloader

    @classmethod
    def code(cls):
        return 'fashionIQ'

    @classmethod
    def all_codes(cls):
        return ['fashionIQ']
    
    @classmethod
    def all_subset_codes(cls):
        return ['dress', 'shirt', 'toptee']

    @classmethod
    def vocab_path(cls):
        return None
        
class FashionIQSampleDataset(Dataset):
    def __init__(self, base_path, image_names, preprocess):
        self.base_path = base_path
        self._image_names = image_names
        self.preprocess = preprocess

    def __len__(self):
        return len(self._image_names)

    def __getitem__(self, idx):
        image_name = self._image_names[idx]
        image_path = base_path / 'images' / f"{image_name}.png"
        im = PIL.Image.open(image_path)
        image = self.preprocess(im)

        return image, image_name


class FashionIQQueryDataset(Dataset):
    def __init__(self, base_path, triplets, tokenizer, preprocess):
        self.base_path = base_path
        self.triplets = triplets
        self.tokenizer = tokenizer
        self.preprocess = preprocess

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, index):
        reference_name = self.triplets[index]['candidate']
        target_name = self.triplets[index]['target']
        image_captions = self.triplets[index]['captions']

        reference_image_path = base_path / 'images' / f"{reference_name}.png"
        reference_image = self.preprocess(PIL.Image.open(reference_image_path))
        # rel_caption = image_captions[0] + " and " + image_captions[1]
        rel_caption = f"{image_captions[0].strip('.?, ').capitalize()} and {image_captions[1].strip('.?, ')}"

        return target_name, reference_image, rel_caption