import torch
import torch.nn as nn
import numpy as np
from math import sqrt
from utils.masking import TriangularCausalMask, TimerMultivariateMask, TimerCovariateMask
from einops import repeat
from layers.Attn_Projection import QueryKeyProjection, RotaryProjection
from layers.Attn_Bias import BinaryAttentionBias


class FullAttention(nn.Module):
    def __init__(self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False):
        super(FullAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, queries, keys, values, attn_mask, n_vars=None, n_tokens=None, tau=None, delta=None):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1. / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)

        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device=queries.device)

            scores.masked_fill_(attn_mask.mask, -np.inf)

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)

        if self.output_attention:
            return V.contiguous(), A
        else:
            return V.contiguous(), None

class TimeAttention(nn.Module):
    def __init__(self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False, d_model=512, num_heads=8, max_len=100, covariate=False, flash_attention=False):
        super(TimeAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)
        self.covariate = covariate
        self.flash_attention = flash_attention
        self.qk_proj = QueryKeyProjection(dim=d_model, num_heads=num_heads, proj_layer=RotaryProjection, kwargs=dict(max_len=max_len),
                                          partial_factor=(0.0, 0.5),)
        self.attn_bias = BinaryAttentionBias(dim=d_model, num_heads=num_heads)

        # Initialize keys and values cache as None
        self.cached_keys = None
        self.cached_values = None

    def forward(self, queries, keys, values, attn_mask, n_vars, n_tokens, tau=None, delta=None, use_kv_cache=False):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape

        # If using KV cache, concatenate the cached keys/values with the new ones
        if use_kv_cache and self.cached_keys is not None:
            keys = torch.cat([self.cached_keys, keys], dim=1)  # Concatenate across time dimension
            values = torch.cat([self.cached_values, values], dim=1)  # Concatenate across time dimension

        queries = queries.permute(0, 2, 1, 3)
        keys = keys.permute(0, 2, 1, 3)
        if self.flash_attention:
            values = values.permute(0, 2, 1, 3)

        seq_id = torch.arange(n_tokens * n_vars)
        seq_id = repeat(seq_id, 'n -> b h n', b=B, h=H)

        queries, keys = self.qk_proj(
            queries, keys, query_id=seq_id, kv_id=seq_id)

        scale = self.scale or 1. / sqrt(E)

        var_id = repeat(torch.arange(n_vars),
                        'C -> (C n_tokens)', n_tokens=n_tokens)
        var_id = repeat(var_id, 'L -> b h L', b=B, h=1).to(queries.device)

        attn_bias = self.attn_bias(var_id, var_id)

        if self.mask_flag:
            if attn_mask is None:
                if self.covariate:
                    attn_mask = TimerCovariateMask(
                        B, n_vars, n_tokens, device=queries.device)
                else:
                    attn_mask = TimerMultivariateMask(
                        B, n_vars, n_tokens, device=queries.device)
            attn_mask = attn_bias.masked_fill(attn_mask.mask, float("-inf"))
        else:
            attn_mask = attn_bias

        if self.flash_attention:
            V = torch.nn.functional.scaled_dot_product_attention(
                queries, keys, values, attn_mask)
        else:
            scores = torch.einsum("bhle,bhse->bhls", queries, keys)
            scores += attn_mask
            
            A = self.dropout(torch.softmax(scale * scores, dim=-1))
            V = torch.einsum("bhls,bshd->blhd", A, values)

        # Cache the keys and values for future steps
        if use_kv_cache:
            self.cached_keys = keys
            self.cached_values = values

        if self.output_attention:
            return V.contiguous(), None
        else:
            return V.contiguous(), None


class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
        super(AttentionLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask, n_vars=None, n_tokens=None, tau=None, delta=None, use_kv_cache=False):
        # Handle the kv_cache argument here as needed
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, attn = self.inner_attention(
            queries,
            keys,
            values,
            attn_mask,
            n_vars=n_vars,
            n_tokens=n_tokens,
            tau=tau,
            delta=delta,
            use_kv_cache=use_kv_cache  # Ensure use_kv_cache is passed down to the attention layer
        )
        out = out.view(B, L, -1)

        return self.out_projection(out), attn


