from models import models_ae
import torch
import numpy as np
from torch.utils.data import Dataset
import os
import trimesh
import mcubes
import random
from models.utils import npz_to_pointcloud_noise2noise
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import pytorch_lightning as pl
from typing import Optional

class Noise2NoiseDatasetv2(Dataset):
    """
    this dataset class is used for training the noise2noise model in a dataset instead of a single shape

    It will load a gt pointcloud and add noise to it to form 2 pointclouds online as a pair for training just like noise2noise in 2D
    """
    def __init__(self,
                 dataset_folder: str = "/root/autodl-tmp/dataset/shapenet",
                 categories: list = ['03001627'],
                 pc_size: int = 2048,
                 num_queries: int = 4096,
                 max_samples: int = None,
                 noise_type: str = "gaussian",
                 noise_mean: float = 0.0,
                 replica: int = 16,
                 surface_noise_std: float = 0.02,
                 point_noise_std: float = 0.01,
                 uniform_padding: float = 0.05,
                 one_train_shape_only: bool = False,
                 shuffle_seed: int = 42,
                 distributed: bool = False
    ):
        self.dataset_folder = dataset_folder
        self.categories = categories
        self.pc_size = pc_size
        self.num_queries = num_queries
        self.max_samples = max_samples
        self.replica = replica
        self.surface_noise_std = surface_noise_std
        self.point_noise_std = point_noise_std
        self.uniform_padding = uniform_padding
        self.one_train_shape_only = one_train_shape_only
        self.shuffle_seed = shuffle_seed
        self.distributed = distributed

        # noise ablation
        if noise_type is None:
            self.noise_type = "gaussian"
        else:
            self.noise_type = noise_type

        if noise_mean is None:
            self.noise_mean = 0.0
        else:
            self.noise_mean = noise_mean


        if self.shuffle_seed is not None:
            if not self.distributed or (distributed and os.getenv("RANK", "0") == "0"):
                random.seed(self.shuffle_seed)
                np.random.seed(self.shuffle_seed)
                torch.manual_seed(self.shuffle_seed)

        self.mesh_path = os.path.join(self.dataset_folder, "mesh")
        self.point_path = os.path.join(self.dataset_folder, "point")

        if categories is None:
            categories = os.listdir(self.dataset_folder)
            categories = [c for c in categories if os.path.isdir(os.path.join(self.dataset_folder, c)) and c.startswith('0')]
        categories.sort()



        self.train_model_list = []
        self.val_model_list = []
        self.test_model_list = []
        for category in categories:
            sub_path = os.path.join(self.point_path, category)
            assert os.path.exists(sub_path), f"point path {sub_path} does not exist"

            train_split_file = os.path.join(sub_path, "train.lst")
            val_split_file = os.path.join(sub_path, "val.lst")
            test_split_file = os.path.join(sub_path, "test.lst")

            assert os.path.exists(train_split_file), f"train split file {train_split_file} does not exist"
            assert os.path.exists(val_split_file), f"val split file {val_split_file} does not exist"
            assert os.path.exists(test_split_file), f"test split file {test_split_file} does not exist"

            with open(train_split_file, "r") as f:
                model_c = [line.strip() for line in f if line.strip()]

            self.train_model_list += [{"category": category, "model_name": model.replace(".npz", "")} for model in model_c]

            with open(val_split_file, "r") as f:
                model_c = [line.strip() for line in f if line.strip()]
            self.val_model_list += [{"category": category, "model_name": model.replace(".npz", "")} for model in model_c]

            with open(test_split_file, "r") as f:
                model_c = [line.strip() for line in f if line.strip()]

            self.test_model_list += [{"category": category, "model_name": model.replace(".npz", "")} for model in model_c]

        print(f"train_model_list: {len(self.train_model_list)}")
        print(f"val_model_list: {len(self.val_model_list)}")
        print(f"test_model_list: {len(self.test_model_list)}")

        if not self.one_train_shape_only:
            if self.distributed:
                random.Random(self.shuffle_seed + int(os.getenv("RANK", 0))).shuffle(self.train_model_list)
            else:
                random.shuffle(self.train_model_list)

        if self.max_samples is not None and len(self.train_model_list) + len(self.val_model_list) > self.max_samples:
            percentage = len(self.train_model_list) / (len(self.train_model_list) + len(self.val_model_list))
            self.train_model_list = self.train_model_list[:int(self.max_samples * percentage)]

        if self.one_train_shape_only:
            self.train_model_list = self.train_model_list[:1]


    def __len__(self):
        return len(self.train_model_list) * self.replica


    def __getitem__(self, idx):

        idx = idx % len(self.train_model_list)

        category = self.train_model_list[idx]["category"]
        model = self.train_model_list[idx]["model_name"]

        model_point_path = os.path.join(self.point_path, category, model+'.npz')

        # load scale
        if category == "abc":
            scale = 1.0
        else:
            try:
                with open(model_point_path.replace('.npz', '.npy'), 'rb') as f:
                    scale = np.load(f).item()
            except Exception as e:
                print(e)
                print(model_point_path)
                scale = 1.0

        model_path = os.path.join(self.mesh_path, category, "4_pointcloud", model + ".npz")

        gt_points, noise1_points, noise2_points, _, _ = npz_to_pointcloud_noise2noise(model_path, self.pc_size, sampling_method='fps', noise_std=self.point_noise_std, scale=scale, noise_type=self.noise_type, noise_mean=self.noise_mean)

        queries = self._generate_queries(gt_points, noise1_points, noise2_points)

        return gt_points, noise1_points, noise2_points, queries


    def get_eval_item(self, idx, use_train_data: bool = False, use_test_data: bool = False, use_one_shape_only: bool = False, use_vol_points: bool = True, get_name: bool = False):
        """
        get the eval item for the model

        return:
            gt_points: the ground truth point cloud as a tensor [N, 3]
            noise1_points: the first noise point cloud as a tensor [N, 3]
            noise2_points: the second noise point cloud as a tensor [N, 3]
            queries: the queries as a tensor [N, 3]
            surface: the surface as a tensor [N, 3]
            labels: the labels corresponding to the queries as a tensor [N, 1]
        """

        if use_one_shape_only:
            idx = 0

        if use_train_data:
            idx = idx % len(self.train_model_list)
        elif use_test_data:
            idx = idx % len(self.test_model_list)
        else:
            idx = idx % len(self.val_model_list)


        model_list = self.test_model_list if use_test_data else self.train_model_list if use_train_data else self.val_model_list

        category = model_list[idx]["category"]
        model = model_list[idx]["model_name"]

        point_path = os.path.join(self.point_path, category, model + '.npz')

        with open(point_path.replace('.npz', '.npy'), 'rb') as f:
            scale = np.load(f).item()

        if category == "abc":
            scale = 1.0

        model_path = os.path.join(self.mesh_path, category, "4_pointcloud", model + ".npz")
        gt_points, noise1_points, _, surface, normals = npz_to_pointcloud_noise2noise(model_path, self.pc_size, sampling_method='fps', noise_std=self.point_noise_std, scale=scale, noise_type=self.noise_type, noise_mean=self.noise_mean)

        # print surface bound
        queries = None
        labels = None
        try:
            with np.load(point_path) as data:
                vol_points = data['vol_points']
                vol_label = data['vol_label']
                near_points = data['near_points']
                near_label = data['near_label']

                ind = np.random.default_rng().choice(vol_points.shape[0], self.num_queries * 2, replace=False)
                vol_points = vol_points[ind]
                vol_label = vol_label[ind]

                ind = np.random.default_rng().choice(near_points.shape[0], self.num_queries * 2, replace=False)
                near_points = near_points[ind]
                near_label = near_label[ind]

                if use_vol_points:
                    queries = torch.from_numpy(np.concatenate([vol_points], axis=0)).float()
                    labels = torch.from_numpy(np.concatenate([vol_label], axis=0)).float()
                else:
                    queries = torch.from_numpy(np.concatenate([near_points], axis=0)).float()
                    labels = torch.from_numpy(np.concatenate([near_label], axis=0)).float()
        except Exception as e:
            print(e)
            print(point_path)


        # make sure the normals is same shape as
        if get_name:
            return gt_points, noise1_points, queries, surface, labels, normals, model
        else:
            return gt_points, noise1_points, queries, surface, labels, normals


    def _generate_queries(self, gt_pc: torch.Tensor, noise1_pc: torch.Tensor, noise2_pc: torch.Tensor):
        """
        gt_pc: the ground truth point cloud as a tensor [N, 3]
        noise1_pc: the first noise point cloud as a tensor [N, 3]
        noise2_pc: the second noise point cloud as a tensor [N, 3]
        """
        surface_points = gt_pc[torch.randperm(gt_pc.size(0))[:self.num_queries]]
        perturbed_points = surface_points + torch.randn_like(surface_points) * self.surface_noise_std


        # uniform_points = (torch.rand(self.num_queries, 3) - 0.5) * (1 + 2 * self.uniform_padding)
        uniform_points = torch.rand(self.num_queries, 3) * 2.0 - 1.0  

        all_queries = torch.cat([perturbed_points, uniform_points], dim=0)
        return all_queries[torch.randperm(all_queries.size(0))]


    def get_dataloader(self, batch_size: int = 32, num_workers=1, shuffle: bool = True, pin_memory: bool= True, persistent_workers: bool = True, prefetch_factor=3):
        return DataLoader(self, batch_size=batch_size,num_workers=num_workers, shuffle=shuffle, pin_memory=pin_memory, persistent_workers=persistent_workers, prefetch_factor=prefetch_factor)


