from functools import partial
import os

import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.transforms import functional as F

import PIL
import numpy as np
import cv2
import time
import random
from tqdm import tqdm

import glm
import rgbd_3d






class BaseDataset_warp(Dataset):
    """
    BaseDataset.
    Load RGBD images and labels (if available)

    Args:
        root_path (str): path to the dataset
        image_size (int): size of the images
        normalize (bool): whether to normalize the images to [-1, 1]
        normalize_depth (bool): whether to normalize the depth maps to [-1, 1]
        prepocess_depth (str): how to preprocess the depth maps (inputs from the dataset are disparity maps)
            - 'none': no preprocessing
            - 'to_depth': disparity map, to depth map
            - 'disparity_minmax': disparity map, min-max normalization, min=0, max=1
            - 'depth_minmax': depth map, min-max normalization, min=0, max=1
            - 'z_buffer': perspective projection to [0, 1]
        near (float): near plane for perspective projection
        far (float): far plane for perspective projection
    """
    def __init__(self,
        root_path,
        data_txt,
        image_size,
        normalize=False,
        normalize_depth=False,
        prepocess_depth='none',
        near=0.5,
        far=100,
        depth_cfg=None,
    ):
        super().__init__()

        assert prepocess_depth in ['none', 'to_depth', 'disparity_minmax', 'depth_minmax', 'z_buffer'], "Unknown depth preprocessing method"
        assert not (normalize_depth and (prepocess_depth == 'none' or prepocess_depth == 'to_depth')), "Can't normalize depth maps if they are not mapped to [0, 1]"

        self.root_path = root_path
        self.data_txt = data_txt
        self.image_size = image_size
        self.normalize = normalize
        self.normalize_depth = normalize_depth
        self.prepocess_depth = prepocess_depth
        self.near = near
        self.far = far
        self.depth_cfg = depth_cfg

        self.images = None
        self.depths = None
        self.labels = None
        
        self.get_fileinfo() # set self.images, self.depths, self.labels

        self.num_classes = len(self.labels) if self.labels is not None else None

        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size), interpolation=F.InterpolationMode.LANCZOS),
            transforms.ToTensor(),
        ])
        self.transform_depth = transforms.Compose([
            transforms.Resize((image_size, image_size), interpolation=F.InterpolationMode.NEAREST),
        ])

    def to3channel(self, image):
        if image.shape[0] == 1: image = image.repeat(3, 1, 1)
        if image.shape[0] == 4: image = image[:3]
        return image

    def get_fileinfo(self):
        """
        Set labels, images, and depths.
        This function is called when the dataset is initialized.
        Should be implemented in the child class.
        """
        pass

    def get_file(self, index):
        image = PIL.Image.open(os.path.join(self.root_path, self.images[index]))
        if self.images[index].split("/")[-1].split('.')[-1] == 'JPEG':
            image_name = self.images[index].split("/")[-1][:-5]
        else:
            image_name = self.images[index].split("/")[-1][:-4]

        depth = np.load(os.path.join(self.root_path, self.depths[index]))['arr_0'].astype(np.float32)
        
        if self.depth_cfg == 'Midas':
            depth /= 6250
        
        if depth.max() > 1 / self.near:
            depth /= depth.max() * self.near
            
        depth = np.maximum(depth, 1e-3)
       
        if self.prepocess_depth == 'none':
            pass
        elif self.prepocess_depth == 'to_depth':
            depth = 1 / depth
        elif self.prepocess_depth == 'disparity_minmax':
            depth = (depth - depth.min()) / (depth.max() - depth.min())
        elif self.prepocess_depth == 'depth_minmax':
            depth = 1 / depth
            depth = (depth - depth.min()) / (depth.max() - depth.min())
        elif self.prepocess_depth == 'z_buffer':
            depth = (depth - 1 / self.near) / (1 / self.far - 1 / self.near)
            depth = np.clip(depth, 0, 1)
        
        
        depth = PIL.Image.fromarray(depth)
        
        label = self.labels[self.images[index].split('/')[-2]] if self.num_classes is not None else None

        return image, depth, label, image_name
    
    def process_file(self, image, depth, label, image_name):
        image = self.transform(image)
        image = self.to3channel(image)
        if self.normalize:
            image = image * 2 - 1

        depth = self.transform_depth(depth)
        depth = transforms.ToTensor()(np.array(depth).astype(np.float32))
        if self.normalize_depth:
            depth = depth * 2 - 1
       
        data = {
            'x_0': torch.cat([image, depth]),
            'img_name': image_name
        }

        if label is not None:
            data['classes'] = torch.tensor(label)

        return data

    def getitem(self, index):
        image, depth, label, image_name = self.get_file(index)
        
        return self.process_file(image, depth, label, image_name)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        try:
            return self.getitem(index)
        except Exception as e:
            print(e)
            return self.__getitem__(np.random.randint(self.__len__()))



