#   Copyright 2014 Darsh Ranjan
#
#   This file is part of python-plyfile.
#
#   python-plyfile is free software: you can redistribute it and/or
#   modify it under the terms of the GNU General Public License as
#   published by the Free Software Foundation, either version 3 of the
#   License, or (at your option) any later version.
#
#   python-plyfile is distributed in the hope that it will be useful,
#   but WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
#   General Public License for more details.
#
#   You should have received a copy of the GNU General Public License
#   along with python-plyfile.  If not, see
#       <http://www.gnu.org/licenses/>.

from itertools import islice as _islice
from sys import byteorder as _byteorder

import numpy as _np


try:
    _range = xrange
except NameError:
    _range = range


# Many-many relation
_data_type_relation = [
    ("int8", "i1"),
    ("char", "i1"),
    ("uint8", "u1"),
    ("uchar", "b1"),
    ("uchar", "u1"),
    ("int16", "i2"),
    ("short", "i2"),
    ("uint16", "u2"),
    ("ushort", "u2"),
    ("int32", "i4"),
    ("int", "i4"),
    ("uint32", "u4"),
    ("uint", "u4"),
    ("float32", "f4"),
    ("float", "f4"),
    ("float64", "f8"),
    ("double", "f8"),
]

_data_types = dict(_data_type_relation)
_data_type_reverse = dict((b, a) for (a, b) in _data_type_relation)

_types_list = []
_types_set = set()
for (_a, _b) in _data_type_relation:
    if _a not in _types_set:
        _types_list.append(_a)
        _types_set.add(_a)
    if _b not in _types_set:
        _types_list.append(_b)
        _types_set.add(_b)


_byte_order_map = {"ascii": "=", "binary_little_endian": "<", "binary_big_endian": ">"}

_byte_order_reverse = {"<": "binary_little_endian", ">": "binary_big_endian"}

_native_byte_order = {"little": "<", "big": ">"}[_byteorder]


def _lookup_type(type_str):
    if type_str not in _data_type_reverse:
        try:
            type_str = _data_types[type_str]
        except KeyError:
            raise ValueError("field type %r not in %r" % (type_str, _types_list))

    return _data_type_reverse[type_str]


def make2d(array, cols=None, dtype=None):
    """
    Make a 2D array from an array of arrays.  The `cols' and `dtype'
    arguments can be omitted if the array is not empty.
    """
    if not len(array):
        if cols is None or dtype is None:
            raise RuntimeError("cols and dtype must be specified for empty array")
        return _np.empty((0, cols), dtype=dtype)
    return _np.vstack(array)


class _PlyHeaderParser(object):
    def __init__(self):
        self.format = None
        self.elements = []
        self.comments = []
        self.obj_info = []
        self.lines = 0
        self._allowed = ["ply"]

    def consume(self, raw_line):
        self.lines += 1
        if not raw_line:
            self._error("early end-of-file")

        line = raw_line.decode("ascii").strip()
        try:
            keyword = line.split(None, 1)[0]
        except IndexError:
            self._error()

        if keyword not in self._allowed:
            self._error("expected one of {%s}" % ", ".join(self._allowed))

        getattr(self, "parse_" + keyword)(line[len(keyword) + 1 :])
        return self._allowed

    def _error(self, message="parse error"):
        raise PlyHeaderParseError(message, self.lines)

    def parse_ply(self, data):
        if data:
            self._error("unexpected characters after 'ply'")
        self._allowed = ["format", "comment", "obj_info"]

    def parse_format(self, data):
        fields = data.strip().split()
        if len(fields) != 2:
            self._error('expected "format {format} 1.0"')

        self.format = fields[0]
        if self.format not in _byte_order_map:
            self._error("don't understand format %r" % format)

        if fields[1] != "1.0":
            self._error("expected version '1.0'")

        self._allowed = ["element", "comment", "obj_info", "end_header"]

    def parse_comment(self, data):
        if not self.elements:
            self.comments.append(data)
        else:
            self.elements[-1][3].append(data)

    def parse_obj_info(self, data):
        self.obj_info.append(data)

    def parse_element(self, data):
        fields = data.strip().split()
        if len(fields) != 2:
            self._error('expected "element {name} {count}"')

        name = fields[0]
        try:
            count = int(fields[1])
        except ValueError:
            self._error("expected integer count")

        self.elements.append((name, [], count, []))
        self._allowed = ["element", "comment", "property", "end_header"]

    def parse_property(self, data):
        properties = self.elements[-1][1]
        fields = data.strip().split()
        if len(fields) < 2:
            self._error("bad 'property' line")

        if fields[0] == "list":
            if len(fields) != 4:
                self._error('expected "property list ' '{len_type} {val_type} {name}"')

            try:
                properties.append(PlyListProperty(fields[3], fields[1], fields[2]))
            except ValueError as e:
                self._error(str(e))

        else:
            if len(fields) != 2:
                self._error('expected "property {type} {name}"')

            try:
                properties.append(PlyProperty(fields[1], fields[0]))
            except ValueError as e:
                self._error(str(e))

    def parse_end_header(self, data):
        if data:
            self._error("unexpected data after 'end_header'")
        self._allowed = []


