import os
import numpy as np
import torch
from torch.utils.data import Dataset
import cv2
import glob
from lightning import LightningDataModule
from torch.utils.data import DataLoader

def get_roi(mask, margin=8):
    """
    """
    h0, w0 = mask.shape[:2]
    
    if  mask is not None:
        rows, cols = np.nonzero(mask)
        rowmin, rowmax = np.min(rows), np.max(rows)
        colmin, colmax = np.min(cols), np.max(cols)
        row, col = rowmax - rowmin, colmax - colmin
        
        flag = not (rowmin - margin <= 0 or rowmax + margin > h0 or 
                    colmin - margin <= 0 or colmax + margin > w0)
        
        if row > col and flag:
            r_s, r_e = rowmin - margin, rowmax + margin
            c_s, c_e = max(colmin - int(0.5 * (row - col)) - margin, 0), \
                       min(colmax + int(0.5 * (row - col)) + margin, w0)
        elif col >= row and flag:
            r_s, r_e = max(rowmin - int(0.5 * (col - row)) - margin, 0), \
                       min(rowmax + int(0.5 * (col - row)) + margin, h0)
            c_s, c_e = colmin - margin, colmax + margin
        else:
            r_s, r_e, c_s, c_e = 0, h0, 0, w0
    else:
        r_s, r_e, c_s, c_e = 0, h0, 0, w0
    
    return np.array([h0, w0, r_s, r_e, c_s, c_e])

def crop_and_resize_img(img, mask, roi, max_image_resolution=4000):
    
   
    h0, w0, r_s, r_e, c_s, c_e = roi
    
    img = img[r_s:r_e, c_s:c_e, :]
 
    
    h = max(512, min(max_image_resolution, (max(img.shape[:2]) // 512) * 512))
    w = h
    
    img = cv2.resize(img, (w, h), interpolation=cv2.INTER_CUBIC)

    
    bit_depth = 255.0 if img.dtype == np.uint8 else 65535.0 if img.dtype == np.uint16 else 1.0
    img = np.float32(img) / bit_depth
    
    return img

def crop_and_resize_mask(img,mask, roi, max_image_resolution=4000):
    
    
    h0, w0, r_s, r_e, c_s, c_e = roi
    
    img = img[r_s:r_e, c_s:c_e, :]
    mask = mask[r_s:r_e, c_s:c_e]
    
    h = max(512, min(max_image_resolution, (max(img.shape[:2]) // 512) * 512))
    w = h
    
   
    mask = np.float32(cv2.resize(mask, (w, h), interpolation=cv2.INTER_CUBIC) > 0.5)
    
    return mask



class Data(Dataset):
    def __init__(self, data_dir, numberOfImages=96):
        self.data_dir = [data_dir]
        
        
        
        self.numberOfImages = numberOfImages
        for i in range(len(self.data_dir)):
            print('Initialize %s' % (self.data_dir[i]))
            objlist = []
            for entry in os.scandir(self.data_dir[i]):
                if entry.is_dir():  
                    objlist.append(entry.path)  
            objlist = sorted(objlist)
        self.objlist = objlist
        total = len(objlist)

        
    
    def __len__(self):
        return len(self.objlist)
    
    def load(self, objlist, objid, numberOfImages):
        scale = 1.0
        direcdtlist = []
        directlist = glob.glob(os.path.join(objlist[objid], f"L*"))
        directlist = sorted(directlist)
        if self.numberOfImages is not None:
            indexset = np.random.permutation(len(directlist))[:self.numberOfImages]
        else:
            indexset = range(len(directlist))
        for i, indexofimage in enumerate(indexset):
            img_path = directlist[indexofimage]
            if i == 0:
                _ = cv2.imread(img_path, flags = cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
                if _ is None:
                    return 0
                img = cv2.cvtColor(cv2.imread(img_path, flags = cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH), cv2.COLOR_BGR2RGB)
                
                h_ = img.shape[0]
                w_ = img.shape[1]
                h0 = int(scale * h_)
                w0 = int(scale * w_)
                img = cv2.resize(img, dsize=None, fx= scale, fy=scale, interpolation=cv2.INTER_NEAREST)
                h = img.shape[0]
                w = img.shape[1]

            else:
                _ = cv2.imread(img_path, flags = cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
                if _ is None:
                    return 0
                img = cv2.resize(cv2.cvtColor(cv2.imread(img_path, flags = cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH), cv2.COLOR_BGR2RGB), dsize=None, fx=scale, fy=scale,interpolation=cv2.INTER_NEAREST)            
            
            mask_path = os.path.join(objlist[objid], "mask.png")
            if i == 0:
                
                
                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)/255
                
                
                
                self.roi = get_roi(mask)
                mask = crop_and_resize_mask(img, mask, self.roi)
            try:
                img= crop_and_resize_img(img, mask, self.roi)
            except:
                print('Error')
            h, w = img.shape[:2]
            if i == 0:
                I = np.zeros((len(indexset), h, w, 3), np.float32)
            I[i, :, :, :] = img
            

        I = np.reshape(I, (-1, h * w, 3))


        temp = np.mean(I[:, mask.flatten()==1,:], axis=2)
        mx = np.max(temp, axis=1)
        temp = mx
        I /= (temp.reshape(-1,1,1) + 1.0e-6)

        
        


        
        
        
        
  
        
        I = np.transpose(I, (1, 2, 0))
        print(min(self.numberOfImages,len(directlist)))
        I = I.reshape(h, w, 3, min(self.numberOfImages,len(directlist)))
        mask = (mask.reshape(h, w, 1)).astype(np.float32) 

       
            



        h = mask.shape[0]
        w = mask.shape[1]
        self.h = h
        self.w = w
        self.I = I 

        self.N = np.ones((h, w, 3), np.float32)
        
        self.mask = mask
        self.directlist = directlist



    def __getitem__(self, idx):
        self.load(self.objlist, idx, self.numberOfImages)
        img = self.I.transpose(2,0,1,3) 
        
        mask = self.mask.transpose(2,0,1) 
        directlist = self.directlist
        roi = self.roi
        return img, self.N, mask, directlist,roi
        
        
    
        



class DataModule(LightningDataModule):

    def __init__(
        self,
        
        batch_size_per_device: int = 1,
       num_workers: int = 8,
        
        data_root: str = None,
        
        numberofImages: int = None,
        **kwargs
    ):
       
        super().__init__()
        self.save_hyperparameters()
        self.num_workers = num_workers
        self.data_root = data_root
        self.numberofImages = numberofImages
        self.batch_size = batch_size_per_device
    
    def setup(self, stage=None):
    
    
        if stage == "test" or stage is None:
            self.test_dataset = Data(
            
                data_dir=self.data_root,
   
                numberOfImages=self.numberofImages,
            )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )
