import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import wraps

from einops import rearrange, repeat

from src.utils.padding import pad_to_multiple
from src.ops.permutation import invert_permutation
from src.models.attention.hash_utils import sort_key_val
from src.models.attention.batching_utils import batched_index_select
from src.models.modules.masking import FullMask, LengthMask
from src.models.attention.mask_utils import pad_mask

#constants

TOKEN_SELF_ATTN_VALUE = -5e4 # carefully set for half precision to work

# helper fns

def chunked_sum(tensor, chunks=1):
    *orig_size, last_dim = tensor.shape
    tensor = tensor.reshape(-1, last_dim)
    summed_tensors = [c.sum(dim=-1) for c in tensor.chunk(chunks, dim=0)]
    return torch.cat(summed_tensors, dim=0).reshape(orig_size)

def default(val, default_val):
    return default_val if val is None else val

def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

def cache_method_decorator(cache_attr, cache_namespace, reexecute=False):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, key_namespace=None, fetch=False, set_cache=True, **kwargs):
            namespace_str = str(default(key_namespace, ''))
            _cache = getattr(self, cache_attr)
            _keyname = f'{cache_namespace}:{namespace_str}'

            if fetch:
                val = _cache[_keyname]
                if reexecute:
                    fn(self, *args, **kwargs)
            else:
                val = fn(self, *args, **kwargs)
                if set_cache:
                    setattr(self, cache_attr, {**_cache, **{_keyname: val}})
            return val
        return wrapper
    return inner_fn

# LSH attention as described in https://openreview.net/pdf?id=rkgNKkHtvB
# adapted from trax, stripped to what paper said needed to work
# namely that buckets need to be at least 64 with 8 rounds of hashing
# https://github.com/google/trax/blob/master/trax/layers/research/efficient_attention.py#L442