class PlyParseError(Exception):

    """
    Base class for PLY parsing errors.
    """

    pass


class PlyElementParseError(PlyParseError):

    """
    Raised when a PLY element cannot be parsed.
    The attributes `element', `row', `property', and `message' give
    additional information.
    """

    def __init__(self, message, element=None, row=None, prop=None):
        self.message = message
        self.element = element
        self.row = row
        self.prop = prop

        s = ""
        if self.element:
            s += "element %r: " % self.element.name
        if self.row is not None:
            s += "row %d: " % self.row
        if self.prop:
            s += "property %r: " % self.prop.name
        s += self.message

        Exception.__init__(self, s)

    def __repr__(self):
        return "%s(%r, element=%r, row=%r, prop=%r)" % (
            self.__class__.__name__,
            self.message,
            self.element,
            self.row,
            self.prop,
        )


class PlyHeaderParseError(PlyParseError):

    """
    Raised when a PLY header cannot be parsed.
    The attribute `line' provides additional information.
    """

    def __init__(self, message, line=None):
        self.message = message
        self.line = line

        s = ""
        if self.line:
            s += "line %r: " % self.line
        s += self.message

        Exception.__init__(self, s)

    def __repr__(self):
        return "%s(%r, line=%r)" % (self.__class__.__name__, self.message, self.line)


