"""
Supplementary functions for PyTorch tensors.
"""

import abc
import torch
import torch.utils.data as torch_data
import numpy as np
import itertools
from contextlib import contextmanager
from typing import  Sequence, Iterable, Callable, Self, Any, ClassVar, TypeVar, Literal
from torch import Tensor
import dataclasses as dc


def clone_to(x: Tensor, device: torch.device | str):
    y = x.to(device=device)
    if y is x:
        y = y.clone()
    return y


def pad(
    tensors: list[Tensor],
    value,
    dim: int,
    min_len: int | None = None
):
    
    max_len = max(tensor.size(dim) for tensor in tensors)
    if min_len is not None:
        max_len = max(max_len, min_len)

    out: list[Tensor] = []
    
    for tensor in tensors:
        if (padlen := max_len - tensor.size(dim)) > 0:
            pad_shape = torch.Size(
                (padlen if d == dim else size) for d, size in enumerate(tensor.shape)
            )
            tensor = torch.cat((
                tensor,
                torch.full(pad_shape, value, dtype=tensor.dtype, device=tensor.device),
            ), dim=dim)
        out.append(tensor)

    return out


def padded_stack(
    tensors: list[Tensor],
    pad_value,
    pad_dim,
    merge_dim: int = 0,
    min_len: int | None = None,
):
    tensors = pad(tensors, pad_value, pad_dim, min_len)
    return torch.stack(tensors, dim=merge_dim)


def nonzero_indices(tensor: Tensor) -> Iterable[tuple[int, ...]]:
    for row in torch.nonzero(tensor):
        yield tuple(row.tolist())


def regularize_logits(logits: Tensor):
    return logits - torch.amax(logits, dim=-1, keepdim=True)


def normalize_tensor(values: Tensor, mask: Tensor | None = None, eps: float = 1e-6):
    if mask is None:
        std, mean = torch.std_mean(values)
        std = torch.clamp(std, min=eps)
        values = (values - mean) / std
    else:
        std, mean = torch.std_mean(values[mask])
        std = torch.clamp(std, min=eps)
        values = torch.where(mask, (values - mean) / std, 0)

    return values


def fetch(values: Tensor, indices: Tensor, keep_dim: bool = False):
    out = torch.gather(values, -1, indices.type(torch.int64).unsqueeze(-1))
    return out if keep_dim else out.squeeze(-1)


def preprocess_logits(
    logits: Tensor,
    temperature: float | Tensor = 1,
    regularize: bool = False,
    shift_next: bool = False,
):
    
    # We need to normalize logits to avoid overflow.
    if regularize:  # We need to normalize logits to avoid overflow.
        logits = regularize_logits(logits)  # (..., len, vocab_size)
    if isinstance(temperature, Tensor):
        temperature = temperature.expand(logits.shape[:-1])
        logits = logits / temperature.unsqueeze(-1)
    elif temperature != 1:
        logits = logits / temperature
    if shift_next:  # Since logits are for the next tokens, we shift them to align
        logits = shift_logits(logits)  # (..., len, vocab_size)

    return logits


def probabilities(
    logits: Tensor,
    tokens: Tensor,
    temperature: float | Tensor,
    regularize=True,
    shift_next=True,
    bias: float = 1e-2,
):  
    
    logits = preprocess_logits(logits, temperature, regularize, shift_next)

    exps = logits.exp()
    exp = fetch(exps, tokens) + bias / logits.size(-1)
    sum_exp = exps.sum(-1) + bias

    return exp / sum_exp


def prob_entropies(
    logits: Tensor,
    tokens: Tensor,
    temperature: float,
    regularize=True,
    shift_next=True,
    bias: float = 1e-2,
):  
    logits = preprocess_logits(logits, temperature, regularize, shift_next)
    
    exps = logits.exp()
    sum_exp = exps.sum(-1, keepdim=True)
    ps = (exps + bias / logits.size(-1)) / (sum_exp + bias)
    p = fetch(ps, tokens)
    ent = - (ps * ps.log()).sum(-1)

    return p, ent


def entropies(
    logits: Tensor,
    temperature: float,
    regularize=True,
    shift_next=True,
    bias: float = 1e-2,  # to avoid exp(logits) == 0
):  
    logits = preprocess_logits(logits, temperature, regularize, shift_next)
    
    exps = logits.exp()
    sum_exp = exps.sum(-1, keepdim=True)
    ps = (exps + bias / logits.size(-1)) / (sum_exp + bias)
    ent = - (ps * ps.log()).sum(-1)

    return ent


