from typing import Dict, List, Optional, Literal
from collections import defaultdict
import torch


def _add_last_dim(dictionary, key, new_value, prepend=False):
    # append/prepend a new value to the last dimension of a tensor in a dictionary
    # if the key does not exist, put the new value in
    # append by default
    if key in dictionary:
        dictionary[key] = torch.cat([dictionary[key], new_value], -1)
    else:
        dictionary[key] = new_value


class KeyValueMemoryStore:
    """
    Works for key/value pairs type storage
    e.g., working and long-term memory
    """
    def __init__(self, save_selection: bool = False, save_usage: bool = False):
        """
        We store keys and values of objects that first appear in the same frame in a bucket.
        Each bucket contains a set of object ids.
        Each bucket is associated with a single key tensor
            and a dictionary of value tensors indexed by object id.

        The keys and values are stored as the concatenation of a permanent part and a temporary part.
        """
        self.save_selection = save_selection
        self.save_usage = save_usage

        self.global_bucket_id = 0  # does not reduce even if buckets are removed
        self.buckets: Dict[int, List[int]] = {}  # indexed by bucket id
        self.k: Dict[int, torch.Tensor] = {}  # indexed by bucket id
        self.v: Dict[int, torch.Tensor] = {}  # indexed by object id

        # indexed by bucket id; the end point of permanent memory
        self.perm_end_pt: Dict[int, int] = defaultdict(int)

        # shrinkage and selection are just like the keys
        self.s = {}
        if self.save_selection:
            self.e = {}  # does not contain the permanent memory part

        # usage
        if self.save_usage:
            self.use_cnt = {}  # indexed by bucket id, does not contain the permanent memory part
            self.life_cnt = {}  # indexed by bucket id, does not contain the permanent memory part

    def add(self,
            key: torch.Tensor,
            values: Dict[int, torch.Tensor],
            shrinkage: torch.Tensor,
            selection: torch.Tensor,
            supposed_bucket_id: int = -1,
            as_permanent: Literal['no', 'first', 'all'] = 'no') -> None:
        """
        key: (1/2)*C*N
        values: dict of values ((1/2)*C*N), object ids are used as keys
        shrinkage: (1/2)*1*N
        selection: (1/2)*C*N

        supposed_bucket_id: used to sync the bucket id between working and long-term memory
        if provided, the input should all be in a single bucket indexed by this id
        as_permanent: whether to store the input as permanent memory
            'no': don't
            'first': only store it as permanent memory if the bucket is empty
            'all': always store it as permanent memory
        """
        bs = key.shape[0]
        ne = key.shape[-1]
        assert len(key.shape) == 3
        assert len(shrinkage.shape) == 3
        assert not self.save_selection or len(selection.shape) == 3
        assert as_permanent in ['no', 'first', 'all']

        # add the value and create new buckets if necessary
        if supposed_bucket_id >= 0:
            enabled_buckets = [supposed_bucket_id]
            bucket_exist = supposed_bucket_id in self.buckets
            for obj, value in values.items():
                if bucket_exist:
                    assert obj in self.v
                    assert obj in self.buckets[supposed_bucket_id]
                    _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
                else:
                    assert obj not in self.v
                    self.v[obj] = value
            self.buckets[supposed_bucket_id] = list(values.keys())
        else:
            new_bucket_id = None
            enabled_buckets = set()
            for obj, value in values.items():
                assert len(value.shape) == 3
                if obj in self.v:
                    _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
                    bucket_used = [
                        bucket_id for bucket_id, object_ids in self.buckets.items()
                        if obj in object_ids
                    ]
                    assert len(bucket_used) == 1  # each object should only be in one bucket
                    enabled_buckets.add(bucket_used[0])
                else:
                    self.v[obj] = value
                    if new_bucket_id is None:
                        # create new bucket
                        new_bucket_id = self.global_bucket_id
                        self.global_bucket_id += 1
                        self.buckets[new_bucket_id] = []
                    # put the new object into the corresponding bucket
                    self.buckets[new_bucket_id].append(obj)
                    enabled_buckets.add(new_bucket_id)

        # increment the permanent size if necessary
        add_as_permanent = {}  # indexed by bucket id
        for bucket_id in enabled_buckets:
            add_as_permanent[bucket_id] = False
            if as_permanent == 'all':
                self.perm_end_pt[bucket_id] += ne
                add_as_permanent[bucket_id] = True
            elif as_permanent == 'first':
                if self.perm_end_pt[bucket_id] == 0:
                    self.perm_end_pt[bucket_id] = ne
                    add_as_permanent[bucket_id] = True

        # create new counters for usage if necessary
        if self.save_usage and as_permanent != 'all':
            new_count = torch.zeros((bs, ne), device=key.device, dtype=torch.float32)
            new_life = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + 1e-7

        # add the key to every bucket
        for bucket_id in self.buckets:
            if bucket_id not in enabled_buckets:
                # if we are not adding new values to a bucket, we should skip it
                continue

            _add_last_dim(self.k, bucket_id, key, prepend=add_as_permanent[bucket_id])
            _add_last_dim(self.s, bucket_id, shrinkage, prepend=add_as_permanent[bucket_id])
            if not add_as_permanent[bucket_id]:
                if self.save_selection:
                    _add_last_dim(self.e, bucket_id, selection)
                if self.save_usage:
                    _add_last_dim(self.use_cnt, bucket_id, new_count)
                    _add_last_dim(self.life_cnt, bucket_id, new_life)

    def update_bucket_usage(self, bucket_id: int, usage: torch.Tensor) -> None:
        # increase all life count by 1
        # increase use of indexed elements
        if not self.save_usage:
            return

        usage = usage[:, self.perm_end_pt[bucket_id]:]
        if usage.shape[-1] == 0:
            # if there is no temporary memory, we don't need to update
            return
        self.use_cnt[bucket_id] += usage.view_as(self.use_cnt[bucket_id])
        self.life_cnt[bucket_id] += 1

    def sieve_by_range(self, bucket_id: int, start: int, end: int, min_size: int) -> None:
        # keep only the temporary elements *outside* of this range (with some boundary conditions)
        # the permanent elements are ignored in this computation
        # i.e., concat (a[:start], a[end:])
        # bucket with size <= min_size are not modified

        assert start >= 0
        assert end <= 0

        object_ids = self.buckets[bucket_id]
        bucket_num_elements = self.k[bucket_id].shape[-1] - self.perm_end_pt[bucket_id]
        if bucket_num_elements <= min_size:
            return

        if end == 0:
            # negative 0 would not work as the end index!
            # effectively make the second part an empty slice
            end = self.k[bucket_id].shape[-1] + 1

        p_size = self.perm_end_pt[bucket_id]
        start = start + p_size

        k = self.k[bucket_id]
        s = self.s[bucket_id]
        if self.save_selection:
            e = self.e[bucket_id]
        if self.save_usage:
            use_cnt = self.use_cnt[bucket_id]
            life_cnt = self.life_cnt[bucket_id]

        self.k[bucket_id] = torch.cat([k[:, :, :start], k[:, :, end:]], -1)
        self.s[bucket_id] = torch.cat([s[:, :, :start], s[:, :, end:]], -1)
        if self.save_selection:
            self.e[bucket_id] = torch.cat([e[:, :, :start - p_size], e[:, :, end:]], -1)
        if self.save_usage:
            self.use_cnt[bucket_id] = torch.cat([use_cnt[:, :start - p_size], use_cnt[:, end:]], -1)
            self.life_cnt[bucket_id] = torch.cat([life_cnt[:, :start - p_size], life_cnt[:, end:]],
                                                 -1)
        for obj_id in object_ids:
            v = self.v[obj_id]
            self.v[obj_id] = torch.cat([v[:, :, :start], v[:, :, end:]], -1)

    def remove_old_memory(self, bucket_id: int, max_len: int) -> None:
        self.sieve_by_range(bucket_id, 0, -max_len, max_len)

    def remove_obsolete_features(self, bucket_id: int, max_size: int) -> None:
        # for long-term memory only
        object_ids = self.buckets[bucket_id]

        assert self.perm_end_pt[bucket_id] == 0  # permanent memory should be empty in LT memory

        # normalize with life duration
        usage = self.get_usage(bucket_id)
        bs = usage.shape[0]

        survivals = []

        for bi in range(bs):
            _, survived = torch.topk(usage[bi], k=max_size)
            survivals.append(survived.flatten())
            assert survived.shape[-1] == survivals[0].shape[-1]

        self.k[bucket_id] = torch.stack(
            [self.k[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
        self.s[bucket_id] = torch.stack(
            [self.s[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)

        if self.save_selection:
            # Long-term memory does not store selection so this should not be needed
            self.e[bucket_id] = torch.stack(
                [self.e[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
        for obj_id in object_ids:
            self.v[obj_id] = torch.stack(
                [self.v[obj_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)

        self.use_cnt[bucket_id] = torch.stack(
            [self.use_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
        self.life_cnt[bucket_id] = torch.stack(
            [self.life_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)

    def get_usage(self, bucket_id: int) -> torch.Tensor:
        # return normalized usage
        if not self.save_usage:
            raise RuntimeError('I did not count usage!')
        else:
            usage = self.use_cnt[bucket_id] / self.life_cnt[bucket_id]
            return usage

    def get_all_sliced(
        self, bucket_id: int, start: int, end: int
    ) -> (torch.Tensor, torch.Tensor, torch.Tensor, Dict[int, torch.Tensor], torch.Tensor):
        # return k, sk, ek, value, normalized usage in order, sliced by start and end
        # this only queries the temporary memory

        assert start >= 0
        assert end <= 0

        p_size = self.perm_end_pt[bucket_id]
        start = start + p_size

        if end == 0:
            # negative 0 would not work as the end index!
            k = self.k[bucket_id][:, :, start:]
            sk = self.s[bucket_id][:, :, start:]
            ek = self.e[bucket_id][:, :, start - p_size:] if self.save_selection else None
            value = {obj_id: self.v[obj_id][:, :, start:] for obj_id in self.buckets[bucket_id]}
            usage = self.get_usage(bucket_id)[:, start - p_size:] if self.save_usage else None
        else:
            k = self.k[bucket_id][:, :, start:end]
            sk = self.s[bucket_id][:, :, start:end]
            ek = self.e[bucket_id][:, :, start - p_size:end] if self.save_selection else None
            value = {obj_id: self.v[obj_id][:, :, start:end] for obj_id in self.buckets[bucket_id]}
            usage = self.get_usage(bucket_id)[:, start - p_size:end] if self.save_usage else None

        return k, sk, ek, value, usage

    def purge_except(self, obj_keep_idx: List[int]):
        # purge certain objects from the memory except the one listed
        obj_keep_idx = set(obj_keep_idx)

        # remove objects that are not in the keep list from the buckets
        buckets_to_remove = []
        for bucket_id, object_ids in self.buckets.items():
            self.buckets[bucket_id] = [obj_id for obj_id in object_ids if obj_id in obj_keep_idx]
            if len(self.buckets[bucket_id]) == 0:
                buckets_to_remove.append(bucket_id)

        # remove object values that are not in the keep list
        self.v = {k: v for k, v in self.v.items() if k in obj_keep_idx}

        # remove buckets that are empty
        for bucket_id in buckets_to_remove:
            del self.buckets[bucket_id]
            del self.k[bucket_id]
            del self.s[bucket_id]
            if self.save_selection:
                del self.e[bucket_id]
            if self.save_usage:
                del self.use_cnt[bucket_id]
                del self.life_cnt[bucket_id]

    def clear_non_permanent_memory(self):
        # clear all non-permanent memory
        for bucket_id in self.buckets:
            self.sieve_by_range(bucket_id, 0, 0, 0)

    def get_v_size(self, obj_id: int) -> int:
        return self.v[obj_id].shape[-1]

    def size(self, bucket_id: int) -> int:
        if bucket_id not in self.k:
            return 0
        else:
            return self.k[bucket_id].shape[-1]

    def perm_size(self, bucket_id: int) -> int:
        return self.perm_end_pt[bucket_id]

    def non_perm_size(self, bucket_id: int) -> int:
        return self.size(bucket_id) - self.perm_size(bucket_id)

    def engaged(self, bucket_id: Optional[int] = None) -> bool:
        if bucket_id is None:
            return len(self.buckets) > 0
        else:
            return bucket_id in self.buckets

    @property
    def num_objects(self) -> int:
        return len(self.v)

    @property
    def key(self) -> Dict[int, torch.Tensor]:
        return self.k

    @property
    def value(self) -> Dict[int, torch.Tensor]:
        return self.v

    @property
    def shrinkage(self) -> Dict[int, torch.Tensor]:
        return self.s

    @property
    def selection(self) -> Dict[int, torch.Tensor]:
        return self.e

    def __contains__(self, key):
        return key in self.v
