from collections import OrderedDict
import math

import torch
import torch.nn as nn
from torch.utils import model_zoo
from typing import Dict, Optional, Tuple

from models.detection.yolox.models.network_blocks import BaseConv
from sate.model.unet import SATE_STEM


class SATEEncoder(nn.Module):
    def __init__(self, dataset_name, in_channels: Tuple[int, ...] = (64, 128, 256)):
        super().__init__()
        self.e2vid_encoder = SATE_STEM(num_input_channels=5, skip_type='sum', recurrent_block_type='convlstm', num_encoders=3, base_num_channels=32, num_residual_blocks=2, use_upsample_conv=False, norm='BN', dataset_name=dataset_name)

        self.down_samp = nn.ModuleList()
        for i in range(len(in_channels)):
            self.down_samp.append(nn.Sequential(*[
                    BaseConv(in_channels=in_channels[i], out_channels=in_channels[i] * 2, ksize=3, stride=2), 
                    BaseConv(in_channels=in_channels[i] * 2, out_channels=in_channels[i] * 2, ksize=3, stride=2)  
                ]))
    
    def forward(self, x, prev_states=None):
        states, out1 = self.e2vid_encoder(x, prev_states)
        mid_f = [out1[2], out1[4], out1[8]]
        mid_features = list()
        for k, (feature, down_sample) in enumerate(zip(mid_f, self.down_samp)):
            mid_features.append(down_sample(feature))
        # for feature in mid_features:
        #     visualize_and_save_feature_maps(feature)
        return mid_features, states
    
def build_SATE_backbone(dataset_name):
    return SATEEncoder(dataset_name=dataset_name)