
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class Attention(nn.Module):
    def __init__(self, dim, 
            heads=8, 
            dim_head=64, 
            dropout=0., model_type=None):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.attend = nn.Softmax(dim = -1)

        self.params = nn.Parameter(torch.eye(dim).unsqueeze(0).repeat(heads, 1, 1))
        self.attn_coeff = nn.Parameter(torch.ones(heads, 8, 1))
        torch.nn.init.normal_(self.attn_coeff, std=.02)

        self.model_type = model_type
    
    def get_dict(self, x, target="k"):
        if target == "k":
            params = self.params
            atoms = torch.einsum('h p q, b n q -> b h n p', params, x)
        elif target == "v":
            params = torch.eye(x.shape[-1]).unsqueeze(0).repeat(self.heads, 1, 1).to(x.device)
            atoms = torch.einsum('h p q, b n q -> b h n p', params, x)
        else:
            raise NotImplementedError
        return atoms

    def encode(self, x, wavelet):
        return torch.einsum('b n p, b h m p -> b h n m', x, wavelet)
    
    def decode(self, x, wavelet):
        out = torch.einsum('b h n m, b h m p -> b h n p', x, wavelet)
        out = torch.sum(out, dim=1)
        return out

    def soft_threshold(self, z, threshold):
        """Applies element-wise soft-thresholding."""
        return torch.sign(z) * F.relu(torch.abs(z) - threshold)
    
    def lifting_attn(self, attn):
        q = attn.shape[-1] // 3
        attn = rearrange(attn, "b c (h p) (w q) -> b c (h w) (p q)", h=3, w=3)
        alpha = attn[:, :, :8]
        alpha_p = torch.einsum('b c n d, c n m -> b c m d', alpha, self.attn_coeff)
        mask = torch.zeros_like(attn)
        mask[:, :, 8:] = alpha_p
        attn = attn + mask
        attn = rearrange(attn, "b c (h w) (p q) -> b c (h p) (w q)", h=3, q=q)
        return attn

    def forward(self, x):
        res = torch.zeros_like(x).detach().clone()
        indices = torch.tensor(list(range(0, 8)) + list(range(12, 20)) + list(range(24, 32)))  # shape (24,)
        index_cond = indices.view(1, -1, 1).expand(x.size(0), -1, x.size(2)).to(x.device)  # shape: (128, 24, 16)
        x_cond = torch.gather(x, dim=1, index=index_cond)
        indices = torch.tensor(list(range(8, 12)) + list(range(20, 24)) + list(range(32, 36)))  # shape (24,)
        index_tgt = indices.view(1, -1, 1).expand(x.size(0), -1, x.size(2)).to(x.device)  # shape: (128, 24, 16)
        x_tgt = torch.gather(x, dim=1, index=index_tgt)
        res[:, 8:12] = res[:, 8:12] + x[:, 8:12]
        res[:, 20:24] = res[:, 20:24] + x[:, 20:24]
        
        k_atoms = self.get_dict(x_cond, "k")
        attn = self.encode(x_tgt, k_atoms)

        if self.model_type == "sparse":
            attn = self.soft_threshold(attn, 5)
            attn = self.lifting_attn(attn)
        else:
            attn = self.attend(attn * self.scale)
        sparsity_of_coeff = torch.where(torch.abs(attn) < 0.000001, 1, 0).sum().detach().item() / attn.numel()        
        
        v_atoms = self.get_dict(x_cond, "v")
        out = self.decode(attn, v_atoms)

        # Create an empty tensor to reconstruct A (or clone from original shape)
        x_recon = torch.zeros_like(x)
        x_recon = x_recon.scatter(dim=1, index=index_tgt, src=out)
        out = x_recon
        out -= res
        
        ret_dict = {
            "attn_x": out,
            "attn_sparsity": sparsity_of_coeff,
            "attn_map": attn,
        }
        return ret_dict


class Transformer(nn.Module):
    def __init__(self, seq_len, d_model, n_heads, num_layers, patch_size=20, expand_ratio=1, dropout=0., model_type=None):
        super().__init__()
        # # Positional encoding
        self.patch_size = patch_size #16 # 80 #16
        self.n_channel = 9
        self.feat_dim = self.patch_size ** 2
        seq_len += 4 # for approximation token

        expand_ratio = expand_ratio #4
        self.layers = nn.ModuleList([])
        self.num_layers = num_layers
        dim_head = d_model // n_heads * expand_ratio
        for _ in range(num_layers):
            self.layers.append(nn.ModuleList([
                Attention(d_model, heads=n_heads, dim_head=dim_head, dropout=dropout, model_type=model_type)
            ]))

        # approximation token
        self.approx_token = nn.Parameter(torch.ones(1, 4, d_model))
        torch.nn.init.normal_(self.approx_token, std=.1)

        self.ret_dict = {}
    
    def patchify(self, x):
        """
        imgs: (N, C, H, W)
        x: (N, L, patch_size**2 *C)
        """
        _, _, self.img_h, _ = x.shape
        x = rearrange(x, "b c (h p) (w q) -> b (c h w) (p q)", p=self.patch_size, q=self.patch_size)
        return x
    
    def unpatchify(self, x):
        """
        x: (N, L, patch_size**2 *C)
        imgs: (N, C, H, W)
        """
        x = rearrange(x, "b (c h w) (p q) -> b c (h p) (w q)", w=2, q=self.patch_size, c=self.n_channel)
        return x

    def forward(self, x):
        # Apply embedding and add positional encoding
        # x: (b, 9, h, w)
        x = self.patchify(x)

        ## include approximation token
        approx_token = self.approx_token.expand(x.shape[0], -1, -1)
        x = torch.cat((x, approx_token), dim=1)
        
        sparsity = 0
        for i in range(self.num_layers):
            attn = self.layers[0][0]
            x_attn = attn(x)
            x = x_attn["attn_x"] + x
            sparsity += x_attn["attn_sparsity"]
            self.ret_dict["attn_map"] = x_attn["attn_map"]
        sparsity /= self.num_layers
        x = self.unpatchify(x)

        return x, sparsity