# class AttentionLayer(nn.Module):
#     def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
#         super(AttentionLayer, self).__init__()

#         d_keys = d_keys or (d_model // n_heads)
#         d_values = d_values or (d_model // n_heads)

#         self.inner_attention = attention
#         self.query_projection = nn.Linear(d_model, d_keys * n_heads)
#         self.key_projection = nn.Linear(d_model, d_keys * n_heads)
#         self.value_projection = nn.Linear(d_model, d_values * n_heads)
#         self.out_projection = nn.Linear(d_values * n_heads, d_model)
#         self.n_heads = n_heads

#     def forward(self, queries, keys, values, attn_mask, n_vars=None, n_tokens=None, tau=None, delta=None):
#         B, L, _ = queries.shape
#         _, S, _ = keys.shape
#         H = self.n_heads

#         queries = self.query_projection(queries).view(B, L, H, -1)
#         keys = self.key_projection(keys).view(B, S, H, -1)
#         values = self.value_projection(values).view(B, S, H, -1)

#         out, attn = self.inner_attention(
#             queries,
#             keys,
#             values,
#             attn_mask,
#             n_vars=n_vars,
#             n_tokens=n_tokens,
#             tau=tau,
#             delta=delta
#         )
#         out = out.view(B, L, -1)

#         return self.out_projection(out), attn



# class TimeAttention(nn.Module):
#     def __init__(self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False, d_model=512, num_heads=8, max_len=100, covariate=False, flash_attention=False):
#         super(TimeAttention, self).__init__()
#         self.scale = scale
#         self.mask_flag = mask_flag
#         self.output_attention = output_attention
#         self.dropout = nn.Dropout(attention_dropout)
#         self.covariate = covariate
#         self.flash_attention = flash_attention
#         self.qk_proj = QueryKeyProjection(dim=d_model, num_heads=num_heads, proj_layer=RotaryProjection, kwargs=dict(max_len=max_len),
#                                           partial_factor=(0.0, 0.5),)
#         self.attn_bias = BinaryAttentionBias(dim=d_model, num_heads=num_heads)

#     def forward(self, queries, keys, values, attn_mask, n_vars, n_tokens, tau=None, delta=None):
#         B, L, H, E = queries.shape
#         _, S, _, D = values.shape

#         # [B, H, L, E]
#         queries = queries.permute(0, 2, 1, 3)
#         keys = keys.permute(0, 2, 1, 3)
#         if self.flash_attention:
#             values = values.permute(0, 2, 1, 3)

#         seq_id = torch.arange(n_tokens * n_vars)
#         seq_id = repeat(seq_id, 'n -> b h n', b=B, h=H)

#         queries, keys = self.qk_proj(
#             queries, keys, query_id=seq_id, kv_id=seq_id)

#         scale = self.scale or 1. / sqrt(E)

#         var_id = repeat(torch.arange(n_vars),
#                         'C -> (C n_tokens)', n_tokens=n_tokens)
#         var_id = repeat(var_id, 'L -> b h L', b=B, h=1).to(queries.device)

#         attn_bias = self.attn_bias(var_id, var_id)

#         if self.mask_flag:
#             if attn_mask is None:
#                 if self.covariate:
#                     attn_mask = TimerCovariateMask(
#                         B, n_vars, n_tokens, device=queries.device)
#                 else:
#                     attn_mask = TimerMultivariateMask(
#                         B, n_vars, n_tokens, device=queries.device)
#             attn_mask = attn_bias.masked_fill(attn_mask.mask, float("-inf"))
#         else:
#             attn_mask = attn_bias

#         if self.flash_attention:
#             V = torch.nn.functional.scaled_dot_product_attention(
#                 queries, keys, values, attn_mask)
#         else:
#             scores = torch.einsum("bhle,bhse->bhls", queries, keys)
#             scores += attn_mask
            
#             A = self.dropout(torch.softmax(scale * scores, dim=-1))
#             V = torch.einsum("bhls,bshd->blhd", A, values)

#         if self.output_attention:
#             return V.contiguous(), None
#         else:
#             return V.contiguous(), None