# -*- coding: utf-8 -*-
import os
import torch
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
from itertools import combinations
import hashlib

from datasets.custom_transforms import get_custom_timm_rand_augment, RandomJPEG, RandomGaussianBlur, CustomResizeKeepRatio


class OmniFakeEpisode(Dataset):
    """
    A minimal dataset that returns the episode data for N-way K-shot question. 

    This dataset does not support DDP parallel reading. a DataLoader with shuffle=False 
    must be used. During training, to ensure that different processes read distinct data, 
    you must assign a unique random seed to each process. 
    A batch of n_way * (k_shot + n_query) samples will form an episode. 

    Item data: 
        1. Random crop (train) / center crop (val) from the original image.
        2. A snap of the original image.
        3. Ground-truth class label (int).
    """
    def __init__(
        self,
        data_root: str,
        class_file_path: str,
        mode: str = "train",
        output_size: int = 224,
        n_way: int = 5, 
        k_shot: int = 5, 
        n_query: int = 5, 
        episodes_lengh: int = 10000, # total episodes in one epoch
    ):
        """
        total data length = episodes_lengh * n_way * (k_shot + n_query)
        """

        super().__init__()
        assert mode in {"train", "val"}, "mode must be 'train' or 'val'"
        self.mode = mode
        self.data_root = os.path.join(data_root, self.mode)
        self.output_size = output_size

        # Parse class file: "<int_id> <class_name>"
        self.idx_to_class = {}
        with open(class_file_path, "r") as f:
            for line in f:
                cls_id, cls_name = line.strip().split(" ")
                self.idx_to_class[int(cls_id)] = cls_name
        self.class_to_idx = {v: k for k, v in self.idx_to_class.items()}

        self.n_way = n_way
        self.k_shot = k_shot
        self.n_query = n_query

        # total data length in one epoch (just for program training based on epochs)
        self.total_length = episodes_lengh * self.n_way * (self.k_shot + self.n_query) 

        assert self.n_way <= len(self.idx_to_class), f"Not enough classes ({len(self.idx_to_class)}) in dataset for N-way ({self.n_way}) task"

        # Scan image paths
        self.samples_by_idx = {}
        for cls_id, cls_name in self.idx_to_class.items():
            cls_dir = os.path.join(self.data_root, cls_name)
            if not os.path.isdir(cls_dir): 
                raise FileNotFoundError(f"{cls_dir} does not exist")
            
            class_samples = []
            for file in os.listdir(cls_dir):
                if file.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")): 
                    class_samples.append(os.path.join(cls_dir, file))
            
            self.samples_by_idx[cls_id] = class_samples
        
        # Build transforms
        if self.mode == "train":
            self.img_augment = self.get_train_transforms(self.output_size + 32) # with a little margin
        else:
            self.img_augment = self.get_val_transforms(self.output_size + 32)
        
        # for path 1
        self.transform_crop = transforms.RandomCrop(self.output_size)

        # for path 2
        self.transform_resize = transforms.Compose([
            transforms.Resize(self.output_size),  # resize short edge with bilinear
            transforms.RandomCrop(self.output_size),
        ])

        """
        Warning: the combinations should not be too large. 
        We just use 5 in 15, 15 in 15, and 30 in 30. 
        """
        self.uniform_class_samples = [list(c) for c in combinations(list(self.samples_by_idx.keys()), self.n_way)]

        self.epoch = 0
    
    def __len__(self) -> int:
        return self.total_length # fake length
    
    def __getitem__(self, idx): 
        """
        To ensure that every n_way * (k_shot + n_query) consecutive samples within 
        a single process form a complete episode task, we adopt sequential sampling 
        and require that batch_size be an integer multiple of n_way * (k_shot + n_query). 
        """

        batch_num = self.k_shot + self.n_query
        # get class
        episode_id = (idx // (self.n_way * batch_num)) % len(self.uniform_class_samples)
        class_id = (idx % (self.n_way * batch_num)) // batch_num
        uniform_sampled_class = self.uniform_class_samples[episode_id][class_id]

        # get sample indice
        batch_id = idx // batch_num
        indice_id = idx % batch_num
        seed = self.epoch_batch_seed(self.epoch, batch_id)
        gen = torch.Generator()
        gen.manual_seed(seed)
        indice = torch.randperm(len(self.samples_by_idx[uniform_sampled_class]), generator=gen)[indice_id].item()

        # read img
        img_path = self.samples_by_idx[uniform_sampled_class][indice]
        img = Image.open(img_path) # convert to RGB in RandomJPEG
        img = self.img_augment(img)
        img_crop = self.transform_crop(img)
        img_resize = self.transform_resize(img)
        label = torch.tensor(idx) # uniform_sampled_class

        return img_crop, img_resize, label
    
    @staticmethod
    def get_train_transforms(min_size):
        return transforms.Compose([
            RandomJPEG(compress_module=["pil"]), 
            CustomResizeKeepRatio(min_size=min_size, scale_range=(0.5, 2.0)), 
            transforms.RandomHorizontalFlip(p=0.5), 
            get_custom_timm_rand_augment(), 
            RandomGaussianBlur(), 
            transforms.ToTensor(), 
        ])
    
    @staticmethod
    def get_val_transforms(min_size):
        return transforms.Compose([
            RandomJPEG(p=0.0),                               # no compression, just transforms to RGB
            CustomResizeKeepRatio(min_size=min_size, p=0.0), # no modification if possible
            transforms.ToTensor(), 
        ])
    
    @staticmethod
    def epoch_batch_seed(epoch: int, batch: int) -> int: 
        # generate one seed for a batch of (k_shot + n_query) samples
        key = (epoch << 32) | (batch & 0xFFFFFFFF)
        return int.from_bytes(
            hashlib.blake2b(key.to_bytes(8, 'little'), digest_size=8).digest(),
            'little'
        )
    
    # set when starting a new epoch
    def set_epoch(self, epoch): 
        self.epoch = epoch
        # remember setting different seeds for each processes
        indices = torch.randperm(len(self.uniform_class_samples)).tolist() # random is ok
        shuffled_ist = [self.uniform_class_samples[i] for i in indices]
        self.uniform_class_samples = shuffled_ist

