import torch
import torchvision
import os.path
from torch.utils.data import Dataset
from PIL import Image
import pickle
from tqdm import tqdm
import numpy as np
import math
import random


def resample_pcd(pcd, n):
    """Drop or duplicate points so that pcd has exactly n points"""
    idx = np.random.permutation(pcd.shape[0])
    if idx.shape[0] < n:
        idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n-pcd.shape[0])])
    return pcd[idx[:n]]


def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)


def rotation_z(pts, theta):
    cos_theta = np.cos(theta)
    sin_theta = np.sin(theta)
    rotation_matrix = np.array([[cos_theta, -sin_theta, 0.0],
                                [sin_theta, cos_theta, 0.0],
                                [0.0, 0.0, 1.0]])
    return pts @ rotation_matrix.T


def rotation_y(pts, theta):
    cos_theta = np.cos(theta)
    sin_theta = np.sin(theta)
    rotation_matrix = np.array([[cos_theta, 0.0, -sin_theta],
                                [0.0, 1.0, 0.0],
                                [sin_theta, 0.0, cos_theta]])
    return pts @ rotation_matrix.T


def rotation_x(pts, theta):
    cos_theta = np.cos(theta)
    sin_theta = np.sin(theta)
    rotation_matrix = np.array([[1.0, 0.0, 0.0],
                                [0.0, cos_theta, -sin_theta],
                                [0.0, sin_theta, cos_theta]])
    return pts @ rotation_matrix.T


