from torch import nn
from einops import rearrange
import torch.nn.functional as F
from utils.positional_embedding import PositionalEmbedding

class cs_block(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=512):
        super(cs_block, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.gelu = nn.GELU()
        self.fc3 = nn.Linear(hidden_dim, output_dim)
    def forward(self, x):
        x = self.fc1(x)
        res = self.gelu(x)
        res = self.fc2(res)
        x = x + res
        x = self.fc3(x)
        return x

class cs_fusion(nn.Module):
    def __init__(self, num_seg, parameter_share=True, cal_att_map=False):
        super(cs_fusion, self).__init__()
        self.num_seg = num_seg
        self.parameter_share = parameter_share
        self.cs_block = cs_block(self.num_seg, self.num_seg, hidden_dim=self.num_seg)

        self.gelu = nn.GELU()
        self.relu = nn.ReLU()
        self.qkv_1 = None
        self.cal_att_map = cal_att_map
    def forward(self, x):
        B, M, L = x.shape
        x = rearrange(x, 'b m (n p) -> b (p m) n', b=B, m=M, n=self.num_seg)
        if self.parameter_share:
            qkv = self.cs_block(x)
        else:
            q = self.cs_block1(x)
            k = self.cs_block2(x)
            v = self.cs_block3(x)
        if self.cal_att_map:
            self.q_1 = q.clone()
            self.k_1 = k.clone()
        if self.parameter_share:
            att_x = F.scaled_dot_product_attention(qkv, qkv, qkv)
        else:
            att_x = F.scaled_dot_product_attention(q, k, v)
        att_x = self.relu(att_x)
        if self.parameter_share:
            qkv = self.cs_block(att_x)
        else:
            q = self.cs_block4(att_x)
            k = self.cs_block5(att_x)
            v = self.cs_block6(att_x)
        if self.cal_att_map:
            self.qkv_2 = qkv.clone()
            self.q_2 = q.clone()
            self.k_2 = k.clone()
        if self.parameter_share:
            att_x = F.scaled_dot_product_attention(qkv, qkv, qkv)
        else:
            att_x = F.scaled_dot_product_attention(q, k, v)
        x = x + att_x
        if self.parameter_share:
            x = self.cs_block(x)
        else:
            x = self.cs_block7(x)
        x = rearrange(x, 'b (p m) n -> b m (n p)', b=B, m=M, n=self.num_seg)
        return x 