import numpy as np
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss 
from torch import nn

from transformers import T5Tokenizer, T5ForConditionalGeneration, get_cosine_schedule_with_warmup
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import T5ForConditionalGeneration
from transformers.models.t5.modeling_t5 import T5Stack,T5Block,T5LayerSelfAttention,T5LayerCrossAttention,T5LayerFF,T5Attention,T5LayerNorm
import copy
from transformers.models.t5.configuration_t5 import T5Config
import torch
from typing import Optional, Tuple, Union
from transformers.modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,)

import math
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

import math
from argparse import Namespace

from functools import partial
from contextlib import contextmanager
from pathlib import Path
from filelock import FileLock

import torch.nn.functional as F
from torch import nn, einsum

from knn_memory import KNNMemoryList, DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY

# helper functions

def identity(t):
    return t

def exists(val):
    return val is not None

def unique(arr):
    return list({el: True for el in arr}.keys())

def default(val, d):
    return val if exists(val) else d

def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

def l2norm(t):
    return F.normalize(t, dim = -1)

def stable_softmax(t, dim = -1):
    t = t - t.amax(dim = dim, keepdim = True).detach()
    return F.softmax(t, dim = dim)

class MemoryT5(T5ForConditionalGeneration):
    def __init__(self, config,**kwargs):
        super().__init__(config)
        args=Namespace(**kwargs)
        self.knn_memories_directory= DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY
        self.knn_mem_kwargs = dict(
            dim = config.d_kv,
            max_memories = args.max_knn_memories,
            multiprocessing = args.knn_memory_multiprocessing,
        )
        #be careful to list down all args that the model needs.
        with self.knn_memories_context(batch_size = args.batch_size) as knn_memories:
            knn_memories_iter = iter(knn_memories)
            knn_memory= next(knn_memories_iter)
            encoder_config = copy.deepcopy(config)
            encoder_config.is_decoder = False
            encoder_config.use_cache = False
            encoder_config.is_encoder_decoder = False
            self.encoder = ModifiedT5Stack(encoder_config, self.shared,knn_memory=knn_memory,**kwargs)
        
            decoder_config = copy.deepcopy(config)
            decoder_config.is_decoder = True
            decoder_config.is_encoder_decoder = False
            decoder_config.num_layers = config.num_decoder_layers
            self.decoder = ModifiedT5Stack(decoder_config, self.shared,knn_memory=knn_memory,**kwargs)
        
        original_model = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base');

        self.load_state_dict(original_model.state_dict(), strict=False)
        #how to load memory.

    def create_knn_memories(
        self,
        *,
        batch_size
    ):
        return KNNMemoryList.create_memories(
            batch_size = batch_size,
            num_memory_layers=1,
            memories_directory = self.knn_memories_directory,
        )(**self.knn_mem_kwargs)

    @contextmanager
    def knn_memories_context(
        self,
        **kwargs
    ):
        knn_dir = Path(self.knn_memories_directory)
        knn_dir.mkdir(exist_ok = True, parents = True)
        lock = FileLock(str(knn_dir / 'mutex'))

        with lock:
            self.knn_memories = self.create_knn_memories(**kwargs)
            yield self.knn_memories
            #self.knn_memories.cleanup()

    def clear_memory(self, x, token_id):
        """ clears the KNN memories based on if the batch row contains the specified token id """
        """ for auto-clearing KNN memories based on start and end of strings """

        clear_memory = (x == token_id).any(dim = -1)
        batch_indices, _ = clear_memory.nonzero(as_tuple = True)
        batch_indices_to_clear = batch_indices.tolist()

        if len(batch_indices_to_clear) == 0:
            return

        knn_memories.clear_memory(batch_indices_to_clear)
        
class ModifiedT5Stack(T5Stack):
    def __init__(self, config, embed_tokens=None, knn_memory=None, **kwargs):
        
        super().__init__(config, embed_tokens)
        self.external_memory=True
        if config.is_decoder:
            self.block[9] = ModifiedT5Block(config, has_relative_attention_bias=False,knn_memory=knn_memory,**kwargs) #randomly decided to do 9h layer instead of 10th now lol
        else:
            self.block[11]=EncoderT5Block(config, has_relative_attention_bias=False,knn_memory=knn_memory,**kwargs)

