import torch
import torch.nn as nn


from hyperbolic_lib.lib.lorentz.blocks.transformer_blocks import LorentzMultiHeadAttention
from utils.utils_h import (BATCH1D_TYPES,
                                       patch_len,
                                       matt_covar)

from models.blocks import LorentzInceptionBlock





class BaselineDeviationEncoder(nn.Module):
    def __init__(self,
                 manifold,
                 in_channel=1,
                 features=18*18,
                 kernel_sizes=(9, 19, 39),
                 inception_channels=8,
                 conv_type="original",
                 batch_type="original",
                 pool_type="dirty",
                 dropout=0):
        super().__init__()

        self.manifold = manifold

        self.inception_block = LorentzInceptionBlock(self.manifold,
                                                     in_channels=in_channel,
                                                     n_filters=int(features / 4),
                                                     kernel_sizes=kernel_sizes,
                                                     bottleneck_channels=inception_channels,
                                                     activation=None,
                                                     return_indices=False,
                                                     conv_type=conv_type,
                                                     batch_type=None,
                                                     pool_type=pool_type,
                                                     dropout=dropout
                                                     )

        self.baseline_block = LorentzInceptionBlock(self.manifold,
                                                     in_channels=in_channel,
                                                     n_filters=int(features / 4),
                                                     kernel_sizes=kernel_sizes,
                                                     bottleneck_channels=inception_channels,
                                                     activation=None,
                                                     return_indices=False,
                                                     conv_type=conv_type,
                                                     batch_type=None,
                                                     pool_type="average",
                                                     dropout=dropout)


        self.att = LorentzMultiHeadAttention(self.manifold,
                                             features + 1,
                                             1,
                                             1,
                                             out_features=features + 1)

        self.attention_prototypes = nn.Parameter(torch.rand(features+ 1), requires_grad=True)

        self.layer_norm = BATCH1D_TYPES["layer"](self.manifold, features + 1) #if batch_type is not None else nn.Sequential()
        self.layer_norm_2 = BATCH1D_TYPES[batch_type](self.manifold,
                                                    features + 1) if batch_type is not None else nn.Sequential()


        self.weight = nn.Parameter(torch.nn.functional.softmax(torch.randn(2)))

        self.activation = nn.ReLU(inplace=True)

        self.flat = nn.Flatten()

    def forward(self, x):


        incepted = self.inception_block(x)
        baseline = self.baseline_block(x)

        incepted = self.manifold.add_time(baseline[..., 1:] - incepted[..., 1:])
        incepted = self.layer_norm(incepted)
        incepted = self.manifold.add_time(self.activation(incepted[..., 1:]))

        atted = self.att(incepted)

        return atted

