import numpy as np
from torch.utils.data import Dataset, DataLoader
from glob import glob
import os
import cv2
import albumentations as A
from utils.flow_utils import *
from glob import glob


class MS2_3modal_dataset(Dataset):
    def __init__(self, split, modal, args):
        self.args = args
        self.transform = A.Compose([A.RandomBrightnessContrast(brightness_limit=0.4, contrast_limit=0.1, brightness_by_max=False, always_apply=True),
                                    A.Sharpen()])
        self.split = split

        root_dir = '/path/to/dataset/'
        seq_paths = sorted([os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        train_list, test_list = [], []
        for seq_path in seq_paths:
            img_dir = os.path.join(seq_path, 'params')
            img_list = sorted(glob(os.path.join(img_dir, '*.npy')))
            N = len(img_list)
            start_idx = int(N * 0.8)
            end_idx = int(N)
            for i, img_path in enumerate(img_list):
                if start_idx <= i < end_idx:
                    test_list.append(img_path)
                else:
                    train_list.append(img_path)
        if split == 'train':
            self.flow_list = sorted(train_list)
        else:
            self.flow_list = sorted(test_list)
        
        if modal == 'thr2rgb':
            self.modal_B, self.modal_A = 'thr', 'rgb'
        elif modal == 'nir2rgb':
            self.modal_B, self.modal_A = 'nir', 'rgb'
        elif modal == 'thr2nir':
            self.modal_B, self.modal_A = 'thr', 'nir'
        
        self.gen_data = args.gen_data

    def __len__(self):
        return len(self.flow_list)
    
    def read_img(self, path):
        img = cv2.imread(str(path))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return img
        
    def __getitem__(self, index):
        index = index % len(self.flow_list)
        
        imgA = self.read_img(self.flow_list[index].replace('params', 'img_'+self.modal_A+'_L').replace('.npy', '.png')) # [H, W, 3]
        imgB = self.read_img(self.flow_list[index].replace('params', 'img_'+self.modal_B+'_L').replace('.npy', '.png'))
        dptA = np.load(self.flow_list[index].replace('params', 'dpt_'+self.modal_A+'_L_unidepth')) # [1, H, W]
        dptB = np.load(self.flow_list[index].replace('params', 'dpt_'+self.modal_B+'_L_unidepth'))
        params = np.load(self.flow_list[index], allow_pickle=True).item()
        
        K_A = params[('K_'+self.modal_A.upper())] # [3,3]
        K_B = params[('K_'+self.modal_B.upper())] # [3,3]
        T_B2A = params[('T_'+ self.modal_B.upper() + '2' +self.modal_A.upper())] # [3,4]

        flow_B2A = np.load(self.flow_list[index].replace('params', 'flow_'+ self.modal_B + '2' +self.modal_A)) # [2, H, W]
        mask_B2A = (flow_B2A[0] != 0) & (flow_B2A[1] != 0)
        
        if self.split == 'train':
            noise = np.random.randn(3)
            T_B2A[:3, 3] += noise
            flow_A2self, mask_A2self, imgA_self, imgA_w = syn_flow_data(imgA, dptA, K_B, K_A, T_B2A)
            flow_B2self, mask_B2self, imgB_self, imgB_w = syn_flow_data(imgB, dptB, K_B, K_A, T_B2A)
            imgA_self = (imgA_self/255.0).permute(2, 0, 1).float()
            imgB_self = (imgB_self/255.0).permute(2, 0, 1).float()
            imgA = torch.from_numpy(imgA/255.0).permute(2, 0, 1).float()
            imgB = torch.from_numpy(imgB/255.0).permute(2, 0, 1).float()
            flow_B2A = torch.from_numpy(flow_B2A).float()
            mask_B2A = torch.from_numpy(mask_B2A)
            return {"imgA":imgA_self, "imgA_self":imgA_self, "flow_A2self":flow_A2self, "mask_A2self":mask_A2self,
                    "imgB":imgB, "imgB_self":imgB_self, "flow_B2self":flow_B2self, "mask_B2self":mask_B2self,
                    "flow_B2A":flow_B2A, "mask_B2A":mask_B2A}
        else:
            imgA = torch.from_numpy(imgA/255.0).permute(2, 0, 1).float()
            imgB = torch.from_numpy(imgB/255.0).permute(2, 0, 1).float()
            flow_B2A = torch.from_numpy(flow_B2A).float()
            mask_B2A = torch.from_numpy(mask_B2A)
            return {"imgA":imgA,
                    "imgB":imgB,
                    "flow_B2A":flow_B2A, "mask_B2A":mask_B2A, "index": self.flow_list[index]}


def fetch_dataloader(args, split='train'):
    if split == 'train':
        if args.dataset == 'MS2_TIR2RGB':
            train_dataset = MS2_3modal_dataset(split='train', modal='thr2rgb', args=args)
        elif args.dataset == 'MS2_NIR2RGB':
            train_dataset = MS2_3modal_dataset(split='train', modal='nir2rgb', args=args)
        elif args.dataset == 'MS2_TIR2NIR':
            train_dataset = MS2_3modal_dataset(split='train', modal='thr2nir', args=args)
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, shuffle=True, num_workers=4, drop_last=False)
        print('Training with %d image pairs' % len(train_dataset))
    else: 
        if args.dataset == 'MS2_TIR2RGB':
            train_dataset = MS2_3modal_dataset(split='test', modal='thr2rgb', args=args)
        elif args.dataset == 'MS2_NIR2RGB':
            train_dataset = MS2_3modal_dataset(split='test', modal='nir2rgb', args=args)
        elif args.dataset == 'MS2_TIR2NIR':
            train_dataset = MS2_3modal_dataset(split='test', modal='thr2nir', args=args)
        train_loader = DataLoader(train_dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=4, drop_last=False)
        print('Test with %d image pairs' % len(train_dataset))
    return train_loader


