import torch.nn as nn
from einops import rearrange
from layers.DUET_clutser import Linear_extractor_cluster 
from layers.DUET_utils import Mahalanobis_mask, Encoder, EncoderLayer, FullAttention, AttentionLayer
import torch


class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()

        self.cluster = Linear_extractor_cluster(args)
        self.CI = bool(args.CI)
        # self.n_vars = args.enc_in
        self.mask_generator = Mahalanobis_mask(args.input_length)
        self.Channel_transformer = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        FullAttention(
                            True,
                            None,
                            attention_dropout=args.att_dropout,
                            output_attention=False,
                        ),
                        args.hid_dim,
                        args.n_heads,
                    ),
                    args.hid_dim,
                    d_ff=None,
                    dropout=args.enc_dropout,
                    activation=args.activation_enc,
                )
                for _ in range(args.num_layers)
            ],
            norm_layer=torch.nn.LayerNorm(args.hid_dim)
        )

        self.linear_head = nn.Sequential(nn.Linear(args.hid_dim, args.predict_length), nn.Dropout(args.enc_dropout))

    def forward(self, input, mode=None):
        
        x, x_time,y_time, aux_data = input
        input = x.squeeze(-1)
        B,T,N = input.shape
        # x: [batch_size, seq_len, n_vars]
        if self.CI:
            channel_independent_input = rearrange(input, 'b l n -> (b n) l 1')

            reshaped_output, L_importance = self.cluster(channel_independent_input, n_vars=N)

            temporal_feature = rearrange(reshaped_output, '(b n) l 1 -> b l n', b=input.shape[0])

        else:
            temporal_feature, L_importance = self.cluster(input, n_vars=N)

        # B x d_model x n_vars -> B x n_vars x d_model
        temporal_feature = rearrange(temporal_feature, 'b d n -> b n d')
        # if self.n_vars > 1:

        changed_input = rearrange(input, 'b l n -> b n l')
        channel_mask = self.mask_generator(changed_input)

        channel_group_feature, attention = self.Channel_transformer(x=temporal_feature, attn_mask=channel_mask)

        output = self.linear_head(channel_group_feature)

        # else:
        #     output = temporal_feature
        #     output = self.linear_head(output)

        output = rearrange(output, 'b n d -> b d n')
        output = self.cluster.revin(output, "denorm")
        
        output = output.unsqueeze(-1)

        return output, L_importance
