from os.path import join
import pytorch_lightning as pl
import numpy as np
from PIL import Image
from sympy import im
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
from pycocotools.coco import COCO
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torchvision.transforms import functional as F
from slot_attention.model.model_utils import coco_rescale

COCO_ROOT_PATH = '/system/user/publicdata/coco'

class CocoDataset(Dataset):
    def __init__(self, params, root, annFile, transform, target_transform) -> None:
        super().__init__()
        
        s = '''
        IMPORTANT: As reported here (Table 21 and Figure 5), we do Object Discovery on the COCO dataset with central crops of size 320x320.
        https://arxiv.org/pdf/2209.14860.pdf
        '''
        print(s)
        
        self.cocods = dset.CocoDetection(
            root=root,
            annFile=annFile,
            transform=transform,
            target_transform=target_transform,
        )
        
        resolution = params['resolution']
        height, width = resolution
        min_h_w = min(resolution)
        
        self.imgs = [None] * len(self.cocods)
        self.masks = [None] * len(self.cocods)
        
        self.coco = COCO(annFile)
        self.to_tensor = transforms.ToTensor()
        # Table 21 from https://arxiv.org/pdf/2209.14860.pdf
        # For central crops, we first resize the mask such that the short side is 320 pixels long, 
        # then take the most centered crop of size 320 × 320.
        self.resize = transforms.Resize([height,width], antialias=True)
        self.resize_mask = transforms.Resize([height,width], interpolation=Image.NEAREST)
        self.rescale = transforms.Lambda(coco_rescale)
    
    
    def __len__(self):
        return len(self.cocods)
    
    def __getitem__(self, index):
        img, target = self.cocods[index]
        
        if self.imgs[index] is not None:
            return self.imgs[index], self.masks[index]
        
        # print('img', img.shape)
        
        '''
        https://arxiv.org/pdf/2209.14860.pdf
        We train on the COCO 2017 dataset with 118 287 images, and evaluate on the validation
        set with 5 000 images. For object discovery, we use both instance and segmentation masks, converting
        instance masks into segmentation masks using a per-pixel arg max over classes. Overlaps between
        instances are ignored during metric evaluation, and crowd instance annotations are not used.
        '''
        
        mask_list = [np.zeros_like(img[0])]
        for ann in target:
            if ann['iscrowd'] == 1:
                continue
            cur_mask = self.coco.annToMask(ann)
            # print('cur_mask', cur_mask.shape)
            mask_list.append(cur_mask)
        
        # print('mask_list', len(mask_list))
        
        masks = np.stack(mask_list, axis=0)
        
        # exclude bg since we explicitly use FG-ARI (see Figure 5 from https://arxiv.org/pdf/2209.14860.pdf)
        # bg_mask = np.ones_like(masks[0]) - np.sum(masks, axis=0)
        # masks = np.concatenate([masks, np.expand_dims(bg_mask, axis=0)], axis=0)
        
        # resize to 128x128
        # convert to tensor
        # img = self.to_tensor(img)
        masks = torch.from_numpy(masks).float()

        # rescale between -1 and 1
        img = self.rescale(img)
        
        min_h_w = min(img.shape[1:])
        F.center_crop(img, min_h_w)

        # resize to target resolution
        img = self.resize(img)
        masks = self.resize_mask(masks)
        
        self.imgs[index] = img
        self.masks[index] = masks
        
        return img, masks

class CocoDataModule(pl.LightningDataModule):
    def __init__(
        self,
        params: dict,
        train_batch_size: int,
        val_batch_size: int,
        num_workers: int,
    ):
        super().__init__()
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.num_workers = num_workers

        if params.mode == 'train':
            images_dataset = 'train2017'
            annot_dataset = 'instances_train2017.json'
        elif params.mode == 'test':
            images_dataset = 'val2017'
            annot_dataset = 'instances_val2017.json'
        else:
            raise ValueError(f'Unknown mode: {params.mode}')
        
        full_dataset = CocoDataset(
            params=params,
            root = join(COCO_ROOT_PATH, 'images', images_dataset),
            annFile = join(COCO_ROOT_PATH, 'annotations', annot_dataset),
            transform=transforms.ToTensor(),
            target_transform=None,
        )
        
        train_size = int(params.train_val_percent * len(full_dataset))
        val_size = len(full_dataset) - train_size
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(
            full_dataset, [train_size, val_size]
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=collate_fn,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.val_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=collate_fn,
        )
        
def collate_fn(batch):
    """
    Custom collate function for DataLoader.
    Args:
        batch (list): List of tuples, where each tuple contains (image, mask).

    Returns:
        A tuple containing a batch of images and a batch of masks.
    """
    images, masks = zip(*batch)
    
    # Determine the maximum number of channels in the masks
    max_mask_channels = max(mask.shape[0] for mask in masks)

    # Initialize empty tensors for images and masks
    batch_images = torch.zeros(len(images), 3, images[0].shape[1], images[0].shape[2], dtype=torch.float32, device=images[0].device)
    batch_masks = torch.zeros(len(masks), max_mask_channels, masks[0].shape[1], masks[0].shape[2], dtype=torch.float32, device=masks[0].device)

    # Fill in the tensors with data
    for i, image in enumerate(images):
        batch_images[i, :, :, :] = image
    
    for i, mask in enumerate(masks):
        batch_masks[i, :mask.shape[0], :, :] = mask

    return batch_images, batch_masks
        
if __name__ == '__main__':
    params = {
        'mode': 'train',
        'resolution': [128, 128],
    }
    train_dataset = CocoDataset(
            params=params,
            root = join(COCO_ROOT_PATH, 'images', 'train2017'),
            annFile = join(COCO_ROOT_PATH, 'annotations', 'instances_train2017.json'),
            transform=transforms.ToTensor(),
            target_transform=None,
        )
    
    print('len', len(train_dataset))
    
    img, target = train_dataset[0] # load 4th sample
    print('img', img.shape)
    print('img', max(img[0].flatten()))
    print('img', min(img[0].flatten()))
    print('target', target)
    print('target', target.shape)
    print('target', max(target[0].flatten()))
    print('target', min(target[0].flatten()))