
import cv2
from glob import glob
from tqdm.notebook import tqdm
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms as T
import albumentations as A
import matplotlib.pyplot as plt
from albumentations.pytorch.transforms import ToTensorV2


# batch_size = 8
# path = '/DATA/saketh/datasets/suim/train_val/'
# image_size = 256

class SUIM(torch.utils.data.Dataset):
  def __init__(self, datapath, transform_img, transform_mask, image_size):
    self.datapath = datapath
    self.img_paths = glob(self.datapath + "/images/*.jpg")
    self.transform_img = transform_img
    self.transform_mask = transform_mask

    self.size = image_size

  def __len__(self):
    return len(self.img_paths)

  def __getitem__(self, idx):
    img_path = self.img_paths[idx]
    mask_path = self.datapath + "/masks/" + img_path.split('/')[-1][:-3] + 'bmp'
    image = cv2.imread(img_path)
    dataset_mask = cv2.imread(mask_path)
    # print(dataset_mask.shape    )
    transformed_img = self.transform_img(image=image, mask=image)
    transformed_mask = self.transform_mask(image=dataset_mask, mask=dataset_mask)

    mask = self.gen_mask(transformed_mask['mask'])
    # print(mask.shape)
    # mask = self.gen_mask(dataset_mask)

    return transformed_img['image'] * 1.0, mask
  
  def gen_mask(self,dataset_mask):
    # print(dataset_mask.shape)
    # dataset_mask = T.functional.center_crop(dataset_mask.permute(2,1,0),
    #                                        (1024,1024))           
    dataset_mask = dataset_mask.permute(2,1,0)                        
    mod_mask = (dataset_mask > 100)                                     
    mask = np.dot(mod_mask.permute(1,2,0).numpy(), [4,2,1]).astype(np.uint8)
    return torch.Tensor(mask)


# dataset = SUIM(datapath=path,
#                transform= A.Compose([
#                    A.PadIfNeeded(min_height=image_size, min_width=image_size,p=1),
#                    A.CenterCrop(height=image_size,width=image_size,p=1),
#                    ToTensorV2()]),
#                image_size=image_size)

# train_data, test_data = torch.utils.data.random_split(
#     dataset,
#     [1220, 305],
#     generator=torch.Generator().manual_seed(42)
#     )
# train_loader = DataLoader(train_data, batch_size=batch_size,
#                           shuffle=True, num_workers=0,
#                           drop_last=True)
# test_loader = DataLoader(test_data, batch_size=batch_size,
#                           shuffle=False, num_workers=0,
#                           drop_last=False)


# for idx, (data, mask) in enumerate(tqdm(train_loader)):
#         data, mask = data, mask
#         # print(mask.shape)