def quotient(
    logit_p: Tensor,
    logit_q: Tensor,
    logits_p: Tensor,
    logits_q: Tensor,
    bias: float = 1e-2,  # to avoid exp(logit) == 0
    debug_mask: Tensor | None = None,  # used to debug only
):
    exps_p = logits_p.exp()
    exps_q = logits_q.exp()
    exp_p = logit_p.exp() + bias / logits_p.size(-1)
    exp_q = logit_q.exp() + bias / logits_q.size(-1)
    sumexp_p = exps_p.sum(-1) + bias
    sumexp_q = exps_q.sum(-1) + bias
    q = (exp_p / sumexp_p) / (exp_q / sumexp_q)

    if debug_mask is not None:
        assert not q[debug_mask].isnan().any()

    return q


def shift_logits(logits: Tensor):

    return torch.cat(
        (
            torch.zeros(
                (*logits.shape[:-2], 1, logits.size(-1)),
                dtype=logits.dtype, device=logits.device
            ),
            logits[..., :-1, :]
        ),
        dim=-2
    )


def _match_suffix(tokens: Tensor, suffix: Tensor, ignore_mask: Tensor | None = None):
    """
    Match the tokens `(*n, T)` with suffices `(*m, L)`.
    Output a tensor `match` with shape `(*n, T, *m)`, such that `match[*i, t, *j]` is given by
    `torch.all((tokens[*i, t-L] == suffix[*j, :]) | ignore_mask[*j, :])`
    """

    T = tokens.size(-1)
    n = tokens.shape[:-1]
    dev = tokens.device

    L = suffix.size(-1)
    m = suffix.shape[:-1]
    suffix = suffix.to(device=dev, dtype=tokens.dtype)

    if ignore_mask is not None and ignore_mask.shape != suffix.shape:
        raise IndexError

    eq = tokens.reshape(tokens.shape + (1,) * suffix.ndim) == suffix
    assert eq.shape == (*n, T, *m, L)
    if ignore_mask is not None:
        ignore_mask = ignore_mask.to(tokens.device)

    _slices = tuple(slice(None) for _ in n)
    out = torch.ones(*n, T, *m, dtype=torch.bool, device=dev)

    for i in range(L):
        t = L - 1 - i
        if not 0 <= t < T:
            continue
        eq_i = torch.zeros_like(out)
        eq_i[_slices + (slice(t, T),)] = eq[_slices + (slice(T-t), ..., i)]
        if ignore_mask is not None:
            eq_i.masked_fill_(ignore_mask[..., i], True)
        out = out & eq_i
    
    return out


def match_suffix(tokens: Tensor, suffix: Tensor | Sequence[Tensor], ignore_mask: Tensor | None = None) -> Tensor:
    """
    Match the tokens `(*n, T)` with suffixes `(*m, L)`.
    Output a tensor `match` with shape `(*n, T, *m)`, such that `match[*i, t, *j]` is given by
    `torch.all((tokens[*i, t-L] == suffix[*j, :]) | ignore_mask[*j, :])`
    """

    if isinstance(suffix, Tensor):
        m = _match_suffix(tokens, suffix, ignore_mask)
    else:
        assert ignore_mask is None
        suffix_, ignore_mask = _combine_suffixes(*suffix)
        m = _match_suffix(tokens, suffix_, ignore_mask)
    return m


def _combine_suffixes(*suffixes: Tensor):

    maxlen: int = -1
    for suf in suffixes:
        if suf.ndim != 1:
            raise IndexError
        maxlen = max(maxlen, suf.size(0))
    
    if maxlen < 0:
        raise ValueError("No suffixes given.")

    sufs = []
    masks = []
    for suf in suffixes:
        padlen = maxlen - suf.size(0)
        pad = torch.full((padlen,), -1, dtype=suf.dtype, device=suf.device)
        suf = torch.cat((pad, suf))
        mask = torch.zeros_like(suf, dtype=torch.bool)
        mask[:padlen] = True
        sufs.append(suf)
        masks.append(mask)

    return torch.stack(sufs), torch.stack(masks)


