from torch import nn
from einops import rearrange
import torch.nn.functional as F
from utils.positional_embedding import PositionalEmbedding

class cs_block(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=512):
        super(cs_block, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.gelu = nn.GELU()
        self.fc3 = nn.Linear(hidden_dim, output_dim)
    def forward(self, x):
        x = self.fc1(x)
        res = self.gelu(x)
        res = self.fc2(res)
        x = x + res
        x = self.fc3(x)
        return x

class cs_encoder(nn.Module):
    def __init__(self, configs):
        super(cs_encoder, self).__init__()
        self.num_seg = configs.num_seg
        self.parameter_share = configs.parameter_share
        self.use_pos_emb = configs.use_pos_emb
        if self.parameter_share:
            self.cs_block = cs_block(self.num_seg, self.num_seg, hidden_dim=self.num_seg)
        else:
            self.cs_block1 = cs_block(self.num_seg, self.num_seg, hidden_dim=self.num_seg)
            self.cs_block2 = cs_block(self.num_seg, self.num_seg, hidden_dim=self.num_seg)
            self.cs_block3 = cs_block(self.num_seg, self.num_seg, hidden_dim=self.num_seg)
            self.cs_block4 = cs_block(self.num_seg, self.num_seg, hidden_dim=self.num_seg)
            self.cs_block5 = cs_block(self.num_seg, self.num_seg, hidden_dim=self.num_seg)
            self.cs_block6 = cs_block(self.num_seg, self.num_seg, hidden_dim=self.num_seg)
            self.cs_block7 = cs_block(self.num_seg, self.num_seg, hidden_dim=self.num_seg)
        self.gelu = nn.GELU()
        self.relu = nn.ReLU()
        self.qkv_1 = None
        self.cal_att_map = configs.cal_att_map
        if self.use_pos_emb:
            self.num_channels = configs.num_channels
            self.embed_dim = configs.seq_len//configs.num_seg
            self.pos_emb = PositionalEmbedding(self.num_seg, self.num_channels*self.embed_dim)
            
    def forward(self, x):
        B, M, L = x.shape
        x = rearrange(x, 'b m (n p) -> b (p m) n', b=B, m=M, n=self.num_seg)
        if self.use_pos_emb:
            pos_emb = self.pos_emb(x)
            x = x + pos_emb
        if self.parameter_share:
            qkv = self.cs_block(x)
        else:
            q = self.cs_block1(x)
            k = self.cs_block2(x)
            v = self.cs_block3(x)
        if self.cal_att_map:
            self.q_1 = q.clone()
            self.k_1 = k.clone()
        if self.parameter_share:
            att_x = F.scaled_dot_product_attention(qkv, qkv, qkv)
        else:
            att_x = F.scaled_dot_product_attention(q, k, v)
        att_x = self.relu(att_x)
        if self.parameter_share:
            qkv = self.cs_block(att_x)
        else:
            q = self.cs_block4(att_x)
            k = self.cs_block5(att_x)
            v = self.cs_block6(att_x)
        if self.cal_att_map:
            self.qkv_2 = qkv.clone()
            self.q_2 = q.clone()
            self.k_2 = k.clone()
        if self.parameter_share:
            att_x = F.scaled_dot_product_attention(qkv, qkv, qkv)
        else:
            att_x = F.scaled_dot_product_attention(q, k, v)
        x = x + att_x
        if self.parameter_share:
            x = self.cs_block(x)
        else:
            x = self.cs_block7(x)
        x = rearrange(x, 'b (p m) n -> b m (n p)', b=B, m=M, n=self.num_seg)
        return x

class mm_encoder(nn.Module):
    def __init__(self, configs, att_method='cs'):
        super(mm_encoder, self).__init__()
        self.num_seg = configs.num_seg
        self.seq_len = configs.seq_len
        self.patch_len = self.seq_len//self.num_seg
        self.parameter_share = configs.parameter_share
        self.use_pos_emb = configs.use_pos_emb
        self.att_method = att_method
        if self.att_method == 'cs':
            self.out_len = self.patch_len
        elif self.att_method == 'ml':
            self.out_len = self.seq_len
        else:
            self.out_len = self.num_seg
        self.att_key_0 = nn.Linear(self.out_len, self.out_len)
        self.att_value_0 = nn.Linear(self.out_len, self.out_len)
        self.att_query_0 = nn.Linear(self.out_len, self.out_len)
        self.att_key_1 = nn.Linear(self.out_len, self.out_len)
        self.att_value_1 = nn.Linear(self.out_len, self.out_len)
        self.att_query_1 = nn.Linear(self.out_len, self.out_len)
        self.FNN = nn.Sequential(nn.Linear(self.out_len, self.out_len), nn.ReLU(), nn.Linear(self.out_len, self.out_len))
        self.gelu = nn.GELU()
        self.relu = nn.ReLU()
        self.cal_att_map = configs.cal_att_map
        if self.use_pos_emb:
            self.num_channels = configs.num_channels
            self.embed_dim = configs.seq_len//configs.num_seg
            self.pos_emb = PositionalEmbedding(self.num_seg, self.num_channels*self.embed_dim)
    def forward(self, x, enc_x):
        B, M, L = x.shape
        if self.att_method == 'ci':
            x = rearrange(x, 'b m (n p) -> b m p n', b=B, m=M, n=self.num_seg)
            enc_x = rearrange(enc_x, 'b m (n p) -> b m p n', b=B, m=M, n=self.num_seg)
        elif self.att_method == 'cc':
            x = rearrange(x, 'b m (n p) -> b (p m) n', b=B, m=M, n=self.num_seg)
            enc_x = rearrange(enc_x, 'b m (n p) -> b (p m) n', b=B, m=M, n=self.num_seg)
        elif self.att_method == 'cs':
            x = rearrange(x, 'b m (n p) -> b n m p', b=B, m=M, n=self.num_seg)
            enc_x = rearrange(enc_x, 'b m (n p) -> b n m p', b=B, m=M, n=self.num_seg)
        elif self.att_method == 'ml':
            pass
        if self.use_pos_emb:
            pos_emb = self.pos_emb(x)
            x = x + pos_emb
        q = self.att_query_0(x)
        k = self.att_key_0(enc_x)
        v = self.att_value_0(enc_x)
        if self.cal_att_map:
            self.q_1 = q.clone()
            self.k_1 = k.clone()
        att_x = F.scaled_dot_product_attention(q, k, v)
        x = x + att_x
        x = self.FNN(x)
        if self.att_method=='ci':
            x = rearrange(x, 'b m p n -> b m (n p)', b=B, m=M, n=self.num_seg)
        elif self.att_method=='cc':
            x = rearrange(x, 'b (p m) n -> b m (n p)', b=B, m=M, n=self.num_seg)
        elif self.att_method=='cs':
            x = rearrange(x, 'b n m p -> b m (n p)', b=B, m=M, n=self.num_seg)
        return x 