import os

import torch
import time
import torch.nn.functional as F
import einops
import matplotlib.pyplot as plt


import sys
from pathlib import Path
sys.path.append("/workspace")

from manar import Manar, Contextualization, append_manar_to_vit
from timm.layers.attention import Attention
from timm.models import create_model


def create_sliding_window_matrix(n: int, m: int, device) -> torch.Tensor:
    """
    Creates an (n, m) matrix with sliding window values efficiently.

    Args:
        n: The number of rows.
        m: The number of columns. Must be an even number.

    Returns:
        A torch.Tensor of shape (n, m) with the specified values.
    """
    # Ensure m is even as per the problem description
    if m % 2 != 0:
        raise ValueError("m must be an even integer.")

    # 1. Create the base row: [-m/2+1, -m/2+2, ..., m/2]
    # torch.arange is highly optimized for creating integer sequences [8]
    start_val = -m // 2 + 1
    base_row = torch.arange(start_val, start_val + m).to(device)

    # 2. Create the column offsets: [0, 1, ..., n-1]
    # .unsqueeze(1) converts the row tensor of shape [n] to a column of shape [n, 1]
    row_offsets = torch.arange(n).to(device).unsqueeze(1)

    # 3. Add the base row and column offsets.
    # PyTorch uses broadcasting to add the row_offsets to each column of the base_row.
    # This is a highly efficient, vectorized operation [2][4].
    matrix = base_row + row_offsets
    
    return matrix