def find_nonzero(x: Tensor, dim: int, 
                 mode: Literal['first', 'last'] = 'first', 
                 default: int | Tensor = -1,
                 bias: int = 0,
                 keepdim=False):
    
    # Handling the special case where the tensor has no data.
    if x.size(dim) == 0:
        if keepdim:
            size_ = x.shape[:dim] + (1,) + x.shape[dim:][1:]
        else:
            size_ = x.shape[:dim] + x.shape[dim:][1:]
        if isinstance(default, Tensor):
            idx = torch.empty(size_, dtype=torch.int64, device=x.device)
            idx.copy_(default)
        else:
            idx = torch.full(size_, default, dtype=torch.int64, device=x.device)
        return idx
    
    if x.dtype != torch.bool:
        x = (x != 0)
    
    has_nonzero = torch.any(x, dim, keepdim=keepdim)
    x = x.byte()
    if mode == 'first':
        idx = torch.argmax(x, dim, keepdim=keepdim)
    elif mode == 'last':
        idx = x.size(dim) - 1 - torch.argmax(x.flip(dim), dim, keepdim=keepdim)
    if bias != 0:
        idx = idx + bias
    return torch.where(has_nonzero, idx, default)


type _DataLike = Sequence | Tensor | np.ndarray


class NamedDataset[_Item: _DataLike](torch_data.Dataset):

    __data_dict__: dict[str, _Item]
    __data_activated__: bool
    __synch_size__: bool = True
    __ignore_empty__: bool = True
    
    def __init__(self, **data: _Item):
        
        object.__setattr__(self, "__data_activated__", True)
        object.__setattr__(self, "__data_dict__", dict())

        for name, list_ in data.items():
            setattr(self, name, list_)

    @contextmanager
    def _attr_setting_mode(self):
        """`__setattr__` sets attrubutes in python attribute dict (`object.__dict__`) instead of
        the data dict (`object.__data_dict__`)."""
        old = self.__data_activated__
        object.__setattr__(self, "__data_activated__", False)
        try:
            yield self
        finally:
            object.__setattr__(self, "__data_activated__", old)
    
    def __setattr__(self, name: str, value: _Item) -> None:
        if self.__data_activated__ is True:
            self.__data_dict__[name] = value
        else:
            object.__setattr__(self, name, value)

    def __getattr__(self, name: str) -> _Item:
        try:
            return self.__data_dict__[name]
        except KeyError:
            raise AttributeError
        
    def __delattr__(self, name: str) -> None:
        try:
            del self.__data_dict__[name]
        except KeyError:
            raise AttributeError
        
    def _names(self) -> Iterable[str]:
        if self.__ignore_empty__:
            for k, v in self.__data_dict__.items():
                if len(v) > 0:
                    yield k
        else:
            yield from self.__data_dict__.keys()
        
    def _items(self) -> Iterable[tuple[str, _Item]]:
        if self.__ignore_empty__:
            for k, v in self.__data_dict__.items():
                if len(v) > 0:
                    yield k, v
        else:
            yield from self.__data_dict__.items()
    
    def clear(self):
        self.__data_dict__.clear()
    
    def __len__(self) -> int:
        
        lengths = [len(v) for _, v in self._items()]

        if len(lengths) == 0:
            return 0
        
        if self.__synch_size__:
            if not all(length == lengths[0] for length in lengths):
                raise IndexError("Incongruous lengths of data.")
            return lengths[0]
        else:
            return min(lengths)
    
    def __getitem__(self, idx: int):
        return {k: v[idx] for k, v in self._items()}

    def __repr__(self):
        keys = (k for k, _ in self._items())
        return self.__class__.__name__ + '(' + (", ".join(keys)) + ')'


class ListDataset[T](list[T], torch_data.Dataset):

    @classmethod
    def no_collate(cls, x: list[T]) -> list[T]:
        return x


class TensorCollection(abc.ABC):

    @abc.abstractmethod
    def iterate_tensors(self) -> Iterable[Tensor]:
        raise NotImplementedError

    def apply_operation(self, fn: Callable[[Tensor], Any]) -> None:
        """
        apply a in-place operation to all tensors.
        """
        for tensor in self.iterate_tensors():
            fn(tensor)
    
    @abc.abstractmethod
    def apply_transform(self, fn: Callable[[Tensor], Tensor]) -> Self:
        """
        apply a transform `fn`. Returns a new object whose tensors are the
        return values of the function call.
        """
        raise NotImplementedError
    
    def to(self, device: torch.device | str | None = None):
        return self.apply_transform(lambda x: x.to(device=device))
    
    def clone(self) -> Self:
        return self.apply_transform(torch.clone)
    
    @property
    def nelem(self):
        return sum(tensor.nelement() for tensor in self.iterate_tensors())

    @property
    def total_memory(self):
        return sum(tensor.element_size() * tensor.nelement()
                   for tensor in self.iterate_tensors())