class WarpDataset(BaseDataset_warp):
    def __init__(
        self,
        root_path,
        data_txt,
        image_size,
        normalize=False,
        normalize_depth=False,
        prepocess_depth='none',
        near=0.5,
        far=100,
        augments=[],
        std=0.15,
        gen_inpainting_data=False,
        viewset='3x9',
        forward_warp='3daware',
        depth_cfg=None,
    ):
        super().__init__(root_path, data_txt, image_size, normalize, normalize_depth, prepocess_depth, near, far)
        self.data_txt = data_txt
        self.renderer = None
        self.augments = augments
        self.std = std
        self.gen_inpainting_data = gen_inpainting_data
        self.viewset = viewset
        self.forward_warp = forward_warp
        self.depth_cfg = depth_cfg

    def __getitem__(self, index):
        data = super().__getitem__(index)
        if self.renderer is None:
            device_id = torch.cuda.current_device()
            # self.renderer = rgbd_3d.SimpleRenderer(self.image_size * 3, self.image_size, near=0.1, far=200, device=device_id)
            self.renderer = rgbd_3d.AggregationRenderer(self.image_size * 3, self.image_size, device=device_id)
       
        rgbd = data['x_0'].cpu().permute(1, 2, 0).numpy().copy()
        if self.normalize: rgbd[..., :3] = rgbd[..., :3] * 0.5 + 0.5
        if self.normalize_depth: rgbd[..., 3:] = rgbd[..., 3:] * 0.5 + 0.5
        x_0 = rgbd.copy()

        yaws = [0.0]
        pitches = [0.0]
        if self.viewset == '3x9':
            for i in range(4): yaws += [(i + 1) * 0.15, -(i + 1) * 0.15]
            for i in range(1): pitches += [(i + 1) * 0.15, -(i + 1) * 0.15]
        elif self.viewset == '3x10':
            for i in range(5): yaws += [(i + 1) * 0.30, -(i + 1) * 0.30]
            for i in range(1): pitches += [(i + 1) * 0.15, -(i + 1) * 0.15]
        elif self.viewset == '3x5':
            for i in range(2): yaws += [(i + 1) * 0.1, -(i + 1) * 0.1]
            for i in range(1): pitches += [(i + 1) * 0.1, -(i + 1) * 0.1]
        elif self.viewset == '3x3':
            for i in range(2): yaws += [(i + 1) * 0.1, -(i + 1) * 0.1]
            for i in range(1): pitches += [(i + 1) * 0.05, -(i + 1) * 0.05]
        else:
            raise NotImplementedError
        
        modelviews = []
        for yaw in yaws:    
            for pitch in pitches:
                modelviews.append(glm.lookAt(
                    glm.vec3(np.sin(yaw) * np.cos(pitch), np.sin(pitch), np.cos(yaw) * np.cos(pitch)),
                    glm.vec3( 0.0, 0.0, 0.0),
                    glm.vec3( 0.0, 1.0, 0.0)
                ))
        s_modelviews = modelviews[i] if isinstance(modelviews[0], list) else modelviews
   
        modelview_prev = glm.lookAt(
                        glm.vec3(0.0, 0.0, 1.0),
                        glm.vec3(0.0, 0.0, 0.0),
                        glm.vec3(0.0, 1.0, 0.0)
                    )
        idice = np.random.choice(range(1, 9), size=1, replace=False)
        for idx in idice:
            if self.forward_warp == '3daware':
                rgbd = x_0
                mesh0 = rgbd_3d.utils.depth_to_mesh(
                        rgbd_3d.utils.linearize_depth(rgbd[:, :, 3:], self.near, self.far),
                        padding='frustum',
                        fov = 45,
                        modelview=modelview_prev,
                        atol=0.03,
                        rtol=0.03,
                        # erode_rgb=3,
                        cal_normal=True,
                    )
                res = rgbd_3d.utils.aggregate_conditions(
                        self.renderer,
                        [mesh0],  
                        [rgbd[:, :, :3]],
                        s_modelviews[idx],
                        fov=45,
                        near=self.near,
                        far=self.far,
                        atol=0.03,
                        rtol=0.03,
                        # erode_rgb=3,
                    ) 
                # y = np.concatenate([res.color, res.depth], axis=-1)
                # mask = res.mask_rgb
                co = res.color * 255 
                co = co * res.mask
                # cos.append(co)
                mask_ = np.zeros_like(res.mask)
                mask_[res.mask < 0.5] = 255
                mask_[res.mask >= 0.5] = 0
                # masks_.append(mask_)

            elif self.forward_warp == 'my_forward':
                res = rgbd_3d.utils.forward_warp(self.renderer, rgbd, s_modelviews[idx], near=self.near, far=self.far, padding=self.image_size)
                y = np.concatenate([res.color, res.depth], axis=-1)
                mask = res.mask
                co = res.color
            else:
                raise NotImplementedError
            
         

            # y = torch.from_numpy(y).permute(2, 0, 1)
            # mask = torch.from_numpy(mask).permute(2, 0, 1)
            # y *= mask
            # if self.normalize: y[..., :3] = y[..., :3] * 2 - 1
            # if self.normalize_depth: y[..., 3:] = y[..., 3:] * 2 - 1

            
            # ys.append(y.float())
            # masks.append(mask.float())
           
      
        if self.gen_inpainting_data:
            data['color_save'] = co
            data['mask_save'] = mask_
            # data['rgbd'] = x_0
       
      
        # data['y'] = ys
        # data['mask'] = masks
        
    
        return data
    


