import numpy as np
import os.path as osp

import torch
import torchaudio
import pickle


class ReplicaDataset(torch.utils.data.Dataset):
    def __init__(self,
                 data_root,
                 split='train',  # ['train', 'test']
                 norm = True,
                 visible = False,
                 dense = False,
                 nrays=None,
                 time_bin_size=0.001, fs=22050,
                 ):
        super(ReplicaDataset, self).__init__()
        self.data_root = data_root
        self.split = split
        self.norm = norm
        self.visible = visible
        self.dense = dense
        self.nrays = nrays
        self.time_bin_size = time_bin_size
        self.fs = fs
        self.load_ir()

    def load_ir(self):
        root_path = osp.join(self.data_root)
        self.file_list = []
        room_name = self.data_root.split('/')[-2]
        data = pickle.load(open(f'data/matterport_data/{room_name}/data.pkl', 'rb'))
        if self.split == 'train':
            data = data[0]
        else:
            data = data[1]

        self.file_list = []
        if not osp.exists(osp.join(self.data_root, 'move_train_list_complete.npy')) or not osp.exists(
                osp.join(self.data_root, 'move_test_list_complete.npy')):
            for dir in ['0', '90', '180', '270']:
                for cur_str in data[dir]:
                    file_name = cur_str + '.wav'
                    if osp.exists(osp.join(root_path, dir, file_name)):
                        self.file_list.append(osp.join(root_path, dir, file_name))
            if self.split == 'train':
                np.save(osp.join(self.data_root, 'move_train_list_complete.npy'), np.array(self.file_list))
            else:
                np.save(osp.join(self.data_root, 'move_test_list_complete.npy'), np.array(self.file_list))

        self.file_list = [f for f in np.load(osp.join(self.data_root, f'move_{self.split}_list_complete.npy'))]
        print(len(self.file_list))

        self.room_id = self.data_root.split('/')[-2]
        self.mean_std = None

        # load mesh points and meta points
        meta_points = np.loadtxt(osp.join(f'/neural_acoustic_field/data/matterport_data/{room_name}/points.txt'))
        self.mesh_points = np.loadtxt(osp.join(f'/neural_acoustic_field/data/matterport_data/{room_name}/mesh.xyz'))
        self.height = meta_points[0, 3]
        self.points = {int(k): [x, y] for k, x, y in meta_points[..., :3]}
        self.orientations = {"0":0, "90":1, "180":2, "270":3}

        self.room_id = self.data_root.split('/')[-2]
        print(self.room_id)
        room_len = {"apartment_1": 8256, "apartment_2": 7727, "frl_apartment_2": 9564, "frl_apartment_4": 9375, "office_4": 7118, "room_2": 7829}
        self.max_len = room_len[self.room_id]
        ori_patches = self.get_patch(self.mesh_points, self.height)
        if self.visible and self.room_id == 'apartment_1':
            self.separate_room()
            bounce1 = ori_patches[ori_patches[..., 0] <= 2]
            bounce2 = ori_patches[(ori_patches[..., 0] >= 2) & (ori_patches[..., 1] <= 3)]
            bounce3 = ori_patches[(ori_patches[..., 0] >= 2) & (ori_patches[..., 1] >= 3)]
            print(bounce1.shape, bounce2.shape, bounce3.shape)
            self.patches = [bounce1, bounce2, bounce3, ori_patches]
        else:
            self.patches = ori_patches
            print('number of patch', self.patches.shape)
        source_listener_pts = np.array([[x, y] for k, x, y in meta_points[..., :3]])
        all_pts = np.concatenate([source_listener_pts, self.patches],axis=0)
        self.max_pos = np.array([np.max(all_pts[:, 0]), np.max(all_pts[:, 1])])
        self.min_pos = np.array([np.min(all_pts[:, 0]), np.min(all_pts[:, 1])])
        if self.norm:
            self.norm_patches = ((self.patches - self.min_pos) / (self.max_pos - self.min_pos) - 0.5) * 2.0


    def separate_room(self):
        self.points_room_idx = {}
        for k, v in self.points.items():
            if v[0] < 2:
                self.points_room_idx[k] = 0
            elif v[0] >= 2 and v[1] < 3:
                self.points_room_idx[k] = 1
            else:
                self.points_room_idx[k] = 2
    def get_patch(self, mesh_points, height, h_range=1.5, space_height=150):
        range = h_range / (space_height / (height - mesh_points[..., 2].min()))
        bounces = mesh_points[(mesh_points[..., 2] > height - range / 2) & (mesh_points[..., 2] < height + range / 2)][
                  ..., :2]
        # remove redundant bounce points
        if 'office' in self.room_id or 'room_2' in self.room_id:
            bounces = np.array(list(set([tuple([round(x, 1), round(y, 1)]) for x, y in bounces]))).reshape(-1, 2)
            bounces = np.array(bounces)[::5]
            return bounces
        if self.dense:
            bounces = np.array(list(set([tuple([round(x, 1), round(y, 1)]) for x, y in bounces]))).reshape(-1, 2)
        else:
            bounces = np.array(list(set([tuple([round(x, 0), round(y, 0)]) for x, y in bounces]))).reshape(-1, 2)
        return bounces
    def __getitem__(self, index):
        file_name = self.file_list[index]
        orientation = np.array([self.orientations[file_name.split('/')[-2]]])
        source_id = int(file_name.split('/')[-1].split('_')[0])
        target_id = int(file_name.split('_')[-1].split('.')[0])
        source_points = np.array(self.points[source_id]).reshape(1, -1)
        points = np.array(self.points[target_id]).reshape(1, -1)

        norm_source_points = ((source_points - self.min_pos)/(self.max_pos - self.min_pos) - 0.5)*2.0
        norm_points = ((points - self.min_pos)/(self.max_pos - self.min_pos) - 0.5)*2.0
        # num_bins x num_sample_per_bin
        waveform_ori, fs = torchaudio.load(file_name)
        transform = torchaudio.transforms.Resample(fs, self.fs)
        transformed_waveform = transform(waveform_ori)
        if transformed_waveform.shape[1] >= self.max_len:
            new_waveform = transformed_waveform[:,:self.max_len]
        else:
            new_waveform = torch.zeros(2, self.max_len)
            new_waveform[:, :transformed_waveform.shape[1]] = transformed_waveform
        if self.mean_std is not None:
            new_waveform[:, self.non_zero] = (new_waveform[:,self.non_zero] - self.min_[self.non_zero])/(self.max_[self.non_zero] - self.min_[self.non_zero])*2 - 1
        if self.visible:
            source_room = self.points_room_idx[source_id]
            listen_room = self.points_room_idx[target_id]
            rooms = np.array([source_room, listen_room])
            return torch.from_numpy(source_points), torch.from_numpy(points), torch.from_numpy(norm_source_points), \
                   torch.from_numpy(norm_points), rooms, orientation, new_waveform
        else:
            return torch.from_numpy(source_points), torch.from_numpy(points), torch.from_numpy(norm_source_points), \
               torch.from_numpy(norm_points), orientation, new_waveform
    def get_item_test(self, index):
        file_name = self.file_list[index]
        orientation = np.array([self.orientations[file_name.split('/')[-2]]])
        source_id = int(file_name.split('/')[-1].split('_')[0])
        target_id = int(file_name.split('_')[-1].split('.')[0])
        source_points = np.array(self.points[source_id]).reshape(1, -1)
        points = np.array(self.points[target_id]).reshape(1, -1)

        norm_source_points = ((source_points - self.min_pos) / (self.max_pos - self.min_pos) - 0.5) * 2.0
        norm_points = ((points - self.min_pos) / (self.max_pos - self.min_pos) - 0.5) * 2.0
        # num_bins x num_sample_per_bin
        waveform_ori, fs = torchaudio.load(file_name)
        transform = torchaudio.transforms.Resample(fs, self.fs)
        transformed_waveform = transform(waveform_ori)
        if transformed_waveform.shape[1] >= self.max_len:
            new_waveform = transformed_waveform[:, :self.max_len]
        else:
            new_waveform = torch.zeros(2, self.max_len)
            new_waveform[:, :transformed_waveform.shape[1]] = transformed_waveform
        if self.mean_std is not None:
            new_waveform[:, self.non_zero] = (new_waveform[:,self.non_zero] - self.min_[self.non_zero])/(self.max_[self.non_zero] - self.min_[self.non_zero])*2 - 1
        return torch.from_numpy(source_points), torch.from_numpy(points), torch.from_numpy(norm_source_points), \
               torch.from_numpy(norm_points), \
               orientation, new_waveform
    def __len__(self):
        return len(self.file_list)