""" Module for LabelTensor """
from typing import Any
import torch
from torch import Tensor


class LabelTensor(torch.Tensor):
    """Torch tensor with a label for any column."""

    @staticmethod
    def __new__(cls, x, labels, *args, **kwargs):
        return super().__new__(cls, x, *args, **kwargs)

    def __init__(self, x, labels):
        '''
        Construct a `LabelTensor` by passing a tensor and a list of column
        labels. Such labels uniquely identify the columns of the tensor,
        allowing for an easier manipulation.

        :param torch.Tensor x: the data tensor.
        :param labels: the labels of the columns.
        :type labels: str or iterable(str)

        :Example:
            >>> from pina import LabelTensor
            >>> tensor = LabelTensor(torch.rand((2000, 3)), ['a', 'b', 'c'])
            >>> tensor
            tensor([[6.7116e-02, 4.8892e-01, 8.9452e-01],
                [9.2392e-01, 8.2065e-01, 4.1986e-04],
                [8.9266e-01, 5.5446e-01, 6.3500e-01],
                ...,
                [5.8194e-01, 9.4268e-01, 4.1841e-01],
                [1.0246e-01, 9.5179e-01, 3.7043e-02],
                [9.6150e-01, 8.0656e-01, 8.3824e-01]])
            >>> tensor.extract('a')
            tensor([[0.0671],
                    [0.9239],
                    [0.8927],
                    ...,
                    [0.5819],
                    [0.1025],
                    [0.9615]])
            >>> tensor.extract(['a', 'b'])
            tensor([[0.0671, 0.4889],
                    [0.9239, 0.8207],
                    [0.8927, 0.5545],
                    ...,
                    [0.5819, 0.9427],
                    [0.1025, 0.9518],
                    [0.9615, 0.8066]])
            >>> tensor.extract(['b', 'a'])
            tensor([[0.4889, 0.0671],
                    [0.8207, 0.9239],
                    [0.5545, 0.8927],
                    ...,
                    [0.9427, 0.5819],
                    [0.9518, 0.1025],
                    [0.8066, 0.9615]])
        '''
        if x.ndim == 1:
            x = x.reshape(-1, 1)

        if isinstance(labels, str):
            labels = [labels]

        if len(labels) != x.shape[-1]:
            raise ValueError(
                'the tensor has not the same number of columns of '
                'the passed labels.'
            )
        self._labels = labels

    @property
    def labels(self):
        """Property decorator for labels

        :return: labels of self
        :rtype: list
        """
        return self._labels

    @labels.setter
    def labels(self, labels):
        if len(labels) != self.shape[self.ndim - 1]:  # small check
            raise ValueError(
                'the tensor has not the same number of columns of '
                'the passed labels.')

        self._labels = labels   # assign the label

    # TODO remove try/ except thing IMPORTANT
    # make the label None of default
    def clone(self, *args, **kwargs):
        """
        Clone the LabelTensor. For more details, see
        :meth:`torch.Tensor.clone`.

        :return: a copy of the tensor
        :rtype: LabelTensor
        """
        try:
            out = LabelTensor(super().clone(*args, **kwargs), self.labels)
        except:
            out = super().clone(*args, **kwargs)

        return out

    def to(self, *args, **kwargs):
        """
        Performs Tensor dtype and/or device conversion. For more details, see
        :meth:`torch.Tensor.to`.
        """
        tmp = super().to(*args, **kwargs)
        new = self.__class__.clone(self)
        new.data = tmp.data
        return new

    def select(self, *args, **kwargs):
        """
        Performs Tensor selection. For more details, see :meth:`torch.Tensor.select`.
        """
        tmp = super().select(*args, **kwargs)
        tmp._labels = self._labels
        return tmp

    def extract(self, label_to_extract):
        """
        Extract the subset of the original tensor by returning all the columns
        corresponding to the passed `label_to_extract`.

        :param label_to_extract: the label(s) to extract.
        :type label_to_extract: str or iterable(str)
        :raises TypeError: labels are not str
        :raises ValueError: label to extract is not in the labels list
        """

        if isinstance(label_to_extract, str):
            label_to_extract = [label_to_extract]
        elif isinstance(label_to_extract, (tuple, list)):  # TODO
            pass
        else:
            raise TypeError(
                '`label_to_extract` should be a str, or a str iterator')

        indeces = []
        for f in label_to_extract:
            try:
                indeces.append(self.labels.index(f))
            except ValueError:
                raise ValueError(f'`{f}` not in the labels list')

        new_data = super(Tensor, self.T).__getitem__(indeces).float().T
        new_labels = [self.labels[idx] for idx in indeces]

        extracted_tensor = new_data.as_subclass(LabelTensor)
        extracted_tensor.labels = new_labels

        return extracted_tensor

    def append(self, lt, mode='std'):
        """
        Return a copy of the merged tensors.

        :param LabelTensor lt: the tensor to merge.
        :param str mode: {'std', 'first', 'cross'}
        :return: the merged tensors
        :rtype: LabelTensor
        """
        if set(self.labels).intersection(lt.labels):
            raise RuntimeError('The tensors to merge have common labels')

        new_labels = self.labels + lt.labels
        if mode == 'std':
            new_tensor = torch.cat((self, lt), dim=1)
        elif mode == 'first':
            raise NotImplementedError
        elif mode == 'cross':
            tensor1 = self
            tensor2 = lt
            n1 = tensor1.shape[0]
            n2 = tensor2.shape[0]

            tensor1 = LabelTensor(
                tensor1.repeat(n2, 1),
                labels=tensor1.labels)
            tensor2 = LabelTensor(
                tensor2.repeat_interleave(n1, dim=0),
                labels=tensor2.labels)
            new_tensor = torch.cat((tensor1, tensor2), dim=1)

        new_tensor = new_tensor.as_subclass(LabelTensor)
        new_tensor.labels = new_labels
        return new_tensor

    def __getitem__(self, index):
        """
        Return a copy of the selected tensor.
        """
        selected_lt = super(Tensor, self).__getitem__(index)

        try:
            len_index = len(index)
        except TypeError:
            len_index = 1

        if isinstance(index, int) or len_index == 1:
            if selected_lt.ndim == 1:
                selected_lt = selected_lt.reshape(1, -1)
            if hasattr(self, 'labels'):
                selected_lt.labels = self.labels
        elif len_index == 2:
            if selected_lt.ndim == 1:
                selected_lt = selected_lt.reshape(-1, 1)
            if hasattr(self, 'labels'):
                if isinstance(index[1], list):
                    selected_lt.labels = [self.labels[i] for i in index[1]]
                else:
                    selected_lt.labels = self.labels[index[1]]
                    
        return selected_lt

    def __len__(self) -> int:
        return super().__len__()
        
    def __str__(self):
        if hasattr(self, 'labels'):
            s = f'labels({str(self.labels)})\n'
        else:
            s = 'no labels\n'
        s += super().__str__()
        return s