class BaseDataset_outpaint(Dataset):
    """
    BaseDataset.
    Load RGBD images and labels (if available)

    Args:
        root_path (str): path to the dataset
        image_size (int): size of the images
        normalize (bool): whether to normalize the images to [-1, 1]
        normalize_depth (bool): whether to normalize the depth maps to [-1, 1]
        prepocess_depth (str): how to preprocess the depth maps (inputs from the dataset are disparity maps)
            - 'none': no preprocessing
            - 'to_depth': disparity map, to depth map
            - 'disparity_minmax': disparity map, min-max normalization, min=0, max=1
            - 'depth_minmax': depth map, min-max normalization, min=0, max=1
            - 'z_buffer': perspective projection to [0, 1]
        near (float): near plane for perspective projection
        far (float): far plane for perspective projection
    """
    def __init__(self,
        root_path,
        data_txt,
        image_size,
        normalize=False,
        normalize_depth=False,
        prepocess_depth='none',
        near=0.5,
        far=100,
        depth_cfg=None,
    ):
        super().__init__()

        assert not (normalize_depth and (prepocess_depth == 'none' or prepocess_depth == 'to_depth')), "Can't normalize depth maps if they are not mapped to [0, 1]"

        self.root_path = root_path
        self.data_txt = data_txt
        self.image_size = image_size
        self.normalize = normalize
        self.normalize_depth = normalize_depth
        self.prepocess_depth = prepocess_depth
        self.near = near
        self.far = far
        self.depth_cfg = depth_cfg

        self.images = None
        self.depths = None
        self.labels = None
        
        self.get_fileinfo() # set self.images, self.depths, self.labels

        self.num_classes = len(self.labels) if self.labels is not None else None

        self.transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        

    def to3channel(self, image):
        if image.shape[0] == 1: image = image.repeat(3, 1, 1)
        if image.shape[0] == 4: image = image[:3]
        return image

    def get_fileinfo(self):
        """
        Set labels, images, and depths.
        This function is called when the dataset is initialized.
        Should be implemented in the child class.
        """
        pass

    def get_file(self, index):
        image = PIL.Image.open(os.path.join(self.root_path, self.images[index]))
        if self.images[index].split("/")[-1].split('.')[-1] == 'JPEG':
            image_name = self.images[index].split("/")[-1][:-5]
        else:
            image_name = self.images[index].split("/")[-1][:-4]
        
        return image, image_name
    
    def process_file(self, image, image_name):
        image = self.transform(image)
        image = self.to3channel(image)
        if self.normalize:
            image = image * 2 - 1
        data = {
            'x_0': image,
            'img_name': image_name
        }

        return data

    def getitem(self, index):
        image, image_name = self.get_file(index)
        
        return self.process_file(image, image_name)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        try:
            return self.getitem(index)
        except Exception as e:
            print(e)
            return self.__getitem__(np.random.randint(self.__len__()))



