import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from models.FMRF import FMRF


class ConvBlock(nn.Module):

    def __init__(self, input_channel, output_channel):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(input_channel, output_channel, kernel_size=3, padding=1),
            nn.BatchNorm2d(output_channel))

    def forward(self, inp):
        return self.layers(inp)


class BackBone(nn.Module):

    def __init__(self, num_channel=64):
        super().__init__()
        self.fmrf1 = FMRF(
            sequence_length=42 * 42,
            embedding_dim=num_channel,
            resnet=False,
            num_layers=1,
            num_heads=1,
            mlp_dropout_rate=0.,
            attention_dropout=0.,
            positional_embedding='sine')
        self.fmrf2 = FMRF(
            sequence_length=21 * 21,
            embedding_dim=num_channel,
            resnet=False,
            num_layers=1,
            num_heads=1,
            mlp_dropout_rate=0.,
            attention_dropout=0.,
            positional_embedding='sine')
        self.fmrf3 = FMRF(
            sequence_length=100,
            embedding_dim=num_channel,
            resnet=False,
            num_layers=1,
            num_heads=1,
            mlp_dropout_rate=0.,
            attention_dropout=0.,
            positional_embedding='sine')
        self.fmrf4 = FMRF(
            sequence_length=25,
            embedding_dim=num_channel,
            resnet=False,
            num_layers=1,
            num_heads=1,
            mlp_dropout_rate=0.,
            attention_dropout=0.,
            positional_embedding='sine')

        self.layer1 = nn.Sequential(
            ConvBlock(3, num_channel),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )
        self.layer2 = nn.Sequential(
            ConvBlock(num_channel, num_channel),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.layer3 = nn.Sequential(
            ConvBlock(num_channel, num_channel),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.layer4 = nn.Sequential(
            ConvBlock(num_channel, num_channel),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )

    def forward(self, inp):
        l1 = self.layer1(inp)
        l1 = self.fmrf1(l1)

        l2 = self.layer2(l1)
        l2 = self.fmrf2(l2)

        l3 = self.layer3(l2)
        l3 = self.fmrf3(l3)

        l4 = self.layer4(l3)
        l4 = self.fmrf4(l4)

        return l4
