import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from models.labits_utlis.utils import local_time_mask
from models.labits_utlis.sn import *


def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.InstanceNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.InstanceNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

def conv_lr(in_channels, out_channels, hid_channels=128):
    return nn.Sequential(
        nn.Conv2d(in_channels, hid_channels, 3, padding=1),
        nn.InstanceNorm2d(hid_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(hid_channels, out_channels, 3, padding=1),
    )


class UNet(nn.Module):
    """ The Labits-to-APLOF net definition.
    """
    def __init__(self, in_channels, out_channels, hid_channels=16, visualization=False):
        super().__init__()
        self.visualization = visualization
        hid = hid_channels
        self.dconv_down1 = double_conv(in_channels, hid)
        self.dconv_down2 = double_conv(hid, hid*2)
        self.dconv_down3 = double_conv(hid*2, hid*4)
        self.dconv_down4 = double_conv(hid*4, hid*8)        

        self.maxpool = nn.MaxPool2d(2)
        
        if visualization:
            self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)        
            
            self.dconv_up3 = double_conv(hid*12, hid*4)
            self.dconv_up2 = double_conv(hid*6, hid*2)
            self.dconv_up1 = double_conv(hid*3, hid)
            
            self.conv_last = nn.Conv2d(hid, out_channels, 1)
            self.conv_lr = conv_lr(hid*8, out_channels, hid_channels=hid*2)

        
    def forward(self, labits, threshold):
        
        x = rearrange(labits, 'B T H W -> (B T) 1 H W')
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)   
        
        x = self.dconv_down4(x)

        flow_feat = x.clone()
        flow_feat = rearrange(flow_feat, '(B T) C h w -> B T C h w', B=1, T=12) # (B, T, C, h, w)
        local_mask = local_time_mask(labits, threshold=threshold) # (B, T, H, W)
        local_mask_lr = F.avg_pool2d(local_mask.float(), 8) >= (8/64) # (B, T, h, w)
        local_mask_lr = local_mask_lr.unsqueeze(2) # (B, T, 1, h, w)
                
        flow_feat = flow_feat * local_mask_lr
        
        if self.visualization:

            pred_lr = self.conv_lr(x)
            
            x = self.upsample(x)        
            x = torch.cat([x, conv3], dim=1)
            
            x = self.dconv_up3(x)
            x = self.upsample(x)        
            x = torch.cat([x, conv2], dim=1)       

            x = self.dconv_up2(x)
            x = self.upsample(x)        
            x = torch.cat([x, conv1], dim=1)   
            
            x = self.dconv_up1(x)
            
            pred_hr = self.conv_last(x)
            
            # Reshape to original shape
            pred_lr = rearrange(pred_lr, '(B T) C H W -> B T C H W', T=12)
            pred_hr = rearrange(pred_hr, '(B T) C H W -> B T C H W', T=12)
            
            return pred_lr, pred_hr, flow_feat
        
        else:
            return flow_feat