class EncoderT5Block(T5Block):
    def __init__(self, config, has_relative_attention_bias=False,knn_memory=None,**kwargs):
        super().__init__(config, has_relative_attention_bias=has_relative_attention_bias)
        self.external_memory=True
        self.layer = nn.ModuleList()
        self.is_decoder = config.is_decoder
        self.layer.append(EncoderT5LayerSelfAttention(config,has_relative_attention_bias=has_relative_attention_bias,knn_memory=knn_memory,**kwargs))
        if self.is_decoder:#what to do with this?
            self.layer.append(T5LayerCrossAttention(config))#Let's see how it does without encoder cross attention
        self.layer.append(T5LayerFF(config))
        #print(config)
        #self.knn_attention= KNNAttention()
        
class EncoderT5LayerSelfAttention(T5LayerSelfAttention):
    def __init__(self,config, has_relative_attention_bias=False,knn_memory=None,**kwargs):
        super().__init__(config, has_relative_attention_bias=has_relative_attention_bias)
        self.SelfAttention = EncoderT5Attention(config, has_relative_attention_bias=has_relative_attention_bias,knn_memory=knn_memory,**kwargs)

class EncoderT5Attention(T5Attention):
    
    def __init__(self, config: T5Config, has_relative_attention_bias=False,knn_memory=None,**kwargs):
        super().__init__(config, has_relative_attention_bias=has_relative_attention_bias)
        args = Namespace(**kwargs)
        self.knn_memory=knn_memory
        self.max_knn_memories = args.max_knn_memories
    
    def forward(
        self,
        hidden_states,
        mask=None,
        key_value_states=None,
        position_bias=None,
        past_key_value=None,
        layer_head_mask=None,
        query_length=None,
        use_cache=False,
        output_attentions=False,
    ):
        """
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
        """
        # Input is (batch_size, seq_length, dim)
        # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
        # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
        batch_size, seq_length = hidden_states.shape[:2]

        real_seq_length = seq_length

        if past_key_value is not None:
            assert (
                len(past_key_value) == 2
            ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
            real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length

        key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]

        def shape(states):
            """projection"""
            return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

        def unshape(states):
            """reshape"""
            return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

        def project(hidden_states, proj_layer, key_value_states, past_key_value):
            """projects hidden states correctly to key/query states"""
            if key_value_states is None:
                # self-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(hidden_states))
            elif past_key_value is None:
                # cross-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(key_value_states))

            if past_key_value is not None:
                if key_value_states is None:
                    # self-attn
                    # (batch_size, n_heads, key_length, dim_per_head)
                    hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
                elif past_key_value.shape[2] != key_value_states.shape[1]:
                    # checking that the `sequence_length` of the `past_key_value` is the same as
                    # the provided `key_value_states` to support prefix tuning
                    # cross-attn
                    # (batch_size, n_heads, seq_length, dim_per_head)
                    hidden_states = shape(proj_layer(key_value_states))
                else:
                    # cross-attn
                    hidden_states = past_key_value
            return hidden_states

        # get query states
        query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)

        # get key/value states
        key_states = project(
            hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
        )
        value_states = project(
            hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
        )
        """

        ChatGPT's guess for dimensions of add memory: 
        So, the dimension of new_kv_memories_discarded is (batch_size, sequence_length - xl_max_memories, 2, heads, dim_head)
        key_states and Value_states are currently (batch_size, heads,sequence, dim_head)
        """
        add_knn_memory=True
        #knn_memory=next(self.knn_memories_iter)
        if(add_knn_memory is True):
            mem_k=rearrange(key_states, 'b h n d -> b n h d')
            mem_v=rearrange(value_states, 'b h n d -> b n h d')
            mem_kv= torch.stack((mem_k, mem_v), dim= -3)
            for i in range(mem_kv.shape[2]): #Iterating over the 'h' dimension
                self.knn_memory.add(mem_kv[:,:,:, i,:])
                #knn_memory=next(self.knn_memories_iter)#by commenting this, it may just overwrite into memories. Let's find a way to verify memories soon.
        
        # compute scores
        scores = torch.matmul(
            query_states, key_states.transpose(3, 2)
        )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9

        if position_bias is None:
            if not self.has_relative_attention_bias:
                position_bias = torch.zeros(
                    (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
                )
                if self.gradient_checkpointing and self.training:
                    position_bias.requires_grad = True
            else:
                position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)

            # if key and values are already calculated
            # we want only the last query position bias
            if past_key_value is not None:
                position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

            if mask is not None:
                position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)

        if self.pruned_heads:
            mask = torch.ones(position_bias.shape[1])
            mask[list(self.pruned_heads)] = 0
            position_bias_masked = position_bias[:, mask.bool()]
        else:
            position_bias_masked = position_bias

        scores += position_bias_masked
        attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
            scores
        )  # (batch_size, n_heads, seq_length, key_length)
        attn_weights = nn.functional.dropout(
            attn_weights, p=self.dropout, training=self.training
        )  # (batch_size, n_heads, seq_length, key_length)

        # Mask heads if we want to
        if layer_head_mask is not None:
            attn_weights = attn_weights * layer_head_mask

        attn_output = unshape(torch.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim)
        attn_output = self.o(attn_output)

        present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
        outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)

        if output_attentions:
            outputs = outputs + (attn_weights,)
        return outputs