class ShapeDataModule(pl.LightningDataModule):
    def __init__(self, dataset_config: dict, batch_size: int = 32, num_workers: int = 0):
        super().__init__()
        self.dataset_config = dataset_config
        self._dataset = None
        self.batch_size = batch_size
        self.num_workers = num_workers

    @property
    def dataset(self):
        if self._dataset is None:
            raise ValueError("Dataset has not been initialized")
        return self._dataset

    def setup(self, stage: Optional[str] = None):
        self._dataset = Noise2NoiseDatasetv2(**self.dataset_config)

    def _create_dataloader(self, shuffle: bool = True):
        num_replicas = getattr(self.trainer, "world_size", 1)
        rank = getattr(self.trainer, "global_rank", 0)

        sampler = DistributedSampler(
            self.dataset,
            shuffle=shuffle,
            num_replicas=num_replicas,
            rank=rank,
            seed=self.dataset.shuffle_seed
        )

        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            sampler=sampler,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=self.num_workers > 0,
            prefetch_factor=2 if self.num_workers > 0 else None,
            shuffle=False
        )

    def train_dataloader(self):
        return self._create_dataloader(shuffle=True)

    def val_dataloader(self):
        """
        virtual dataloader for validation
        """
        return DataLoader(
            [torch.zeros(1)],
            batch_size=1,
            num_workers=0,
            shuffle=False
        )
