import time

import torch
import numpy as np
import os

import torch.nn as nn


# TODO: Remove fix_memory and memory_update_steps

class ExternalMemory(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dimension = config.hidden_size  # d
        self.memory_size = config.memory_size  # M
        self.memory_insert_layers = config.memory_insert_layers
        self.num_mem_modules = len(self.memory_insert_layers)
        self.use_gpu_to_search = config.use_gpu_to_search
        self.layer_index_map, self.layer_dstore_idx, self.memory_full, self.memory_accumulate_steps = self.build_layer_index_map()

        if self.use_gpu_to_search:
            self.memory_modules = nn.Parameter(
                torch.zeros(self.num_mem_modules, self.memory_size, self.dimension, dtype=torch.float16,
                            device=torch.cuda.current_device()))
            self.update_memory_modules = nn.Parameter(
                torch.zeros(self.num_mem_modules, self.memory_size, self.dimension, dtype=torch.float16,
                            device=torch.cuda.current_device()))
        else:
            raise NotImplementedError

        self.memory_modules.requires_grad = False
        self.update_memory_modules.requires_grad = False

        self.dstore_idx = 0
        self.fix_memory = config.fix_memory
        self.memory_update_steps = config.memory_update_steps

    def init_from_hidden(self, mem_hidden_path):
        for i in range(self.num_mem_modules):
            layer_idx = self.memory_insert_layers[i]
            mem_hidden = np.load(os.path.join(mem_hidden_path, f"{layer_idx}.npy"))
            # adjust mem_hidden size to (memory_size, dimension)
            if mem_hidden.shape[0] > self.memory_size:
                mem_hidden = mem_hidden[:self.memory_size]  # truncate
            elif mem_hidden.shape[0] < self.memory_size:
                mem_hidden = np.concatenate(
                    (mem_hidden, np.zeros((self.memory_size - mem_hidden.shape[0], self.dimension), dtype=np.float16)),
                    axis=0)

            self.memory_modules[i] = torch.tensor(mem_hidden, dtype=torch.float16,
                                                  device=torch.cuda.current_device()).squeeze(1)
            self.update_memory_modules[i] = torch.tensor(mem_hidden, dtype=torch.float16,
                                                         device=torch.cuda.current_device()).squeeze(1)
            self.memory_full[layer_idx] = True
            self.layer_dstore_idx[layer_idx] = self.memory_size

    def reset(self):
        self.dstore_idx = 0
        if self.use_gpu_to_search:
            self.memory_modules = nn.Parameter(
                torch.zeros(self.num_mem_modules, self.memory_size, self.dimension, dtype=torch.float16,
                            device=torch.cuda.current_device()))
            self.memory_modules.requires_grad = False

    def build_layer_index_map(self):
        layer_index_map, layer_dstore_idx, memory_full, memory_accumulation_steps = dict(), dict(), dict(), dict()
        for i in range(self.num_mem_modules):
            layer_index_map[self.memory_insert_layers[i]] = i
            layer_dstore_idx[self.memory_insert_layers[i]] = 0
            memory_full[self.memory_insert_layers[i]] = False
            memory_accumulation_steps[self.memory_insert_layers[i]] = 0

        return layer_index_map, layer_dstore_idx, memory_full, memory_accumulation_steps

    def update_to_memory_modules(self):
        self.memory_modules.data.copy_(self.update_memory_modules.data)

    def retrieve(self, layer_idx):
        # TODO: Remove this func


        raise NotImplementedError('Deprecated func.')

    def update_memory(self, layer_idx, hidden_states):
        # TODO: Remove this func

        raise NotImplementedError('Deprecated func.')


class AttentionExternalMemory(ExternalMemory):
    def __init__(self, config):
        nn.Module.__init__(self)
        self.dimension = config.hidden_size  # d
        self.memory_size = config.memory_size  # M

        self.memory_insert_layers = config.memory_insert_layers
        self.num_mem_modules = len(self.memory_insert_layers)
        self.use_gpu_to_search = config.use_gpu_to_search
        self.layer_index_map, self.layer_dstore_idx, self.memory_full, self.memory_accumulate_steps = self.build_layer_index_map()

        self.update_indices_buffer = {idx: [] for idx in self.memory_insert_layers}

        if self.use_gpu_to_search:
            # MeM: Modify the following codes for attention-score based retrieval
            self.memory_modules = nn.Parameter(
                torch.zeros(self.num_mem_modules, self.memory_size, dtype=torch.float16,
                            device=torch.cuda.current_device()))
            self.update_memory_modules = nn.Parameter(
                torch.zeros(self.num_mem_modules, self.memory_size, dtype=torch.float16,
                            device=torch.cuda.current_device()))
        else:
            raise NotImplementedError

        self.memory_modules.requires_grad = False
        self.update_memory_modules.requires_grad = False

        self.dstore_idx = 0
        self.fix_memory = config.fix_memory
        self.memory_update_steps = config.memory_update_steps


def retrieve_layer_memory(layer_idx: int, external_memory_v: ExternalMemory,
                          external_memory_k: ExternalMemory = None, ):
    idx_v = external_memory_v.layer_index_map[layer_idx]
    memory_v = external_memory_v.memory_modules[idx_v]

    if external_memory_k is not None:
        idx_k = external_memory_k.layer_index_map[layer_idx]
        memory_k = external_memory_k.memory_modules[idx_k]
    else:
        memory_k = None

    return memory_v, memory_k


def update_memory_fifo(layer_idx, external_memory: ExternalMemory, hidden_states):
    # First IN First OUT
    # hidden_states: [bz, 1, hidden_dim]
    bsz, _, hidden_dim = hidden_states.shape
    idx = external_memory.layer_index_map[layer_idx]
    dstore_idx = external_memory.layer_dstore_idx[layer_idx]

    with torch.no_grad():
        external_memory.update_memory_modules.data.copy_(external_memory.memory_modules.data)

    memory = external_memory.update_memory_modules[idx]

    # MeM: Test fix memory
    if external_memory.fix_memory and external_memory.memory_full[layer_idx] is True:
        return

    if dstore_idx + bsz > external_memory.memory_size:
        external_memory.memory_full[layer_idx] = True
        update_size = (dstore_idx + bsz) - external_memory.memory_size
        dstore_idx = dstore_idx - update_size

        # Shift memory FIFO-style
        with torch.no_grad():
            temp = memory[update_size:, :].clone()  # Make a separate copy
            memory[:-update_size, :].copy_(temp)  # Now this is non-overlapping
            memory[-update_size:, :] = 0.0  # zero out the tail

    with torch.no_grad():
        memory[dstore_idx: dstore_idx + bsz, :] = hidden_states.squeeze(1)

    dstore_idx += bsz
    external_memory.layer_dstore_idx[layer_idx] = dstore_idx

    del hidden_states

def update_attn_weights(external_memory: AttentionExternalMemory, layer_idx, mem_attn_weights):
    with torch.no_grad():
        current_avg_mem_attn_weights = mem_attn_weights.mean(0)  # MeM: Average over batch (memory_size,)

        idx = external_memory.layer_index_map[layer_idx]
        previous_avg_mem_attn_weights = external_memory.update_memory_modules[idx]

        memory_accumulate_steps = external_memory.memory_accumulate_steps[layer_idx]
        external_memory.update_memory_modules[idx] = (previous_avg_mem_attn_weights * memory_accumulate_steps + current_avg_mem_attn_weights) / (memory_accumulate_steps + 1)

        external_memory.memory_accumulate_steps[layer_idx] = memory_accumulate_steps + 1


def update_memory_attn(layer_idx, external_memory_attn: AttentionExternalMemory, external_memory: ExternalMemory,
                       hidden_states):
    # MeM: Update the external memory by dropping memory tokens with the smallest attention weights.
    idx_mem = external_memory.layer_index_map[layer_idx]
    bsz, _, hidden_dim = hidden_states.shape

    with torch.no_grad():
        external_memory.update_memory_modules.data.copy_(external_memory.memory_modules.data)
    memory = external_memory.update_memory_modules[idx_mem]  # MeM: shape: [memory_size, hidden_dim]

    idx_attn = external_memory_attn.layer_index_map[layer_idx]
    attn_weights = external_memory_attn.update_memory_modules[idx_attn]  # MeM: Attention weights (shape: [memory_size])

    memory_size = external_memory.memory_size

    total_update_size = int(memory_size * 0.2)  # MeM: update 20% of memory bank

    # If no update indices are cached yet, find the smallest attention weights
    if len(external_memory_attn.update_indices_buffer[layer_idx]) == 0:
        # MeM: Identify tokens with the smallest attention weights
        _, smallest_indices = torch.topk(attn_weights, k=total_update_size, largest=False)  # MeM: Smallest weights
        # Store them in update_indices_buffer
        external_memory_attn.update_indices_buffer[layer_idx] = smallest_indices.cpu().tolist()

    update_size = min(len(external_memory_attn.update_indices_buffer[layer_idx]), bsz)  # MeM: Limit to batch size if smaller
    indices_to_replace = external_memory_attn.update_indices_buffer[layer_idx][:update_size].copy()
    indices_to_replace = torch.tensor(indices_to_replace, device=hidden_states.device)

    # if torch.distributed.get_rank() == 0:
    #     print('IN PROGRESS:')
    #     print(external_memory_attn.update_indices_buffer[layer_idx])

    # MeM: Replace identified tokens
    with torch.no_grad():
        hidden_states = hidden_states.squeeze(1)  # MeM: Remove singleton dimension (shape: [batch_size, hidden_dim])
        memory[indices_to_replace, :] = hidden_states[:update_size, :]

    # Remove used indices
    external_memory_attn.update_indices_buffer[layer_idx] = external_memory_attn.update_indices_buffer[layer_idx][update_size:]

    # Reset memory_accumulate_steps if we’ve replaced everything    if len(external_memory_attn.update_indices_buffer[layer_idx]) == 0:
    external_memory_attn.memory_accumulate_steps[layer_idx] = 0

    del hidden_states


# def update_memory_fifo(layer_idx, external_memory: ExternalMemory, hidden_states):
#     # First IN First OUT
#     # hidden_states: [bz, 1, hidden_dim]
#     bsz, _, hidden_dim = hidden_states.shape
#     idx = external_memory.layer_index_map[layer_idx]
#     memory = external_memory.memory_modules[idx].clone()
#     dstore_idx = external_memory.layer_dstore_idx[layer_idx]
#
#     # MeM: Test fix memory
#     if external_memory.fix_memory and external_memory.memory_full[layer_idx] is True:
#         return
#
#     if dstore_idx + bsz >= external_memory.memory_size:
#         external_memory.memory_full[layer_idx] = True
#         update_size = dstore_idx + bsz - external_memory.memory_size
#         dstore_idx = dstore_idx - update_size
#         tmp = torch.zeros(update_size, external_memory.dimension, dtype=torch.float16,
#                           device=torch.cuda.current_device())
#         memory = torch.cat((memory[update_size:, :], tmp))
#         del tmp
#
#     memory[dstore_idx: dstore_idx + bsz, :] = hidden_states.squeeze(1)
#     new_memory = nn.Parameter(memory)
#     # new_memory.requires_grad = False
#
#     dstore_idx += bsz
#     external_memory.layer_dstore_idx[layer_idx] = dstore_idx
#     external_memory.update_memory_modules[idx] = new_memory.detach()
#
#     del new_memory, memory, hidden_states