class PlyData(object):

    """
    PLY file header and data.
    A PlyData instance is created in one of two ways: by the static
    method PlyData.read (to read a PLY file), or directly from __init__
    given a sequence of elements (which can then be written to a PLY
    file).
    """

    def __init__(self, elements=[], text=False, byte_order="=", comments=[], obj_info=[]):
        """
        elements: sequence of PlyElement instances.
        text: whether the resulting PLY file will be text (True) or
            binary (False).
        byte_order: '<' for little-endian, '>' for big-endian, or '='
            for native.  This is only relevant if `text' is False.
        comments: sequence of strings that will be placed in the header
            between the 'ply' and 'format ...' lines.
        obj_info: like comments, but will be placed in the header with
            "obj_info ..." instead of "comment ...".
        """
        if byte_order == "=" and not text:
            byte_order = _native_byte_order

        self.byte_order = byte_order
        self.text = text

        self.comments = comments
        self.obj_info = obj_info
        self.elements = elements

    def _get_elements(self):
        return self._elements

    def _set_elements(self, elements):
        self._elements = tuple(elements)
        self._index()

    elements = property(_get_elements, _set_elements)

    def _get_byte_order(self):
        return self._byte_order

    def _set_byte_order(self, byte_order):
        if byte_order not in ["<", ">", "="]:
            raise ValueError("byte order must be '<', '>', or '='")

        self._byte_order = byte_order

    byte_order = property(_get_byte_order, _set_byte_order)

    def _index(self):
        self._element_lookup = dict((elt.name, elt) for elt in self._elements)
        if len(self._element_lookup) != len(self._elements):
            raise ValueError("two elements with same name")

    def _get_comments(self):
        return list(self._comments)

    def _set_comments(self, comments):
        _check_comments(comments)
        self._comments = list(comments)

    comments = property(_get_comments, _set_comments)

    def _get_obj_info(self):
        return list(self._obj_info)

    def _set_obj_info(self, obj_info):
        _check_comments(obj_info)
        self._obj_info = list(obj_info)

    obj_info = property(_get_obj_info, _set_obj_info)

    @staticmethod
    def _parse_header(stream):
        """
        Parse a PLY header from a readable file-like stream.
        """
        parser = _PlyHeaderParser()
        while parser.consume(stream.readline()):
            pass

        return PlyData(
            [PlyElement(*e) for e in parser.elements],
            parser.format == "ascii",
            _byte_order_map[parser.format],
            parser.comments,
            parser.obj_info,
        )

    @staticmethod
    def read(stream):
        """
        Read PLY data from a readable file-like object or filename.
        """
        (must_close, stream) = _open_stream(stream, "read")
        try:
            data = PlyData._parse_header(stream)
            for elt in data:
                elt._read(stream, data.text, data.byte_order)
        finally:
            if must_close:
                stream.close()

        return data

    def write(self, stream):
        """
        Write PLY data to a writeable file-like object or filename.
        """
        (must_close, stream) = _open_stream(stream, "write")
        try:
            stream.write(self.header.encode("ascii"))
            stream.write(b"\n")
            for elt in self:
                elt._write(stream, self.text, self.byte_order)
        finally:
            if must_close:
                stream.close()

    @property
    def header(self):
        """
        Provide PLY-formatted metadata for the instance.
        """
        lines = ["ply"]

        if self.text:
            lines.append("format ascii 1.0")
        else:
            lines.append("format " + _byte_order_reverse[self.byte_order] + " 1.0")

        # Some information is lost here, since all comments are placed
        # between the 'format' line and the first element.
        for c in self.comments:
            lines.append("comment " + c)

        for c in self.obj_info:
            lines.append("obj_info " + c)

        lines.extend(elt.header for elt in self.elements)
        lines.append("end_header")
        return "\n".join(lines)

    def __iter__(self):
        return iter(self.elements)

    def __len__(self):
        return len(self.elements)

    def __contains__(self, name):
        return name in self._element_lookup

    def __getitem__(self, name):
        return self._element_lookup[name]

    def __str__(self):
        return self.header

    def __repr__(self):
        return "PlyData(%r, text=%r, byte_order=%r, " "comments=%r, obj_info=%r)" % (
            self.elements,
            self.text,
            self.byte_order,
            self.comments,
            self.obj_info,
        )


def _open_stream(stream, read_or_write):
    if hasattr(stream, read_or_write):
        return (False, stream)
    try:
        return (True, open(stream, read_or_write[0] + "b"))
    except TypeError:
        raise RuntimeError("expected open file or filename")