class ViPCDataLoader(Dataset):
    def __init__(self, filepath, data_path, status, pc_input_num=2048, view_align=False, category='all'):
        super(ViPCDataLoader, self).__init__()
        self.pc_input_num = pc_input_num
        self.status = status
        self.filelist = []
        self.cat = []
        self.key = []
        self.category = category
        self.view_align = view_align
        self.cat_map = {
            'plane': '02691156',
            'bench': '02828884',
            'cabinet': '02933112',
            'car': '02958343',
            'chair': '03001627',
            'monitor': '03211117',
            'lamp': '03636649',
            'speaker': '03691459',
            'firearm': '04090263',
            'couch': '04256520',
            'table': '04379243',
            'cellphone': '04401088',
            'watercraft': '04530566'
        }

        with open(filepath, 'r') as f:
            for line in f:
                self.filelist.append(line)

        self.imcomplete_path = os.path.join(data_path, 'ShapeNetViPC-Partial')
        self.gt_path = os.path.join(data_path, 'ShapeNetViPC-GT')
        self.rendering_path = os.path.join(data_path, 'ShapeNetViPC-View')

        for key in self.filelist:
            if category != 'all':
                if key.split(';')[0] != self.cat_map[category]:
                    continue
            else:
                if key.split(';')[0] == self.cat_map['firearm']:
                    continue
            self.cat.append(key.split(';')[0])
            self.key.append(key)

        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(224),
            torchvision.transforms.ToTensor()
        ])

        print(f'{status} data num: {len(self.key)}')

        if status in ['train', 'valid']:
            self.views_cache = []
            self.pc_cache = []
            self.pc_part_cache = []

            for idx, key in enumerate(self.key):
                views, pc, pc_part = self.load_data_by_key(key)
                self.views_cache.append(views)
                self.pc_cache.append(pc)
                self.pc_part_cache.append(pc_part)

                if idx % 100 == 0:
                    print(f'Loaded {idx}/{len(self.key)} samples to memory.')

    def load_data_by_key(self, key):
        pc_part_path = os.path.join(
            self.imcomplete_path,
            key.split(';')[0] + '/' + key.split(';')[1] + '/' + key.split(';')[-1].strip() + '.dat')

        if self.view_align:
            ran_key = key
        else:
            ran_key = key[:-3] + str(random.randint(0, 23)).rjust(2, '0')

        pc_path = os.path.join(
            self.gt_path,
            ran_key.split(';')[0] + '/' + ran_key.split(';')[1] + '/' + ran_key.split(';')[-1].strip() + '.dat')
        view_path = os.path.join(
            self.rendering_path,
            ran_key.split(';')[0] + '/' + ran_key.split(';')[1] + '/rendering/' + ran_key.split(';')[-1].strip() + '.png'
        )

        if len(ran_key.split(';')[-1]) > 3:
            print("bug")
            print(ran_key.split(';')[-1])
            fin = ran_key.split(';')[-1][-2:]
            interm = ran_key.split(';')[-1][:-2]
            pc_path = os.path.join(self.gt_path,
                                   ran_key.split(';')[0] + '/' + interm + '/' + fin + '.dat')
            view_path = os.path.join(self.rendering_path,
                                     ran_key.split(';')[0] + '/' + interm + '/rendering/' + fin + '.png')

        views = self.transform(Image.open(view_path))
        views = views[:3, :, :]

        with open(pc_part_path, 'rb') as f:
            pc_part = pickle.load(f).astype(np.float32)
        with open(pc_path, 'rb') as f:
            pc = pickle.load(f).astype(np.float32)

        if pc_part.shape[0] < self.pc_input_num:
            pc_part = np.repeat(pc_part,
                                (self.pc_input_num // pc_part.shape[0]) + 1, axis=0)[0:self.pc_input_num]

        image_view_id = view_path.split('.')[0].split('/')[-1]
        part_view_id = pc_part_path.split('.')[0].split('/')[-1]
        view_metadata = np.loadtxt(view_path[:-6] + 'rendering_metadata.txt')
        theta_part = math.radians(view_metadata[int(part_view_id), 0])
        phi_part = math.radians(view_metadata[int(part_view_id), 1])
        theta_img = math.radians(view_metadata[int(image_view_id), 0])
        phi_img = math.radians(view_metadata[int(image_view_id), 1])

        pc_part = rotation_y(rotation_x(pc_part, -phi_part), np.pi + theta_part)
        pc_part = rotation_x(rotation_y(pc_part, np.pi - theta_img), phi_img)

        gt_mean = pc.mean(axis=0)
        pc = pc - gt_mean
        pc_L_max = np.max(np.sqrt(np.sum(abs(pc ** 2), axis=-1)))
        pc = pc / pc_L_max
        pc_part = pc_part - gt_mean
        pc_part = pc_part / pc_L_max

        return views.float(), torch.from_numpy(pc).float(), torch.from_numpy(pc_part).float()

    def __getitem__(self, idx):
        if self.status in ['train', 'valid']:
            return self.views_cache[idx], self.pc_cache[idx], self.pc_part_cache[idx], self.key[idx]
        else:
            views, pc, pc_part = self.load_data_by_key(self.key[idx])
            return views, pc, pc_part, self.key[idx]

    def __len__(self):
        return len(self.key)



class EPNDataLoader(Dataset):
    def __init__(self, filepath, data_path, status, pc_input_num=2048, category='car'):
        super(EPNDataLoader, self).__init__()
        self.cat_map = {
            'plane': '02691156',
            'cabinet': '02933112',
            'car': '02958343',
            'chair': '03001627',
            'lamp': '03636649',
            'couch': '04256520',
            'table': '04379243',
            'watercraft': '04530566'
        }
        self.pc_input_num = pc_input_num
        self.status = status
        self.data_path = data_path
        self.filelist = []
        self.cat = []
        self.key = []
        self.category = category

        with open(filepath, 'r') as f:
            line = f.readline()
            while (line):
                self.filelist.append(line.strip())
                line = f.readline()

        self.incomplete_path = os.path.join(self.data_path, 'part')
        self.gt_path = os.path.join(self.data_path, 'gt')
        self.rendering_path = os.path.join(self.data_path, 'image')

        for key in self.filelist:
            if key.split('/')[0] != self.cat_map[category]:
                continue
            self.cat.append(key.split('/')[0])
            self.key.append(key)

        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(224),
            torchvision.transforms.ToTensor()
        ])

        print(f'{self.status} data num: {len(self.key)}')

        self.pc_part_cache = []
        self.pc_cache = []
        self.image_cache = []

        for idx, key in enumerate(self.key):
            pc_part_path = os.path.join(self.incomplete_path, key + '.dat')
            pc_path = os.path.join(self.gt_path, key + '.dat')
            image_path = os.path.join(self.rendering_path, key + '.png')

            image = Image.open(image_path).convert('RGB')
            image = self.transform(image)
            image = image[:3, :, :]

            with open(pc_part_path, 'rb') as f:
                pc_part = pickle.load(f).astype(np.float32)

            with open(pc_path, 'rb') as f:
                pc = pickle.load(f).astype(np.float32)

            if pc_part.shape[0] < self.pc_input_num:
                pc_part = np.repeat(pc_part,
                                    (self.pc_input_num // pc_part.shape[0]) + 1, axis=0)[0:self.pc_input_num]

            self.image_cache.append(image.float())
            self.pc_part_cache.append(torch.from_numpy(pc_part).float())
            self.pc_cache.append(torch.from_numpy(pc).float())

            if idx % 100 == 0:
                print(f'Loaded {idx}/{len(self.key)} samples to memory.')

    def __getitem__(self, idx):
        return self.image_cache[idx], self.pc_part_cache[idx], self.pc_cache[idx], self.key[idx]

    def __len__(self):
        return len(self.key)


class KITTIDataLoader(Dataset):
    def __init__(self, data_path, pc_input_num=2048):
        super(KITTIDataLoader, self).__init__()
        self.pc_input_num = pc_input_num
        self.filelist = os.listdir(os.path.join(data_path, 'image'))
        self.key = []
        self.incomplete_path = os.path.join(data_path, 'partial')
        self.rendering_path = os.path.join(data_path, 'image')
        for key in self.filelist:
            self.key.append(key[:-4])
        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(224),
            torchvision.transforms.ToTensor()
        ])

    def __getitem__(self, idx):
        key = self.key[idx].strip()
        pc_part_path = os.path.join(self.incomplete_path, key + '.npy')
        view_path = os.path.join(self.rendering_path, key + '.jpg')
        views = self.transform(Image.open(view_path))
        views = views[:3, :, :]
        pc_part = np.load(pc_part_path)
        if pc_part.shape[0] != self.pc_input_num:
            pc_part = resample_pcd(pc_part, self.pc_input_num)
        return views.float(), torch.from_numpy(pc_part).float(), key

    def __len__(self):
        return len(self.key)