class SparseAttention(nn.Module):
    def __init__(self,
                 softmax_temp=None,
                 attention_dropout=0.,
                 bucket_size=64,
                 n_hashes=8,
                 causal=False, reverse = False,
                 adjacency = 'adjacent',
                 allow_duplicate_attention=True,
                 attend_across_buckets=True,
                 rehash_each_round=True,
                 drop_for_hash_rate=0.0,
                 random_rotations_per_head=False,
                 device=None, dtype=None):
        super().__init__()
        self.softmax_temp = softmax_temp
        self.dropout = nn.Dropout(attention_dropout)
        self.dropout_for_hash = nn.Dropout(drop_for_hash_rate)

        assert rehash_each_round or allow_duplicate_attention, (
            'The setting {allow_duplicate_attention=False, rehash_each_round=False}'
            ' is not implemented.')

        self.causal = causal
        self.bucket_size = int(bucket_size)

        self.n_hashes = n_hashes

        self._allow_duplicate_attention = allow_duplicate_attention
        self._attend_across_buckets = attend_across_buckets
        self._rehash_each_round = rehash_each_round
        self._random_rotations_per_head = random_rotations_per_head

        self.reverse = reverse
        if self.reverse:
            self.WQ = nn.Parameter(torch.randn(3,3))
            self.WK = nn.Parameter(torch.randn(3,3))

        # cache buckets for reversible network, reported by authors to make Reformer work at depth
        self._cache = {}

        self.adjacency = adjacency

    @cache_method_decorator('_cache', 'buckets', reexecute=True)
    def hash_vectors(self, n_buckets, vecs):
        batch_size = vecs.shape[0]
        device = vecs.device

        # See https://arxiv.org/pdf/1509.02897.pdf
        # We sample a different random rotation for each round of hashing to
        # decrease the probability of hash misses.
        assert n_buckets % 2 == 0

        rot_size = n_buckets

        rotations_shape = (
            batch_size if self._random_rotations_per_head else 1,
            vecs.shape[-1],
            self.n_hashes if self._rehash_each_round else 1,
            rot_size // 2)

        random_rotations = torch.randn(rotations_shape, dtype=vecs.dtype, device=device).expand(batch_size, -1, -1, -1)

        dropped_vecs = self.dropout_for_hash(vecs)
        rotated_vecs = torch.einsum('btf,bfhi->bhti', dropped_vecs, random_rotations)

        if self._rehash_each_round:
            # rotated_vectors size [batch,n_hash,seq_len,buckets]
            rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1)
            buckets = torch.argmax(rotated_vecs, dim=-1)
        else:
            rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1)
            # In this configuration, we map each item to the top self.n_hashes buckets
            rotated_vecs = torch.squeeze(rotated_vecs, 1)
            bucket_range = torch.arange(rotated_vecs.shape[-1], device=device)
            bucket_range = torch.reshape(bucket_range, (1, -1))
            bucket_range = bucket_range.expand_as(rotated_vecs)

            _, buckets = sort_key_val(rotated_vecs, bucket_range, dim=-1)
            # buckets size [batch size, seq_len, buckets]
            buckets = buckets[... , -self.n_hashes:].transpose(1, 2)

        # buckets is now (batch_size, self.n_hashes, seq_len).
        return buckets

    def forward(self, qk, k, v, attn_mask=None, key_padding_mask=None, need_weights=False):
        # Ignoring k, assuming that q = k = qk
        #convert inputs to torch.float
        qk = qk.type(torch.float)
        k = k.type(torch.float)
        v = v.type(torch.float)
        v = v.T # transpose v

        #reshape to accomodate the required dimensions for now
        qk = qk.unsqueeze(0).unsqueeze(2)
        k = k.unsqueeze(0).unsqueeze(2)
        v = v.unsqueeze(0).unsqueeze(2)

        #standardise q and k for better hashing
        qk  = qk - qk.mean(dim = 1)
        k = k - k.mean(dim = 1)
        qk = F.normalize(qk, p =2, dim = -1).type_as(qk)
        k = F.normalize(k, p = 2, dim = -1).type_as(k)


        _, seqlen_og, n_head, _ = qk.shape
        qk = pad_to_multiple(qk, self.bucket_size * 2, dims=1)
        v = pad_to_multiple(v, self.bucket_size * 2, dims=1)

        # Extract some shapes and compute the temperature
        B, T, H, E = qk.shape
        _, S, _, D = v.shape
        softmax_temp = self.softmax_temp or 1 / math.sqrt(E)

        # various checks for padding in cases when the keys and queries are not the right length
        """ # pad the masks
        if S > seqlen_og:
            if key_padding_mask is None:
                key_padding_mask = LengthMask(qk.new_full((qk.shape[0],), seqlen_og,
                                                           dtype=torch.long), max_len=S)
            else:
                key_padding_mask = pad_mask(key_padding_mask, pad_length=S - seqlen_og,
                                            left=False, value=False)
        if key_padding_mask is not None and not key_padding_mask.all_ones:
            # Repeat for all heads and all hash functions
            key_padding_mask_bool = repeat(key_padding_mask.bool_matrix, 'b s -> (b head) s',
                                           head=H)
        else:
            key_padding_mask_bool = None
        if attn_mask is not None and (S > seqlen_og or T > seqlen_og):
            attn_mask = FullMask(F.pad(attn_mask._mask, (0, S - seqlen_og, 0, T - seqlen_og),
                                       value=False))
        if attn_mask is not None and not attn_mask.all_ones:
            attn_mask_bool = attn_mask.bool_matrix  # (T, S)
        else:
            attn_mask_bool = None
 """
        key_padding_mask_bool = None
        attn_mask_bool = None

        # rearrange to batch*heads x seq_length x d_z
        # i.e (1,256,3)
        qk = rearrange(qk, 'b t h e -> (b h) t e')
        v = rearrange(v, 'b s h d -> (b h) s d')
        batch_size, seqlen, dim, device = *qk.shape, qk.device

        assert seqlen % (self.bucket_size * 2) == 0, f'Sequence length ({seqlen}) needs to be divisible by target bucket size  x 2 - {self.bucket_size * 2}'

        n_buckets = seqlen // self.bucket_size

        #get buckets, shape: (1, n_buckets, 256)
        # hence here we get the bucket assignment for each vector for each hash round
        buckets = self.hash_vectors(n_buckets, qk, set_cache=self.training)

        assert buckets.shape[1] == self.n_hashes
        assert buckets.shape[2] == seqlen

        total_hashes = self.n_hashes

        # sort the vectors by their bucket assignment to get ready for chunking
        # s_buckets is hence just [0, ....0, 1 ... 1, ... 3, ....3]
        # perm is the original indices of the vectors before sorting,
        # perm.shape: (8,1,256) each row of 256 gives the original index of the vector for the hash round. 
        # i.e the first x vectors all corresponding to bucket 0 for the first hash could be [0, 3, 158,.... etc]
        buckets = rearrange(buckets, 'b nhashes seqlen -> nhashes b seqlen')
        s_buckets, perm = torch.sort(buckets, dim=-1, stable=True)
        # equivalent to an argsort of perm, hence sorts perm back into increasing order and gives the indices and therefore is like an inverse 
        perm_inv = invert_permutation(perm)

        # Allow each chunk to attend within itself, and also one chunk back. Chunk
        # boundaries might occur in the middle of a sequence of items from the
        # same bucket, so this increases the chances of attending to relevant items.

        # this may need to be edited for my purposes, look one back and look one forward most likely
        def look_one_back(x):
            x_extra = torch.cat([x[:, -1:, ...], x[:, :-1, ...]], dim=1)
            return torch.cat([x, x_extra], dim=2)
        
        def look_adjacent(x):
            x_extra = torch.cat([x[:, -1:, ...], x[:, :-1, ...]], dim=1) # look back
            lookback = torch.cat([x, x_extra], dim=2) 
            
            x_extra = torch.cat([x[:, 1:, ...], x[:, :1, ...]], dim=1) # look forward
            return torch.cat([lookback, x_extra], dim=2) 


        # We differ here from https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reformer_pytorch.py
        # the look_one_back only looks back at the bucket from the same hash function, while
        # lucidrains's implementation could look back at the bucket from the previous hash function.
        # perm_oneback = look_one_back(rearrange(perm, 'h b (nbuckets bucketsz) '
        #                                              '-> (h b) nbuckets bucketsz',
        #                                        nbuckets=n_buckets))
        
        if self.adjacency == 'adjacent':
            perm_oneback = look_adjacent(rearrange(perm, 'h b (nbuckets bucketsz) '
                                                     '-> (h b) nbuckets bucketsz',
                                               nbuckets=n_buckets))
            num_attention_buckets = 3
        if self.adjacency == 'oneback':
            perm_oneback = look_one_back(rearrange(perm, 'h b (nbuckets bucketsz) '
                                                     '-> (h b) nbuckets bucketsz',
                                               nbuckets=n_buckets))
            num_attention_buckets = 2

        if self.adjacency == 'intra':
            perm_oneback = perm
            num_attention_buckets = 1
        

        perm_oneback = rearrange(perm_oneback, '(h b) nbuckets2 bucketsz -> h b (nbuckets2 bucketsz)',
                                 h=self.n_hashes)

        # sort queries, keys, values
        def sort_to_buckets(x, perm, bucketsz, unsqueeze=True):
            if unsqueeze:
                x = rearrange(x, 'b s d -> 1 b s d')
            return rearrange(batched_index_select(x, perm),
                             'h b (nbuckets bucketsz) d -> h b nbuckets bucketsz d',
                             bucketsz=bucketsz)

        #qk_norm = F.normalize(qk, p=2, dim=-1).type_as(qk)
        # tq.shape (8, 1, 4, 64, 3) - for each hash round, we have the 4 buckets each containing 64 3-colour vectors
        tq = sort_to_buckets(qk, perm, self.bucket_size)
        # tk.shape (8,1,4,128,3) - for each hash round, we have 4 buckets each containing 128 3-colour vectors due to the lookback
        tk = sort_to_buckets(qk, perm_oneback, self.bucket_size * num_attention_buckets) # look back and forward so bucket size 3x
        tv = sort_to_buckets(v, perm_oneback, self.bucket_size * num_attention_buckets) # look back and forward so bucket size 3x

        if self.reverse:
            # under reverse, we assume no linear projection has been applied yet, and so we sort to bucket -> then project
            # rather than project -> sort to bucket
            tq = torch.einsum("abcde, ef -> abcdf", tq, self.WQ)
            tk = torch.einsum("abcde, ef -> abcdf", tk, self.WK)
            

        # Dot-product attention.
        # we calculate dot product attention for each bucket, with the 64 vectors in each bucket from tq attending to the 128 vectors
        # per bucket in tk, where 64 of those vectors are the same as tq (hence attention within own-bucket) and 64 come from an adjacent
        # bucket (still unclear why tk uses qk_norm but tq doesn't. Typo in the repo?)
        # therefore inner.shape(8,1,4,64,128) for the 64x128 attention scores per bucket, for each hash
        inner = torch.einsum('zbhie,zbhje->zbhij', tq, tk) * softmax_temp
        masked_value = max_neg_value(inner)

        # break up perm from (8,1,256) to (8,1,4,64,1) - i.e vector assignment per bucket per hash round
        # break up perm_oneback in the same way to be 128 vectors per bucket per hash round
        bq_idx = rearrange(perm, 'h b (nbuckets bucketsz) -> h b nbuckets bucketsz 1',
                           bucketsz=self.bucket_size)
        bkv_idx = rearrange(perm_oneback, 'h b (nbuckets bucketsz2) -> h b nbuckets 1 bucketsz2',
                            bucketsz2=self.bucket_size * num_attention_buckets) # look back and forward so bucket size 3x

        # Mask for post qk attention logits of the input sequence
        """ if attn_mask_bool is not None:
            dot_attn_indices = bq_idx * seqlen + bkv_idx
            mask = attn_mask_bool.flatten()[dot_attn_indices]
            inner.masked_fill_(~mask, masked_value)
            del mask

        # mask out attention to padded tokens
        if key_padding_mask is not None and not key_padding_mask.all_ones:
            s_key_padding_mask = sort_to_buckets(rearrange(key_padding_mask_bool,
                                                           'b s -> b s 1'),
                                                 perm_oneback, self.bucket_size * 2)
            s_key_padding_mask = rearrange(s_key_padding_mask,
                                           '... bucketsz 1 -> ... 1 bucketsz')
            inner.masked_fill_(~s_key_padding_mask, masked_value)

        # Causal masking
        if self.causal:
            mask = bq_idx < bkv_idx
            inner.masked_fill_(mask, masked_value)
            del mask """

        # Mask out attention to self except when no other targets are available.
        self_mask = bq_idx == bkv_idx
        inner.masked_fill_(self_mask, TOKEN_SELF_ATTN_VALUE)
        del self_mask

        # Mask out attention to other hash buckets.
        """ if not self._attend_across_buckets:
            bq_buckets = sort_to_buckets(rearrange(buckets, 'h b s -> h b s 1'), perm,
                                         self.bucket_size, unsqueeze=False)
            bkv_buckets = sort_to_buckets(rearrange(buckets, 'h b s -> h b s 1'), perm_oneback,
                                          self.bucket_size * 2, unsqueeze=False)
            bkv_buckets = rearrange(bkv_buckets, 'h b nbuckets bucketsz2 1 -> h b nbuckets 1 bucketsz2')
            bucket_mask = bq_buckets != bkv_buckets
            inner.masked_fill_(bucket_mask, masked_value)
            del bucket_mask """

        # Don't double-count query-key pairs across multiple rounds of hashing.
        # There are two possible strategies here. (1) The default is to count how
        # many times a query-key pair is repeated, and to lower its log-prob
        # correspondingly at each repetition. (2) When hard_k is set, the code
        # instead masks all but the first occurence of each query-key pair.
        """ if not self._allow_duplicate_attention:
            locs1 = rearrange(perm_inv // self.bucket_size, 'h b seqlen -> b seqlen h')
            locs2 = (locs1 + 1) % n_buckets
            if not self._attend_across_buckets:
                locs1 = buckets * n_buckets + locs1
                locs2 = buckets * n_buckets + locs2
            locs = torch.cat([locs1, locs2], dim=-1)

            slocs = sort_to_buckets(locs, perm, self.bucket_size)  # (h b nbuckets bucketsz h*2)
            bq_locs = repeat(slocs[..., :total_hashes],
                             'h b nbuckets bucketsz nh -> h b nbuckets bucketsz 1 (2 nh)')
            bkv_locs = look_one_back(rearrange(slocs, 'h b nbuckets bucketsz nh2'
                                                      '-> (h b) nbuckets bucketsz nh2'))
            bkv_locs = rearrange(bkv_locs,
                                 '(h b) nbuckets bucketsz2 nh2 -> h b nbuckets 1 bucketsz2 nh2',
                                 h=self.n_hashes)
            dup_counts = bq_locs == bkv_locs
            # for memory considerations, chunk summation of last dimension for counting duplicates
            dup_counts = chunked_sum(dup_counts, chunks=(total_hashes * batch_size))
            dup_counts = dup_counts.detach()
            assert dup_counts.shape == inner.shape
            inner = inner - torch.log(dup_counts + 1e-9)
            del dup_counts """

        # Softmax.
        # dots_logsumexp: logsumexp rowwise for the 64x128 dot-prod matrices in each bucket and hash round. 
        # i.e dots_logsumexp is the (8,1,4,64,1) logged normalizing constants   
        # dots.shape: (8,1,4,64,128) 64x128 attention map in each of the 4 buckets per hash round with softmax fully applied
        dots_logsumexp = torch.logsumexp(inner, dim=-1, keepdim=True)
        dots = torch.exp(inner - dots_logsumexp).type_as(inner)
        dropped_dots = self.dropout(dots)

        # so should now be the computed outputs using the attention maps and the values
        # matmul last two indices of (8,1,4,64,128) and (8,1,4,128,3) => (8,1,4,64,3)
        # hence the 64 3-dim vectors per bucket per hash round
        so = torch.einsum('...ij,...jd->...id', dropped_dots, tv)

        # undo sort
        # first flatten out the 4x64 bucket-vector assigment to return the 256 vectors, then sort them back into their original order using perm_inverse
        def unsort_from_buckets(s_x, perm_inverse):
            b_x = rearrange(s_x, 'h b nbuckets bucketsz d -> h b (nbuckets bucketsz) d')
            return batched_index_select(b_x, perm_inverse)

        # o.shape: (8,1,256,3) with the 256 vectors back in their original order as when first read into the attention module
        # logits.shape (8,1,256,1) for each of the 256 vectors in their original order, we have the corresponding log normalizing constant
        o = unsort_from_buckets(so, perm_inv)
        logits = unsort_from_buckets(dots_logsumexp, perm_inv)
        
        # bucketed_attention.shape: (8,1,256, 128), 256 vectors in order, each with 128 attention scores 
        #bucketed_attention = unsort_from_buckets(dots, perm_inv)

        # logits are the softmax normalizing constants for each of the 256 vectors in order, per hash round
        # if we normalize these logits over the hash rounds, then we're downweighting mis-hashes, since if v is in the same bucket 7 times but mis-hashed once
        # then it wil get a small attention score when in the wrong bucket (basically be dropped) and so the prob for v in that bucket will go to 0 while the other
        # 7 buckets will each have 1/7 prob, and so when summing only those values show up
        # torch.sim along dim=0 performs the averaging over hash round
        probs = torch.exp(logits - torch.logsumexp(logits, dim=0, keepdim=True))
        out = torch.sum(o * probs, dim=0)
        # rearrange to original shape (batch, seq_length, heads, dim) = (1,256,1,3)
        out = rearrange(out, '(b h) t d -> b t h d', h=H)
        # get back original sequence length in case of padding
        out = out[:, :seqlen_og]

        out = out.squeeze()


        attn = None
        if need_weights:
            dot_attn_indices = rearrange(bq_idx * seqlen + bkv_idx,
                                         'h b nbuckets qbucketsz kbucketsz -> h b (nbuckets qbucketsz kbucketsz)')
            unsorted_dots = torch.zeros(self.n_hashes, batch_size, seqlen * seqlen, device=device)
            unsorted_dots.scatter_(-1, dot_attn_indices, dots.view_as(dot_attn_indices))
            del dot_attn_indices
            unsorted_dots = rearrange(unsorted_dots,
                                      'h b (q_seqlen k_seqlen) -> h b q_seqlen k_seqlen',
                                      q_seqlen = seqlen)
            attn = torch.sum(unsorted_dots * probs, dim=0)
            attn = rearrange(attn, '(b h) t s -> b h t s', h=n_head)[:, :, :seqlen_og, :seqlen_og]

        return out, attn
    