def _pad(x: torch.Tensor, context_window_len: int, pad_value: int = 0, pad_dim: int = -1) -> torch.Tensor:
    """Pad and unfold the input tensor to create a context window."""
    assert context_window_len % 2 == 0, "Context window length must be either 1 or even for symmetric padding"
    # pad the input tensor
    if pad_dim < 0:
        pad_dim = x.ndim + pad_dim
    zeros_to_append = 2*(x.ndim - pad_dim - 1)
    pad = (0,) * zeros_to_append + (context_window_len // 2 - 1, context_window_len // 2)
    x = F.pad(x, pad, value=pad_value)
    return x


def _pad_and_unfold(x: torch.Tensor, context_window_len: int, pad_value: int = 0, pad_dim: int = -1) -> torch.Tensor:
    assert context_window_len > 0, "Context window length must be positive"
    if context_window_len > 1:
        # unfold the tensor to create a sliding window
        x = _pad(x, context_window_len, pad_value, pad_dim)
        # unfold the tensor along the specified dimension
        x = x.unfold(
            dimension=pad_dim,
            size=context_window_len,
            step=1
        )
    else:
        x = x.unsqueeze(-1)  # if context_window_len is 1, just add a dimension
    return x # shape (..., cwl)


def _make_cat_compatible(A: torch.Tensor, B: torch.Tensor):
    shared = torch.max(torch.tensor(A.shape[:-1]), torch.tensor(B.shape[:-1]))
    A_new_shape = torch.cat([shared, torch.tensor(A.shape[-1]).reshape(1)])
    B_new_shape = torch.cat([shared, torch.tensor(B.shape[-1]).reshape(1)])
    return A.expand(A_new_shape.tolist()), B.expand(B_new_shape.tolist())


def create_mask(score: torch.Tensor):
    # The tensor `score` is expected to be of shape (..., n, m) where n is the number of queries and m is the number of keys.
    sequence_length = score.shape[-2]
    sliding_window = create_sliding_window_matrix(score.shape[-2], score.shape[-1], device=score.device) # shape (..., n, m)
    mask = torch.where(sliding_window < 0, True, False) # mask out the negative values
    mask = torch.where(sliding_window >= sequence_length, True, mask) # mask out the values greater than sequence length

    # unsqueeze the mask to match the ndim of the score tensor [1, ..., 1, n, m]
    new_shape = (1,) * (score.ndim - 2) + mask.shape
    mask = mask.reshape(new_shape)
    return mask


def base_attn(
    q: torch.Tensor, # tensor of shape (..., n, d)
    k: torch.Tensor, # tensor of shape (..., n, d)
    v: torch.Tensor, # tensor of shape (..., n, d)
):
    scale = 1 / (q.shape[-1] ** 0.5)  # scale factor for attention scores
    q = q * scale
    attn = q @ k.transpose(-2, -1)
    attn = attn.softmax(dim=-1)
    x = attn @ v
    return x


def base_concatenated_attention(
    queries: torch.Tensor, # tensor of shape (..., n, d)
    self_keys: torch.Tensor, # tensor of shape (..., n, d)
    self_values: torch.Tensor, # tensor of shape (..., n, d)
    external_keys: torch.Tensor, # tensor of shape (..., m, d)
    external_values: torch.Tensor, # tensor of shape (..., m, d)
    context_window_len: int = 1, # length of the context window for local attention
):
    head_dim = queries.shape[-1] # d
    # set scale as the sqrt of head_dim
    scale = 1 / (head_dim ** 0.5)
    # all shapes must be broadcast compatible before passed to this function
    self_values = _pad_and_unfold(self_values, context_window_len, pad_dim=-2, pad_value=0) # (..., n, d, cwl)
    self_keys = _pad_and_unfold(self_keys, context_window_len, pad_dim=-2, pad_value=0) # (..., n, d, cwl)
    # make self_keys and queries compatible for einsum
    self_score = einops.einsum(queries, self_keys, '... n d, ... n d cwl -> ... n cwl') * scale # (..., n, cwl) * scale
    if context_window_len > 1:
        mask = create_mask(self_score) # (..., n, cwl)
        self_score = self_score.masked_fill(mask, float('-inf')) # mask out the negative values

    external_score = einops.einsum(queries, external_keys, '... n d, ... m d -> ... n m') * scale #(..., n, m)

    # make self_score and external_score compatible for concatenation
    self_score, external_score = _make_cat_compatible(self_score, external_score) # (..., n, cwl) and (..., n, m)
    score = torch.cat([self_score, external_score], dim=-1) # (..., n, cwl + m)
    score = score.softmax(dim=-1) # (..., n, cwl + m)
    self_score, external_score = torch.split(score, [context_window_len, external_keys.shape[-2]], dim=-1) # (...n,cwl) and (...,n,m)

    # calculate the response
    self_response = einops.einsum(self_score, self_values, '... n cwl, ... n d cwl -> ... n d') # (..., n, d)
    external_response = einops.einsum(external_score, external_values, '... n m, ... m d -> ... n d') # (..., n, d)
    return self_response + external_response


def measure(func, *args):
    """
    Measures the execution time and peak memory usage of a function.

    Args:
        func: The function to measure.
        *args: The arguments to pass to the function.

    Returns:
        A tuple containing the execution time in seconds and the peak memory usage in MB.
    """
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    start_time = time.time()
    result = func(*args)
    torch.cuda.synchronize()  # Ensure all operations are complete
    end_time = time.time()

    memory = torch.cuda.max_memory_allocated() / (1024 ** 2)  # Convert bytes to MB
    return end_time - start_time, memory, result


def correctness():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ns = [197]
    m = 128
    d = 64
    context_window_len = 96
    batch_size = 128
    num_heads = 12

    for n in ns:
        queries = torch.randn(batch_size, num_heads,n, d).to(device).requires_grad_()
        self_keys = torch.randn(batch_size, num_heads, n, d).to(device).requires_grad_()
        self_values = torch.randn(batch_size, num_heads, n, d).to(device).requires_grad_()
        external_keys = torch.randn(batch_size, num_heads, m, d).to(device).requires_grad_()
        external_values = torch.randn(batch_size, num_heads, m, d).to(device).requires_grad_()

        dO = torch.randn_like(queries)  # Gradient of the output with respect to queries

        ref = base_concatenated_attention(queries, self_keys, self_values, external_keys, external_values, context_window_len)
        ref.backward(dO)
        ref_dq, queries.grad = queries.grad.clone(), None  # Clone the gradient to avoid in-place modification
        ref_dk, self_keys.grad = self_keys.grad.clone(), None
        ref_dv, self_values.grad = self_values.grad.clone(), None
        ref_dek, external_keys.grad = external_keys.grad.clone(), None
        ref_dev, external_values.grad = external_values.grad.clone(), None


        O = Contextualization.apply(queries, self_keys, self_values, external_keys, external_values, context_window_len, False, torch.zeros(512,512).to(device) == 0)

        # Check that O is allclose to queries
        if not torch.allclose(O, ref, atol=1e-2):
            print(f"---- Wrong for sequence length {n}")
        else:
            print(f"++++ Correct")

        O.backward(dO)
        dq, queries.grad = queries.grad.clone(), None  # Clone the gradient to avoid in-place modification
        dk, self_keys.grad = self_keys.grad.clone(), None
        dv, self_values.grad = self_values.grad.clone(), None
        dek, external_keys.grad = external_keys.grad.clone(), None
        dev, external_values.grad = external_values.grad.clone(), None

        samples = 10

        # Check that dq is allclose to ref_dq
        if not torch.allclose(dq, ref_dq, atol=1e-2):
            print(f"---- Wrong dq for sequence length {n}")
            nonzeros = torch.nonzero(torch.abs(dq - ref_dq) > 1e-2)
            print(nonzeros[-samples:, :])
            print(f"Number of non-zero differences: {nonzeros.shape[0]}")
            print(f"ref_dq:")
            # print the 100 first nonzero elements
            print(ref_dq[nonzeros[:samples, 0], nonzeros[:samples, 1], nonzeros[:samples, 2], nonzeros[:samples, 3]])
            print(f"dq:")
            print(dq[nonzeros[:samples, 0], nonzeros[:samples, 1], nonzeros[:samples, 2], nonzeros[:samples, 3]])
        else:
            print(f"++++ Correct dq for sequence length {n}")

        # Check that dk is allclose to ref_dk
        if not torch.allclose(dk, ref_dk, atol=1e-2):
            print(f"---- Wrong dk for sequence length {n}")
            nonzeros = torch.nonzero(torch.abs(dk - ref_dk) > 1e-2)
            print(nonzeros[-samples:, :])
            print(f"Number of non-zero differences: {nonzeros.shape[0]}")
            print(f"ref_dk:")
            # print the 100 first nonzero elements
            print(ref_dk[nonzeros[:samples, 0], nonzeros[:samples, 1], nonzeros[:samples, 2], nonzeros[:samples, 3]])
            print(f"dk:")
            print(dk[nonzeros[:samples, 0], nonzeros[:samples, 1], nonzeros[:samples, 2], nonzeros[:samples, 3]])
        else:
            print(f"++++ Correct dk for sequence length {n}")

        # Check that dv is allclose to ref_dv
        if not torch.allclose(dv, ref_dv, atol=1e-2):
            print(f"---- Wrong dv for sequence length {n}")
            nonzeros = torch.nonzero(torch.abs(dv - ref_dv) > 1e-2)
            print(nonzeros[-samples:, :])
            print(f"Number of non-zero differences: {nonzeros.shape[0]}")
            print(f"ref_dv:")
            # print the 100 first nonzero elements
            print(ref_dv[nonzeros[:samples, 0], nonzeros[:samples, 1], nonzeros[:samples, 2], nonzeros[:samples, 3]])
            print(f"dv:")
            print(dv[nonzeros[:samples, 0], nonzeros[:samples, 1], nonzeros[:samples, 2], nonzeros[:samples, 3]])
        else:
            print(f"++++ Correct dv for sequence length {n}")

        # Check that dek is allclose to ref_dek
        if not torch.allclose(dek, ref_dek, atol=1e-2):
            print(f"---- Wrong dek for sequence length {n}")
            nonzeros = torch.nonzero(torch.abs(dek - ref_dek) > 1e-2)
            print(nonzeros[-samples:, :])
            print(f"Number of non-zero differences: {nonzeros.shape[0]}")
            print(f"ref_dek:")
            # print the 100 first nonzero elements
            print(ref_dek[nonzeros[:samples, 0], nonzeros[:samples, 1], nonzeros[:samples, 2], nonzeros[:samples, 3]])
            print(f"dek:")
            print(dek[nonzeros[:samples, 0], nonzeros[:samples, 1], nonzeros[:samples, 2], nonzeros[:samples, 3]])
        else:
            print(f"++++ Correct dek for sequence length {n}")

        # Check that dev is allclose to ref_dev
        if not torch.allclose(dev, ref_dev, atol=1e-2):
            print(f"---- Wrong dev for sequence length {n}")
            nonzeros = torch.nonzero(torch.abs(dev - ref_dev) > 1e-2)
            print(nonzeros[-samples:, :])
            print(f"Number of non-zero differences: {nonzeros.shape[0]}")
            print(f"ref_dev:")
            # print the 100 first nonzero elements
            print(ref_dev[nonzeros[:samples, 0], nonzeros[:samples, 1], nonzeros[:samples, 2], nonzeros[:samples, 3]])
            print(f"dev:")
            print(dev[nonzeros[:samples, 0], nonzeros[:samples, 1], nonzeros[:samples, 2], nonzeros[:samples, 3]])
        else:
            print(f"++++ Correct dev for sequence length {n}")


def perf_single(manar=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ns = [128, 256, 512, 1024, 1024+512, 2048, 2048+512, 2048+1024, 2048+1536, 4096]
    num_memory_cells = 256
    m = 32
    d = 64
    context_window_lens = [n // 2 for n in ns]
    batch_sizes = [1024, 1024, 512, 512, 256, 128, 128, 64, 64, 32]
    num_heads = 12

    for n, context_window_len, batch_size in zip(ns, context_window_lens, batch_sizes):
        x = torch.randn(batch_size, n, num_heads * d).to(device)

        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        if not manar:
            layer = Attention(
                dim=num_heads * d,
                num_heads=num_heads,
            ).to(device)
        else:
            layer = Manar(
                dim=num_heads * d,
                num_heads=num_heads,
                num_memory_cells=num_memory_cells,
                conceptual_representation_size=m,
                context_window_len=context_window_len,
            ).to(device)

        with torch.no_grad():
            layer.eval()
            measure(layer, x)

        with torch.no_grad():
            layer.eval()
            time, memory, _ = measure(layer, x)

        print("~~~~~~~~~~~~~~~~~~~~")
        print(f"Model: {'MANAR' if manar else 'MHA'}, Sequence length: {n}")
        print(f"\tThroughput: {batch_size / time:.2f} Req/s")
        print(f"\tTime per request: {time / batch_size * 1000:.2f} ms")
        print(f"\tMemory per request: {memory / batch_size:.2f} MB/Req")


def perf_model(max_batch_sizes, manar=False, model="s"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #ns = [128, 256, 512, 1024, 1024+512, 2048, 2048+512, 2048+1024, 2048+1536, 4096]
    resolutions = [0,1,2,3,4,5,6,7,8,9,10]
    resolutions = [128*res + 256 for res in resolutions]
    num_memory_cells = 256
    m = 32

    model = create_model(
        'deit_small_patch16_224' if model=="s" else 'deit_tiny_patch16_224',
        pretrained=False,
        num_classes=1000,
        drop_rate=0.0,
        drop_path_rate=0.0,
    )

    for res, batch_size in zip(resolutions, max_batch_sizes):
        x = torch.randn(batch_size, 3, res, res).to(device)
        cwl = ((res * res) // 256) // 2 # every patch is of size 16x16, so there are (res*res)/256 patches
        if cwl % 2 != 0:
            cwl += 1

        model.set_input_size(res, 16)

        if manar:
            model, _ = append_manar_to_vit(model, M=num_memory_cells, m=m, cwl=cwl)

        model = model.to(device)

        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

        # warmup stage
        with torch.no_grad():
            model.eval()
            measure(model, x.to(device))
        print("~~~~~~~~~~~~~~~~~~~~")

        # Measure performance of the first implementation
        time, memory, _ = measure(model, x.to(device))
        #print(f"Done measuring for sequence length {n}, manar R/s: {batch_size / time1:.2f} Req/s, base attn time: {batch_size / time3:.2f} Req/s, manar memory/Req: {memory1 / batch_size:.2f} MB/Req, base attn memory: {memory3 / batch_size:.2f} MB/Req")
        print(f"Model: {'MANAR' if manar else 'DeiT'}, Resolution: {res}x{res}")
        print(f"\tThroughput: {batch_size / time:.2f} Req/s")
        print(f"\tTime per request: {time / batch_size * 1000:.2f} ms")
        print(f"\tMemory per request: {memory / batch_size:.2f} MB/Req")


if __name__ == "__main__":
    max_batch_sizes_deit_t = [256,256,128,64,32,16,8,8,8,4,1]
    # MHA small
    max_batch_sizes_deit_s = [256,128,64,16,16,8,8,4,4,2,1]
    # MANAR tiny
    max_batch_sizes_deit_manar_t = [512,512,256,128,128,128,64,64,32,32]
    # MANAR small
    max_batch_sizes_deit_manar_s = [512,256,128,64,64,64,32,32,32,32]
    correctness()
    perf_model(max_batch_sizes_deit_s, manar=False, model="s")
    perf_model(max_batch_sizes_deit_manar_s, manar=True, model="s")
    perf_single(manar=False)
    perf_single(manar=True)