class OutpaintDataset(BaseDataset_outpaint):
    def __init__(
        self,
        root_path,
        data_txt,
        image_size,
        normalize=False,
        normalize_depth=False,
        prepocess_depth='none',
        near=0.5,
        far=100,
        augments=[],
        std=0.15,
        gen_inpainting_data=False,
        viewset='3x9',
        forward_warp='3daware',
        depth_cfg=None,
    ):
        super().__init__(root_path, data_txt, image_size, normalize, normalize_depth, prepocess_depth, near, far)
        self.data_txt = data_txt
        self.renderer = None
        self.augments = augments
        self.std = std
        self.gen_inpainting_data = gen_inpainting_data
        self.viewset = viewset
        self.forward_warp = forward_warp
        self.depth_cfg = depth_cfg

    def __getitem__(self, index):
        data = super().__getitem__(index)
    
       
        rgb = data['x_0'].cpu().permute(1, 2, 0).numpy().copy()
        image = Image.open('./data/warp/imagenet/images/ILSVRC2012_test_00085049.JPEG')
        width, height = image.size  
        box_size = height * random.uniform(0.7, 0.8)

        x = np.random.randint(0, int(width - box_size))
        y = np.random.randint(0, int(height - box_size))

        left = x
        upper = y
        right = x + box_size
        lower = y + box_size

        print(f"Box coordinates: left={left}, upper={upper}, right={right}, lower={lower}")

        box = (left, upper, right, lower)
        small_image = image.crop(box)
        small_image.save('res.png')


        large_box_width, large_box_height = 512, 512  

        # size = random.randint(200, 300)
        size = random.randint(200, 300)
        small_box_width, small_box_height = size, size
        small_image = small_image.resize((small_box_width, small_box_height))


        large_image = PIL.Image.new('RGB', (large_box_width, large_box_height), "black")
        mask_0 =  PIL.Image.new('RGB', (small_box_width, small_box_height), "black")
        mask_1 = PIL.Image.new('RGB', (large_box_width, large_box_height), "white")



        max_x = large_box_width - small_box_width
        max_y = large_box_height - small_box_height

        random_x = np.random.randint(0, max_x)
        random_y = np.random.randint(0, max_y)


        co = large_image.paste(small_image, (random_x, random_y))
        mask = mask_1.paste(mask_0, (random_x, random_y))
        co.save('co.png')
        mask.save('m.png')        
        data['color_save'] = co
        data['mask_save'] = mask
       
       
        return data