import math
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Iterator, List, Optional

import numpy as np
import torch


@dataclass
class Minibatch:
    num_context_samples: int
    context_points_list: List[torch.Tensor]
    context_normals_list: List[torch.Tensor]
    target_points: Optional[torch.Tensor]
    target_normals: Optional[torch.Tensor]


@dataclass
class UniformPointCloudData:
    path: str
    vertices: np.ndarray
    vertex_normals: np.ndarray
    offset: np.ndarray
    scale: float
    vertex_indices: np.ndarray = None


@dataclass
class PartialPointCloudData:
    path: str
    vertices: np.ndarray
    vertex_normals: np.ndarray
    num_viewpoints: int
    partial_point_indices_list: List[np.ndarray]
    offset: np.ndarray
    scale: float
    vertex_indices: np.ndarray = None


@dataclass
class SdfData:
    path: str
    positive_sdf_samples: np.ndarray
    negative_sdf_samples: np.ndarray
    offset: np.ndarray
    scale: float


@dataclass
class MeshData:
    path: str
    vertices: np.ndarray
    vertex_indices: np.ndarray


def _numpy_copy_or_none(array: np.ndarray):
    if array is None:
        return None
    return np.copy(array)


def _numpy_list_copy(array_list: List[np.ndarray]):
    return [np.copy(array) for array in array_list]


def _torch_from_numpy(array: np.ndarray, copy=True, dtype=torch.float32):
    if array is None:
        return None
    if isinstance(array, list):
        array_list = array
        if copy:
            array_list = [np.copy(array) for array in array_list]
        return [torch.from_numpy(array).type(dtype) for array in array_list]
    if copy:
        array = np.copy(array)
    return torch.from_numpy(array).type(dtype)


def _to_device(array: torch.Tensor, device: str):
    if array is None:
        return None
    if isinstance(array, list):
        array_list = array
        return [array.to(device) for array in array_list]
    return array.to(device)


class DatasetInterface(ABC):
    @abstractmethod
    def __len__(self):
        raise NotImplementedError()

    @abstractmethod
    def __iter__(self):
        raise NotImplementedError()

    @abstractmethod
    def __getitem__(self, indices) -> np.ndarray:
        raise NotImplementedError()

    @abstractmethod
    def shuffle(self):
        raise NotImplementedError()


class CombinedDataset(DatasetInterface):
    def __init__(self, combined_data_path_list, read_func_list):
        self.combined_data_path_list = combined_data_path_list
        self.read_func_list = read_func_list

    def read_files(self, path_index: int):
        ret = []
        path_list = self.combined_data_path_list[path_index]
        for k, data_path in enumerate(path_list):
            data = self.read_func_list[k](data_path)
            ret.append(data)
        return ret

    def __len__(self):
        return len(self.combined_data_path_list)

    def __iter__(self):
        for path_index in range(len(self.combined_data_path_list)):
            yield self.read_files(path_index)

    def __getitem__(self, indices):
        if isinstance(indices, int):
            return self.read_files(indices)
        ret = []
        for index in indices:
            ret.append(self.read_files(index))
        return ret

    def shuffle(self):
        indices = np.random.permutation(len(self.combined_data_path_list))
        for index in indices:
            yield self[int(index)]


class MinibatchGeneratorInterface(ABC):
    def __init__(self,
                 min_num_context: int = None,
                 max_num_context: int = None,
                 min_num_target: int = None,
                 max_num_target: int = None,
                 num_context_samples: int = 1,
                 device="cpu"):
        self.min_num_context = min_num_context
        self.max_num_context = max_num_context
        self.min_num_target = min_num_target
        self.max_num_target = max_num_target
        self.num_context_samples = num_context_samples
        self.device = device
        self.transform_functions: List[Callable[[Minibatch], Minibatch]] = []

    def map(self, func):
        self.transform_functions.append(func)

    def transform(self, minibatch: Minibatch) -> Minibatch:
        for transform_func in self.transform_functions:
            minibatch = transform_func(minibatch)

        minibatch.context_points_list = _torch_from_numpy(
            minibatch.context_points_list)
        minibatch.context_normals_list = _torch_from_numpy(
            minibatch.context_normals_list)
        minibatch.target_points = _torch_from_numpy(minibatch.target_points)
        minibatch.target_normals = _torch_from_numpy(minibatch.target_normals)

        minibatch.context_points_list = _to_device(
            minibatch.context_points_list, self.device)
        minibatch.context_normals_list = _to_device(
            minibatch.context_normals_list, self.device)
        minibatch.target_points = _to_device(minibatch.target_points,
                                             self.device)
        minibatch.target_normals = _to_device(minibatch.target_normals,
                                              self.device)

        return minibatch

    @abstractmethod
    def __call__(self,
                 data_list: List[UniformPointCloudData],
                 num_point_samples=None):
        raise NotImplementedError()


class IndexSampler:
    def __init__(self, dataset_size, batchsize, drop_last):
        self.dataset_size = dataset_size
        self.batchsize = batchsize
        self.drop_last = drop_last

    def __len__(self):
        count = self.dataset_size / self.batchsize
        if self.drop_last:
            return math.floor(count)
        return math.ceil(count)

    def __iter__(self):
        ret = []
        rand_indices = np.random.permutation(self.dataset_size)
        for index in rand_indices:
            ret.append(int(index))
            if len(ret) == self.batchsize:
                yield ret
                ret = []
        if len(ret) > 0 and not self.drop_last:
            yield ret


class MinibatchIterator:
    def __init__(self,
                 dataset: DatasetInterface,
                 batchsize: int,
                 minibatch_generator: MinibatchGeneratorInterface,
                 drop_last=False):
        self.dataset = dataset
        self.batchsize = batchsize
        self.minibatch_generator = minibatch_generator
        self.sampler = IndexSampler(len(dataset), batchsize, drop_last)
        self.batch_indices_iterator = iter(self.sampler)

    def __len__(self):
        return len(self.sampler)

    def __iter__(self) -> Iterator[Minibatch]:
        self.batch_indices_iterator = iter(self.sampler)
        return self

    def __next__(self) -> Minibatch:
        batch_indices = next(self.batch_indices_iterator)
        data_list = self.dataset[batch_indices]
        return self.minibatch_generator(data_list)
