"""
Scalable, Detailed and Mask-free Universal Photometric Stereo Network (CVPR2023)
# Copyright (c) 2023 Satoshi Ikehata
# All rights reserved.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import math
import torch.nn.init as init

class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, ln=False, attention_dropout = 0.1, dim_feedforward = 512,  
                 q_bucket_size = 1024, k_bucket_size = 2048, attn_mode = 'Normal'):
        super(MultiHeadAttentionBlock, self).__init__()
        if attn_mode == 'Efficient':
            self.q_bucket_size = q_bucket_size
            self.k_bucket_size = k_bucket_size
        self.attn_mode = attn_mode
        self.dim_V = dim_out
        self.dim_Q = dim_in
        self.dim_K = dim_in
        self.num_heads = num_heads
        self.fc_q = nn.Linear(self.dim_Q, self.dim_V, bias=False) # dimin -> dimhidden
        self.fc_k = nn.Linear(self.dim_K, self.dim_V, bias=False) # dimin -> dimhidden
        self.fc_v = nn.Linear(self.dim_K, self.dim_V, bias=False) # dimhidden -> dim
        if ln:
            self.ln0 = nn.LayerNorm(self.dim_Q)
            self.ln1 = nn.LayerNorm(self.dim_V)

        self.dropout_attn = nn.Dropout(attention_dropout)
        self.fc_o1 = nn.Linear(self.dim_V, dim_feedforward, bias=False)
        self.fc_o2 = nn.Linear(dim_feedforward, self.dim_V, bias=False)
        self.dropout1 = nn.Dropout(attention_dropout)
        self.dropout2 = nn.Dropout(attention_dropout)

        # memory efficient attention related parameters
        # can be overriden on forward
        self.q_bucket_size = q_bucket_size
        self.k_bucket_size = k_bucket_size
     

    # memory efficient attention
    def summarize_qkv_chunk(self, q, k, v):
        weight = torch.einsum('b h i d, b h j d -> b h i j', q, k)
        weight_max = weight.amax(dim = -1, keepdim = True).detach()
        weight = weight - weight_max

        exp_weight = self.dropout_attn(weight.exp()) # attention_dropout
        weighted_value = torch.einsum('b h i j, b h j d -> b h i d', exp_weight, v)
        return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...')

    def memory_efficient_attention(
        self,
        q, k, v,
        q_bucket_size = 512,
        k_bucket_size = 1024,
        eps = 1e-8,
    ):
        scale = q.shape[-1] ** -0.5
        q = q * scale

        summarize_qkv_fn = self.summarize_qkv_chunk

        # chunk all the inputs
        q_chunks = q.split(q_bucket_size, dim = -2)
        k_chunks = k.split(k_bucket_size, dim = -2)
        v_chunks = v.split(k_bucket_size, dim = -2)

        # loop through all chunks and accumulate
        values = []
        weights = []
        for q_chunk in q_chunks:
            exp_weights = []
            weighted_values = []
            weight_maxes = []  
            for (k_chunk, v_chunk) in zip(k_chunks, v_chunks):
                exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn(
                    q_chunk,
                    k_chunk,
                    v_chunk
                )
                exp_weights.append(exp_weight_chunk)
                weighted_values.append(weighted_value_chunk)
                weight_maxes.append(weight_max_chunk)
            weight_maxes = torch.stack(weight_maxes, dim = -1)
            weighted_values = torch.stack(weighted_values, dim = -1)
            exp_weights = torch.stack(exp_weights, dim = -1)

            global_max = weight_maxes.amax(dim = -1, keepdim = True)
            renorm_factor = (weight_maxes - global_max).exp().detach()

            exp_weights = exp_weights * renorm_factor
            weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c')         

            all_values = weighted_values.sum(dim = -1)
            all_weights = exp_weights.sum(dim = -1)
            values.append(all_values)
            weights.append(all_weights)    
        values = torch.cat(values, dim=2)
        weights = torch.cat(weights, dim=2)
        # (rearrange(weights, '... -> ... 1') 
        normalized_values = values / (rearrange(weights, '... -> ... 1')  + eps)
        return normalized_values


    def forward(
        self,
        x,y,
    ):
        
        x = x if getattr(self, 'ln0', None) is None else self.ln0(x) # pre-normalization
        Q = self.fc_q(x) # input_dim -> embed dim       
        K, V = self.fc_k(y), self.fc_v(y) # input_dim -> embed dim
        dim_split = self.dim_V // self.num_heads # multi-head attention
        if self.attn_mode == 'Efficient':
            q_bucket_size = self.q_bucket_size
            k_bucket_size = self.k_bucket_size
            
            Q_ = torch.stack(Q.split(int(dim_split), 2), 1)
            K_ = torch.stack(K.split(int(dim_split), 2), 1)
            V_ = torch.stack(V.split(int(dim_split), 2), 1) 
            
            A = self.memory_efficient_attention(Q_, K_, V_, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size)
            A = A.reshape(-1, A.shape[2], A.shape[3])
            Q_ = Q_.reshape(-1, Q_.shape[2], Q_.shape[3])
            O = torch.cat((Q_ + A).split(Q.size(0), 0), 2)
        else: # Basic
            Q_ = torch.cat(Q.split(int(dim_split), 2), 0)
            K_ = torch.cat(K.split(int(dim_split), 2), 0)
            V_ = torch.cat(V.split(int(dim_split), 2), 0)
            A = self.dropout_attn(torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)) # this may not be correct due to mult-head attention
            A =  A.bmm(V_) # A(Q, K, V) attention_output
            O = torch.cat((Q_ + A).split(Q.size(0), 0), 2)

        
        O_ = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        O = O + self.dropout2(self.fc_o2(self.dropout1(F.gelu(self.fc_o1(O_))))) 
        return O


class SAB(nn.Module): # self attention block
    def __init__(self, dim_in, dim_out, num_heads=4, ln=False, attention_dropout = 0.1, dim_feedforward = 512, attn_mode = 'Normal'):
        super(SAB, self).__init__()
        self.mab = MultiHeadAttentionBlock(dim_in, dim_out, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode)
    def forward(self, X):
        return self.mab(X, X)

class CAB(nn.Module): # cross attention block
    def __init__(self, dim_in, dim_out, num_heads=4, ln=False, attention_dropout = 0.1, dim_feedforward = 512, attn_mode = 'Normal'):
        super(CAB, self).__init__()
        self.mab = MultiHeadAttentionBlock(dim_in, dim_out, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode)
    def forward(self, q, kv):       
        return self.mab(q, kv)
    
class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False, attn_mode='Normal'):
        super(PMA, self).__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        init.xavier_uniform_(self.S)
        self.mab = MultiHeadAttentionBlock(dim, dim, num_heads, ln=ln, attn_mode=attn_mode)

    def forward(self, X):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X)

class CommunicationBlock(nn.Module):
    def __init__(self, dim_input, num_enc_sab = 3, dim_hidden=384, dim_feedforward = 1024, num_heads=8, ln=False, attention_dropout=0.1, use_efficient_attention=False):
        super(CommunicationBlock, self).__init__()

        if use_efficient_attention:
            attn_mode = 'Efficient'
        else:
            attn_mode = 'Normal'
        self.dim_hidden = dim_hidden
        modules_enc = []
        modules_enc.append(SAB(dim_input, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode))
        for k in range(num_enc_sab):
            modules_enc.append(SAB(dim_hidden, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode))
        self.enc = nn.Sequential(*modules_enc)

    def forward(self, x):
        x = self.enc(x)
        return x

class CrossAttentionBlock(nn.Module):
    def __init__(self, dim_input, num_enc_sab = 3, dim_hidden=384, dim_feedforward = 1024, num_heads=8, ln=False, attention_dropout=0.1, use_efficient_attention=False):
        super(CrossAttentionBlock, self).__init__()

        if use_efficient_attention:
            attn_mode = 'Efficient'
        else:
            attn_mode = 'Normal'
        self.dim_hidden = dim_hidden
        modules_enc = []
        modules_enc.append(CAB(dim_input, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode))
        for k in range(num_enc_sab):
            modules_enc.append(CAB(dim_hidden, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode))
        self.layers = nn.ModuleList(modules_enc)
    def forward(self, q, kv):
        for k in range(len(self.layers)):
            q = self.layers[k](q, kv) # x[0] query, x[1] key value       
        return q # the output token length is len(x)
    
class AggregationBlock(nn.Module):
    def __init__(self, dim_input, num_enc_sab = 3, num_outputs = 1, dim_hidden=384, dim_feedforward = 1024, num_heads=8, ln=False, attention_dropout=0.1, use_efficient_attention=False):
        super(AggregationBlock, self).__init__()

        self.num_outputs = num_outputs
        self.dim_hidden = dim_hidden

        if use_efficient_attention:
            attn_mode = 'Efficient'
        else:
            attn_mode = 'Normal'

        modules_enc = []
        modules_enc.append(SAB(dim_input, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode))
        for k in range(num_enc_sab):
            modules_enc.append(SAB(dim_hidden, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode))
        self.enc = nn.Sequential(*modules_enc)
        modules_dec = []
        modules_dec.append(PMA(dim_hidden, num_heads, num_outputs, attn_mode=attn_mode)) # after the PMA we should not put drop out
        self.dec = nn.Sequential(*modules_dec)
        
    def forward(self, x):
        x = self.enc(x)
        x = self.dec(x)
        x = x.view(-1, self.num_outputs * self.dim_hidden)
        return x