###Decoder PART
class ModifiedT5Block(T5Block):
    def __init__(self, config, has_relative_attention_bias=False,knn_memory=None,**kwargs):
        super().__init__(config, has_relative_attention_bias=has_relative_attention_bias)
        self.external_memory=True
        self.layer = nn.ModuleList()
        self.is_decoder = config.is_decoder
        self.layer.append(ModifiedT5LayerSelfAttention(config,has_relative_attention_bias=has_relative_attention_bias,knn_memory=knn_memory,**kwargs))
        if self.is_decoder:#what to do with this?
            self.layer.append(T5LayerCrossAttention(config))#Let's see how it does without encoder cross attention
        self.layer.append(T5LayerFF(config))
        #print(config)
        #self.knn_attention= KNNAttention()
        
class ModifiedT5LayerSelfAttention(T5LayerSelfAttention):
    def __init__(self,config, has_relative_attention_bias=False,knn_memory=None,**kwargs):
        super().__init__(config, has_relative_attention_bias=has_relative_attention_bias)
        self.SelfAttention = ModifiedT5Attention(config, has_relative_attention_bias=has_relative_attention_bias,knn_memory=knn_memory,**kwargs)

class ModifiedT5Attention(T5Attention):
    
    def __init__(self, config: T5Config, has_relative_attention_bias=False,knn_memory=None,**kwargs):
        super().__init__(config, has_relative_attention_bias=has_relative_attention_bias)
        self.output_gate = nn.Parameter(torch.zeros(1)) 
        args=Namespace(**kwargs)
        self.knn_memory=knn_memory
        self.num_retrieved_memories=args.num_retrieved_memories
        attn_scale_init=20
        self.scale = nn.Parameter(torch.ones(config.num_heads, 1, 1) * math.log(attn_scale_init))
    def forward(
        self,
        hidden_states,
        mask=None,
        key_value_states=None,
        position_bias=None,
        past_key_value=None,
        layer_head_mask=None,
        query_length=None,
        use_cache=False,
        output_attentions=False,
    ):
        """
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
        """
        # Input is (batch_size, seq_length, dim)
        # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
        # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
        batch_size, seq_length = hidden_states.shape[:2]

        real_seq_length = seq_length

        if past_key_value is not None:
            assert (
                len(past_key_value) == 2
            ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
            real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length

        key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]

        def shape(states):
            """projection"""
            return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

        def unshape(states):
            """reshape"""
            return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

        def project(hidden_states, proj_layer, key_value_states, past_key_value):
            """projects hidden states correctly to key/query states"""
            if key_value_states is None:
                # self-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(hidden_states))
            elif past_key_value is None:
                # cross-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(key_value_states))

            if past_key_value is not None:
                if key_value_states is None:
                    # self-attn
                    # (batch_size, n_heads, key_length, dim_per_head)
                    hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
                elif past_key_value.shape[2] != key_value_states.shape[1]:
                    # checking that the `sequence_length` of the `past_key_value` is the same as
                    # the provided `key_value_states` to support prefix tuning
                    # cross-attn
                    # (batch_size, n_heads, seq_length, dim_per_head)
                    hidden_states = shape(proj_layer(key_value_states))
                else:
                    # cross-attn
                    hidden_states = past_key_value
            return hidden_states

        # get query states
        query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)
        
        # get key/value states
        key_states = project(
            hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
        )
        value_states = project(
            hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
        )

        # compute scores
        scores = torch.matmul(
            query_states, key_states.transpose(3, 2)
        )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9

        if position_bias is None:
            if not self.has_relative_attention_bias:
                position_bias = torch.zeros(
                    (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
                )
                if self.gradient_checkpointing and self.training:
                    position_bias.requires_grad = True
            else:
                position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)

            # if key and values are already calculated
            # we want only the last query position bias
            if past_key_value is not None:
                position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

            if mask is not None:
                position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)

        if self.pruned_heads:
            mask = torch.ones(position_bias.shape[1])
            mask[list(self.pruned_heads)] = 0
            position_bias_masked = position_bias[:, mask.bool()]
        else:
            position_bias_masked = position_bias

        scores += position_bias_masked
        attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
            scores
        )  # (batch_size, n_heads, seq_length, key_length)
        attn_weights = nn.functional.dropout(
            attn_weights, p=self.dropout, training=self.training
        )  # (batch_size, n_heads, seq_length, key_length)

        # Mask heads if we want to
        if layer_head_mask is not None:
            attn_weights = attn_weights * layer_head_mask

        attn_output = unshape(torch.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim)
        
        '''
        ## Place to read the memory, replaced q with query_states; it appears to be b,h,n,d same as the knn code expects
        # b=batch size, h= number of heads, n= sequence length, d= dim_per_head?
        print query_states.shape and make sure it's b h n d not b n h d which I feel like it might be. 
        h=12 and n=15 for test_examples. 
'''
        scale = self.scale.exp()
        sim=scores*scale
        #need this mask,valye from sim variable but that might be hard to calculate
        mask_value = -torch.finfo(sim.dtype).max
        
        #knn_memory=next(self.knn_memories_iter)
        #query_mem=rearrange(query_states, 'b h n d -> b h n d') # let's make sure queries states is as expected for the search. 
        mem_kv, mem_mask = self.knn_memory.search(query_states, self.num_retrieved_memories)
        mem_k, mem_v = mem_kv.unbind(dim = -2)

        sim_mem = einsum('b h i d, b h i j d -> b h i j', query_states, mem_k) *scale
        sim_mem = sim_mem.masked_fill(~mem_mask, mask_value)
        #print("see how much memory retrieved is not useless vectors:",mem_mask)# this line will tell how much memory retrieved is not useless vectors.
        #print("To see how much memory is used in scale:",scale)
        # attention (combining local and distant)

        #sim = torch.cat((sim_mem, sim), dim = -1)
        mem_attn = stable_softmax(sim_mem)

        mem_out = einsum('b h i j, b h i j d -> b h i d', mem_attn, mem_v)

        mem_out = rearrange(mem_out, 'b h n d -> b n (h d)')
        #mem_out=torch.rand(attn_output.shape)#

        #where does this mem_out get added? 
        
        attn_output=attn_output+ mem_out
        
        attn_output = self.o(attn_output)


        present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
        outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)

        if output_attentions:
            outputs = outputs + (attn_weights,)
        return outputs