class PlyElement(object):

    """
    PLY file element.
    A client of this library doesn't normally need to instantiate this
    directly, so the following is only for the sake of documenting the
    internals.
    Creating a PlyElement instance is generally done in one of two ways:
    as a byproduct of PlyData.read (when reading a PLY file) and by
    PlyElement.describe (before writing a PLY file).
    """

    def __init__(self, name, properties, count, comments=[]):
        """
        This is not part of the public interface.  The preferred methods
        of obtaining PlyElement instances are PlyData.read (to read from
        a file) and PlyElement.describe (to construct from a numpy
        array).
        """
        _check_name(name)
        self._name = str(name)
        self._count = count

        self._properties = tuple(properties)
        self._index()

        self.comments = comments

        self._have_list = any(isinstance(p, PlyListProperty) for p in self.properties)

    @property
    def count(self):
        return self._count

    def _get_data(self):
        return self._data

    def _set_data(self, data):
        self._data = data
        self._count = len(data)
        self._check_sanity()

    data = property(_get_data, _set_data)

    def _check_sanity(self):
        for prop in self.properties:
            if prop.name not in self._data.dtype.fields:
                raise ValueError("dangling property %r" % prop.name)

    def _get_properties(self):
        return self._properties

    def _set_properties(self, properties):
        self._properties = tuple(properties)
        self._check_sanity()
        self._index()

    properties = property(_get_properties, _set_properties)

    def _get_comments(self):
        return list(self._comments)

    def _set_comments(self, comments):
        _check_comments(comments)
        self._comments = list(comments)

    comments = property(_get_comments, _set_comments)

    def _index(self):
        self._property_lookup = dict((prop.name, prop) for prop in self._properties)
        if len(self._property_lookup) != len(self._properties):
            raise ValueError("two properties with same name")

    def ply_property(self, name):
        return self._property_lookup[name]

    @property
    def name(self):
        return self._name

    def dtype(self, byte_order="="):
        """
        Return the numpy dtype of the in-memory representation of the
        data.  (If there are no list properties, and the PLY format is
        binary, then this also accurately describes the on-disk
        representation of the element.)
        """
        return _np.dtype([(prop.name, prop.dtype(byte_order)) for prop in self.properties])

    @staticmethod
    def describe(data, name, len_types={}, val_types={}, comments=[]):
        """
        Construct a PlyElement from an array's metadata.
        len_types and val_types can be given as mappings from list
        property names to type strings (like 'u1', 'f4', etc., or
        'int8', 'float32', etc.). These can be used to define the length
        and value types of list properties.  List property lengths
        always default to type 'u1' (8-bit unsigned integer), and value
        types default to 'i4' (32-bit integer).
        """
        if not isinstance(data, _np.ndarray):
            raise TypeError("only numpy arrays are supported")

        if len(data.shape) != 1:
            raise ValueError("only one-dimensional arrays are " "supported")

        count = len(data)

        properties = []
        descr = data.dtype.descr

        for t in descr:
            if not isinstance(t[1], str):
                raise ValueError("nested records not supported")

            if not t[0]:
                raise ValueError("field with empty name")

            if len(t) != 2 or t[1][1] == "O":
                # non-scalar field, which corresponds to a list
                # property in PLY.

                if t[1][1] == "O":
                    if len(t) != 2:
                        raise ValueError("non-scalar object fields not " "supported")

                len_str = _data_type_reverse[len_types.get(t[0], "u1")]
                if t[1][1] == "O":
                    val_type = val_types.get(t[0], "i4")
                    val_str = _lookup_type(val_type)
                else:
                    val_str = _lookup_type(t[1][1:])

                prop = PlyListProperty(t[0], len_str, val_str)
            else:
                val_str = _lookup_type(t[1][1:])
                prop = PlyProperty(t[0], val_str)

            properties.append(prop)

        elt = PlyElement(name, properties, count, comments)
        elt.data = data

        return elt

    def _read(self, stream, text, byte_order):
        """
        Read the actual data from a PLY file.
        """
        dtype = self.dtype(byte_order)
        if text:
            self._read_txt(stream)
        elif _can_mmap(stream) and not self._have_list:
            # Loading the data is straightforward.  We will memory map
            # the file in copy-on-write mode.
            num_bytes = self.count * dtype.itemsize
            offset = stream.tell()
            stream.seek(0, 2)
            max_bytes = stream.tell() - offset
            if max_bytes < num_bytes:
                raise PlyElementParseError("early end-of-file", self, max_bytes // dtype.itemsize)
            self._data = _np.memmap(stream, dtype, "c", offset, self.count)
            # Fix stream position
            stream.seek(offset + self.count * dtype.itemsize)
        else:
            # A simple load is impossible.
            self._read_bin(stream, byte_order)

        self._check_sanity()

    def _write(self, stream, text, byte_order):
        """
        Write the data to a PLY file.
        """
        if text:
            self._write_txt(stream)
        else:
            if self._have_list:
                # There are list properties, so serialization is
                # slightly complicated.
                self._write_bin(stream, byte_order)
            else:
                # no list properties, so serialization is
                # straightforward.
                stream.write(self.data.astype(self.dtype(byte_order), copy=False).data)

    def _read_txt(self, stream):
        """
        Load a PLY element from an ASCII-format PLY file.  The element
        may contain list properties.
        """
        self._data = _np.empty(self.count, dtype=self.dtype())

        k = 0
        for line in _islice(iter(stream.readline, b""), self.count):
            fields = iter(line.strip().split())
            for prop in self.properties:
                try:
                    self._data[prop.name][k] = prop._from_fields(fields)
                except StopIteration:
                    raise PlyElementParseError("early end-of-line", self, k, prop)
                except ValueError:
                    raise PlyElementParseError("malformed input", self, k, prop)
            try:
                next(fields)
            except StopIteration:
                pass
            else:
                raise PlyElementParseError("expected end-of-line", self, k)
            k += 1

        if k < self.count:
            del self._data
            raise PlyElementParseError("early end-of-file", self, k)

    def _write_txt(self, stream):
        """
        Save a PLY element to an ASCII-format PLY file.  The element may
        contain list properties.
        """
        for rec in self.data:
            fields = []
            for prop in self.properties:
                fields.extend(prop._to_fields(rec[prop.name]))

            _np.savetxt(stream, [fields], "%.18g", newline="\n")

    def _read_bin(self, stream, byte_order):
        """
        Load a PLY element from a binary PLY file.  The element may
        contain list properties.
        """
        self._data = _np.empty(self.count, dtype=self.dtype(byte_order))

        for k in _range(self.count):
            for prop in self.properties:
                try:
                    self._data[prop.name][k] = prop._read_bin(stream, byte_order)
                except StopIteration:
                    raise PlyElementParseError("early end-of-file", self, k, prop)

    def _write_bin(self, stream, byte_order):
        """
        Save a PLY element to a binary PLY file.  The element may
        contain list properties.
        """
        for rec in self.data:
            for prop in self.properties:
                prop._write_bin(rec[prop.name], stream, byte_order)

    @property
    def header(self):
        """
        Format this element's metadata as it would appear in a PLY
        header.
        """
        lines = ["element %s %d" % (self.name, self.count)]

        # Some information is lost here, since all comments are placed
        # between the 'element' line and the first property definition.
        for c in self.comments:
            lines.append("comment " + c)

        lines.extend(list(map(str, self.properties)))

        return "\n".join(lines)

    def __getitem__(self, key):
        return self.data[key]

    def __setitem__(self, key, value):
        self.data[key] = value

    def __str__(self):
        return self.header

    def __repr__(self):
        return "PlyElement(%r, %r, count=%d, comments=%r)" % (self.name, self.properties, self.count, self.comments)


def _check_comments(comments):
    for comment in comments:
        for char in comment:
            if not 0 <= ord(char) < 128:
                raise ValueError("non-ASCII character in comment")
            if char == "\n":
                raise ValueError("embedded newline in comment")


class PlyProperty(object):

    """
    PLY property description.  This class is pure metadata; the data
    itself is contained in PlyElement instances.
    """

    def __init__(self, name, val_dtype):
        _check_name(name)
        self._name = str(name)
        self.val_dtype = val_dtype

    def _get_val_dtype(self):
        return self._val_dtype

    def _set_val_dtype(self, val_dtype):
        self._val_dtype = _data_types[_lookup_type(val_dtype)]

    val_dtype = property(_get_val_dtype, _set_val_dtype)

    @property
    def name(self):
        return self._name

    def dtype(self, byte_order="="):
        """
        Return the numpy dtype description for this property (as a tuple
        of strings).
        """
        return byte_order + self.val_dtype

    def _from_fields(self, fields):
        """
        Parse from generator.  Raise StopIteration if the property could
        not be read.
        """
        return _np.dtype(self.dtype()).type(next(fields))

    def _to_fields(self, data):
        """
        Return generator over one item.
        """
        yield _np.dtype(self.dtype()).type(data)

    def _read_bin(self, stream, byte_order):
        """
        Read data from a binary stream.  Raise StopIteration if the
        property could not be read.
        """
        try:
            return _read_array(stream, self.dtype(byte_order), 1)[0]
        except IndexError:
            raise StopIteration

    def _write_bin(self, data, stream, byte_order):
        """
        Write data to a binary stream.
        """
        _write_array(stream, _np.dtype(self.dtype(byte_order)).type(data))

    def __str__(self):
        val_str = _data_type_reverse[self.val_dtype]
        return "property %s %s" % (val_str, self.name)

    def __repr__(self):
        return "PlyProperty(%r, %r)" % (self.name, _lookup_type(self.val_dtype))


class PlyListProperty(PlyProperty):

    """
    PLY list property description.
    """

    def __init__(self, name, len_dtype, val_dtype):
        PlyProperty.__init__(self, name, val_dtype)

        self.len_dtype = len_dtype

    def _get_len_dtype(self):
        return self._len_dtype

    def _set_len_dtype(self, len_dtype):
        self._len_dtype = _data_types[_lookup_type(len_dtype)]

    len_dtype = property(_get_len_dtype, _set_len_dtype)

    def dtype(self, byte_order="="):
        """
        List properties always have a numpy dtype of "object".
        """
        return "|O"

    def list_dtype(self, byte_order="="):
        """
        Return the pair (len_dtype, val_dtype) (both numpy-friendly
        strings).
        """
        return (byte_order + self.len_dtype, byte_order + self.val_dtype)

    def _from_fields(self, fields):
        (len_t, val_t) = self.list_dtype()

        n = int(_np.dtype(len_t).type(next(fields)))

        data = _np.loadtxt(list(_islice(fields, n)), val_t, ndmin=1)
        if len(data) < n:
            raise StopIteration

        return data

    def _to_fields(self, data):
        """
        Return generator over the (numerical) PLY representation of the
        list data (length followed by actual data).
        """
        (len_t, val_t) = self.list_dtype()

        data = _np.asarray(data, dtype=val_t).ravel()

        yield _np.dtype(len_t).type(data.size)
        for x in data:
            yield x

    def _read_bin(self, stream, byte_order):
        (len_t, val_t) = self.list_dtype(byte_order)

        try:
            n = _read_array(stream, _np.dtype(len_t), 1)[0]
        except IndexError:
            raise StopIteration

        data = _read_array(stream, _np.dtype(val_t), n)
        if len(data) < n:
            raise StopIteration

        return data

    def _write_bin(self, data, stream, byte_order):
        """
        Write data to a binary stream.
        """
        (len_t, val_t) = self.list_dtype(byte_order)

        data = _np.asarray(data, dtype=val_t).ravel()

        _write_array(stream, _np.array(data.size, dtype=len_t))
        _write_array(stream, data)

    def __str__(self):
        len_str = _data_type_reverse[self.len_dtype]
        val_str = _data_type_reverse[self.val_dtype]
        return "property list %s %s %s" % (len_str, val_str, self.name)

    def __repr__(self):
        return "PlyListProperty(%r, %r, %r)" % (self.name, _lookup_type(self.len_dtype), _lookup_type(self.val_dtype))


def _check_name(name):
    for char in name:
        if not 0 <= ord(char) < 128:
            raise ValueError("non-ASCII character in name %r" % name)
        if char.isspace():
            raise ValueError("space character(s) in name %r" % name)


def _read_array(stream, dtype, n):
    try:
        size = int(_np.dtype(dtype).itemsize * n)
        return _np.frombuffer(stream.read(size), dtype)
    except Exception:
        raise StopIteration


def _write_array(stream, array):
    stream.write(array.tostring())


def _can_mmap(stream):
    try:
        pos = stream.tell()
        try:
            _np.memmap(stream, "u1", "c")
            stream.seek(pos)
            return True
        except Exception:
            stream.seek(pos)
            return False
    except Exception:
        return False