def iterate_tensors(items: Iterable[Tensor | TensorCollection]) -> Iterable[Tensor]:
    for item in items:
        if isinstance(item, Tensor):
            yield item
        else:
            yield from item.iterate_tensors()


class TensorDict[Key](dict[Key, Tensor], TensorCollection):
    
    def iterate_tensors(self) -> Iterable[Tensor]:
        return self.values()
    
    def apply_transform(self, fn: Callable[[Tensor], Tensor]) -> Self:
        return self.__class__({key: fn(value) for key, value in self.items()})


class NestedTensorDict[K, T: Tensor | TensorCollection](dict[K, T], TensorCollection):
    
    def iterate_tensors(self) -> Iterable[Tensor]:
        return iterate_tensors(self.values())
    
    def apply_transform(self, fn: Callable[[Tensor], Tensor]) -> Self:
        return self.__class__({
            key: fn(value) if not isinstance(value, TensorCollection) else value.apply_transform(fn)
            for key, value in self.items()
        })


class TensorList(list[Tensor], TensorCollection):
    
    def iterate_tensors(self) -> Iterable[Tensor]:
        yield from self
                
    def apply_transform(self, fn: Callable[[Tensor], Tensor]) -> Self:
        return self.__class__(fn(item) for item in self)

    def apply_transform_(self, fn: Callable[[Tensor], Tensor]) -> Self:
        for i, item in enumerate(self):
            self[i] = fn(item)
        return self


class NestedTensorList[T: Tensor | TensorCollection](list[T], TensorCollection):
    
    def iterate_tensors(self) -> Iterable[Tensor]:
        return iterate_tensors(self)
                
    def apply_transform(self, fn: Callable[[Tensor], Tensor]) -> Self:
        return self.__class__(
            fn(item) if not isinstance(item, TensorCollection) else item.apply_transform(fn)
            for item in self
        )


class TensorDataClass(TensorCollection):

    __init_fields__: ClassVar[frozenset[str]]

    def _init_fields(self) -> Iterable[str]:
        try:
            return self.__init_fields__
        except AttributeError:
            assert dc.is_dataclass(self)
            self.__class__.__init_fields__ = frozenset(
                field.name for field in dc.fields(self) if field.init
            )
            return self.__init_fields__

    def iterate_tensors(self) -> Iterable[Tensor]:
        assert dc.is_dataclass(self)
        for field in dc.fields(self):
            value = getattr(self, field.name, None)
            if isinstance(value, Tensor):
                yield value
            elif isinstance(value, TensorCollection):
                yield from value.iterate_tensors()

    def apply_transform(self, fn: Callable[[Tensor], Tensor]) -> Self:
        init_args = {}
        for name in self._init_fields():
            val = getattr(self, name)
            if isinstance(val, Tensor):
                val = fn(val)
            elif isinstance(val, TensorCollection):
                val = val.apply_transform(fn)
            init_args[name] = val
        return self.__class__(**init_args)


class ObjectTensor[T](TensorCollection):

    def __init__(self, indices: torch.Tensor, objects: tuple[T, ...]):
        self._objects = objects
        self._indices: torch.Tensor = indices
    
    def __getitem__(self, idx):
        return self._objects[self.idxmap(idx)]
    
    def idxmap(self, idx: tuple[int, ...]):
        if len(idx) != self._indices.ndim:
            raise IndexError
        return int(self._indices[idx])
    
    def iterate_tensors(self) -> Iterable[Tensor]:
        yield self._indices
    
    def apply_transform(self, fn: Callable[[Tensor], Tensor]) -> Self:
        indices = fn(self._indices)
        return self.__class__(indices, self._objects)
    
    @property
    def shape(self):
        return self._indices.shape
    
    def size(self, dim: int | None):
        return self._indices.size(dim)
    
    @property
    def ndim(self):
        return self._indices.ndim
    