import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
import cv2

from .base.builder import MODELS
from .base.base_model import Base_model
# --------------------------------------------- Binarized Basic Units -----------------------------------------------------------------


class LearnableBias(nn.Module):
    def __init__(self, out_chn):
        super(LearnableBias, self).__init__()
        self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True)

    def forward(self, x):
        out = x + self.bias.expand_as(x)
        return out



class RPReLU(nn.Module):
    def __init__(self, inplanes):
        super(RPReLU, self).__init__()
        self.pr_bias0 = LearnableBias(inplanes)
        self.pr_prelu = nn.PReLU(inplanes)
        self.pr_bias1 = LearnableBias(inplanes)

    def forward(self, x):
        x = self.pr_bias1(self.pr_prelu(self.pr_bias0(x)))      
        return x
    
class Q_A(torch.autograd.Function):  # dorefanet, but constrain to {-1, 1}
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x.sign()                     
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = (2 - torch.abs(2*input))
        grad = grad_input.clamp(0) * grad_output.clone()
        return grad

class Q_W(torch.autograd.Function):  # xnor-net, but gradient use identity approximation
    @staticmethod
    def forward(ctx, x):
        return x.sign()
    @staticmethod
    def backward(ctx, grad):
        return grad

class BinaryConv2d(nn.Conv2d):

    def __init__(self, in_chn, out_chn, kernel_size, stride=1, padding=0, bias=False, groups=1):
        super(BinaryConv2d, self).__init__(
            in_chn,
            out_chn,
            kernel_size,
            stride=stride,
            padding=padding,
            groups=groups,
            bias=bias
        )
        self.relu=RPReLU(in_chn)

    def forward(self, x):
        x = Q_A.apply(x)
        binary_weights = Q_W.apply(self.weight)
        out=F.conv2d(x, binary_weights,self.bias, stride=self.stride, padding=self.padding,groups=self.groups)
        out =self.relu(out)
        out = out + x
        return out



class BinaryConv2d_Down(nn.Module):
    '''
    input: b,c,h,w
    output: b,c/2,2h,2w
    '''
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False, groups=1):
        super(BinaryConv2d_Down, self).__init__()

        self.biconv_1 = BinaryConv2d(in_channels, in_channels, kernel_size, stride, padding, bias, groups)
        self.biconv_2 = BinaryConv2d(in_channels, in_channels, kernel_size, stride, padding, bias, groups)
        self.avg_pool = nn.AvgPool2d(kernel_size = 2, stride = 2, padding = 0)

    def forward(self, x):
        '''
        x: b,c,h,w
        out: b,2c,h/2,w/2
        '''
        out = self.avg_pool(x)
        out_1 = out
        out_2 = out_1.clone()
        out_1 = self.biconv_1(out_1)
        out_2 = self.biconv_2(out_2)
        out = torch.cat([out_1, out_2], dim=1)

        return out



class BinaryConv2d_Up(nn.Module):
    '''
    input: b,c,h,w
    output: b,c/2,2h,2w
    '''
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False, groups=1):
        super(BinaryConv2d_Up, self).__init__()

        self.biconv_1 = BinaryConv2d(out_channels, out_channels, kernel_size, stride, padding, bias, groups)
        self.biconv_2 = BinaryConv2d(out_channels, out_channels, kernel_size, stride, padding, bias, groups)

    def forward(self, x):
        '''
        x: b,c,h,w
        out: b,c/2,2h,2w
        '''
        b,c,h,w = x.shape
        out = F.interpolate(x, scale_factor=2, mode='bilinear')
        
        out_1 = out[:,:c//2,:,:]
        out_2 = out[:,c//2:,:,:]

        out_1 = self.biconv_1(out_1)
        out_2 = self.biconv_2(out_2)

        
        out = (out_1 + out_2) / 2

        return out


class BinaryConv2d_Fusion_Decrease(nn.Module):
    '''
    input: b,c,h,w
    output: b,c/2,h,w
    '''
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False, groups=1):
        super(BinaryConv2d_Fusion_Decrease, self).__init__()

        self.biconv_1 = BinaryConv2d(out_channels, out_channels, kernel_size, stride, padding, bias, groups)
        self.biconv_2 = BinaryConv2d(out_channels, out_channels, kernel_size, stride, padding, bias, groups)
    

    def forward(self, x):
        '''
        x: b,c,h,w
        out: b,c/2,h,w
        '''
        b,c,h,w = x.shape
        out = x
        
        out_1 = out[:,:c//2,:,:]
        out_2 = out[:,c//2:,:,:]

        out_1 = self.biconv_1(out_1)
        out_2 = self.biconv_2(out_2)
        
        out = (out_1 + out_2) / 2

        return out


class BinaryConv2d_Fusion_Increase(nn.Module):
    '''
    input: b,c,h,w
    output: b,2c,h,w
    '''
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False, groups=1):
        super(BinaryConv2d_Fusion_Increase, self).__init__()

        self.biconv_1 = BinaryConv2d(in_channels, in_channels, kernel_size, stride, padding, bias, groups)
        self.biconv_2 = BinaryConv2d(in_channels, in_channels, kernel_size, stride, padding, bias, groups)
      

    def forward(self, x):
        '''
        x: b,c,h,w
        out: b,2c,h,w
        '''
      
        out_1 = x
        out_2 = out_1.clone()
        out_1 = self.biconv_1(out_1)
        out_2 = self.biconv_2(out_2)
        out = torch.cat([out_1, out_2], dim=1)

        return out

# ---------------------------------------------------------- Binarized UNet------------------------------------------------------



class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, *args, **kwargs):
        x = self.norm(x)
        return self.fn(x, *args, **kwargs)



class FeedForward(nn.Module):
    def __init__(self, dim, mult=2):
        super().__init__()
        self.net = nn.Sequential(
            BinaryConv2d_Fusion_Increase(dim, dim * mult, 1, 1, bias=False),
            BinaryConv2d_Fusion_Increase(dim * mult, dim * mult * mult, 1, 1, bias=False),
            RPReLU(dim * mult * mult),
            BinaryConv2d(dim * mult * mult, dim * mult * mult, 3, 1, 1, bias=False, groups=dim),
            RPReLU(dim * mult * mult),
            BinaryConv2d_Fusion_Decrease(dim * mult * mult, dim * mult, 1, 1, bias=False),
            BinaryConv2d_Fusion_Decrease(dim * mult, dim, 1, 1, bias=False),
        )

    def forward(self, x):
        """
        x: [b,h,w,c]
        return out: [b,h,w,c]
        """
        out = self.net(x.permute(0, 3, 1, 2))
        return out.permute(0, 2, 3, 1)



class Block(nn.Module):
    def __init__(
            self,
            dim,
            dim_head,
            heads,
            num_blocks,
    ):
        super().__init__()
        self.blocks = nn.ModuleList([])
        for _ in range(num_blocks):
            self.blocks.append(
                PreNorm(dim, FeedForward(dim=dim))
            )

    def forward(self, x):
        """
        x: [b,c,h,w]
        return out: [b,c,h,w]
        """
        x = x.permute(0, 2, 3, 1)
        for ff in self.blocks:
            x = ff(x) + x
        out = x.permute(0, 3, 1, 2)
        return out



class body(nn.Module):
    def __init__(self, in_dim=16, out_dim=16, dim=16, stage=2, num_blocks=[2,4,4]):
        super(body, self).__init__()
        self.dim = dim
        self.stage = stage

        # Input projection
        self.embedding = BinaryConv2d(in_dim, self.dim, 3, 1, 1, bias=False)                           # 1-bit -> 32-bit

        # Encoder
        self.encoder_layers = nn.ModuleList([])
        dim_stage = dim
        for i in range(stage):
            self.encoder_layers.append(nn.ModuleList([
                Block(dim=dim_stage, num_blocks=num_blocks[i], dim_head=dim, heads=dim_stage // dim),
                BinaryConv2d_Down(dim_stage, dim_stage * 2, 3, 1, 1, bias=False),
            ]))
            dim_stage *= 2

        # Bottleneck
        self.bottleneck = Block(
            dim=dim_stage, dim_head=dim, heads=dim_stage // dim, num_blocks=num_blocks[-1])

        # Decoder
        self.decoder_layers = nn.ModuleList([])
        for i in range(stage):
            self.decoder_layers.append(nn.ModuleList([
                BinaryConv2d_Up(dim_stage, dim_stage // 2, 3, 1, 1, bias=False),
                BinaryConv2d_Fusion_Decrease(dim_stage, dim_stage // 2, 1, 1, bias=False),
                Block(
                    dim=dim_stage // 2, num_blocks=num_blocks[stage - 1 - i], dim_head=dim,
                    heads=(dim_stage // 2) // dim),
            ]))
            dim_stage //= 2

        # Output projection
        self.mapping = BinaryConv2d(self.dim, out_dim, 3, 1, 1, bias=False)                                # 1-bit -> 32-bit

    def forward(self, x):
        """
        x: [b,c,h,w]
        return out:[b,c,h,w]
        """

        # Embedding
        fea = self.embedding(x)

        # Encoder
        fea_encoder = []
        for (Block, FeaDownSample) in self.encoder_layers:
            
            fea = Block(fea)
            fea_encoder.append(fea)
            
            fea = FeaDownSample(fea)

        # Bottleneck
        fea = self.bottleneck(fea)

        # Decoder
        for i, (FeaUpSample, Fution, Block) in enumerate(self.decoder_layers):
            fea = FeaUpSample(fea)
            fea = Fution(torch.cat([fea, fea_encoder[self.stage-1-i]], dim=1))
            fea = Block(fea)

        # Mapping
        out = self.mapping(fea) + x

        return out

def sampling_(x, s_factor, mode_='bicubic'):
    return nn.functional.interpolate(x, scale_factor=s_factor, mode=mode_, align_corners=False,
                                     recompute_scale_factor=False)

class Net(nn.Module):
    
    def __init__(self, in_channels=5, out_channels=4, n_feat=16, stage=1, num_blocks=[1,1,1]):#stage=3
        super(Net, self).__init__()
        self.stage = stage
        self.conv_in = nn.Conv2d(in_channels, n_feat, kernel_size=3, padding=(3 - 1) // 2,bias=False)      
        modules_body = [body(dim=n_feat, stage=2, num_blocks=num_blocks) for _ in range(self.stage)]  
        self.body = nn.Sequential(*modules_body)
        self.conv_out = nn.Conv2d(n_feat, out_channels, kernel_size=3, padding=(3 - 1) // 2,bias=False)     

    def forward(self, ms, pan):
        """
        x: [b,c,h,w]
        return out:[b,c,h,w]
        """
        ms=sampling_(ms, s_factor=4)
        x = torch.cat([ms, pan], dim=1)
        b, c, h_inp, w_inp = x.shape
        hb, wb = 8, 8
        pad_h = (hb - h_inp % hb) % hb
        pad_w = (wb - w_inp % wb) % wb
        x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
        x = self.conv_in(x)
        h = self.body(x)
        h = self.conv_out(h)
        return h[:, :, :h_inp, :w_inp]
    

@MODELS.register_module()
class E2fif(Base_model):
    def __init__(self, cfg, logger, train_data_loader, test_data_loader0, test_data_loader1):
        super().__init__(cfg, logger, train_data_loader, test_data_loader0, test_data_loader1)

        self.add_module('core_module', Net())

    def get_model_output(self, input_batch):
        input_pan = input_batch['input_pan']
        input_lr = input_batch['input_lr']
        output = self.module_dict['core_module'](input_lr, input_pan)
        return output

    def train_iter(self, iter_id, input_batch, log_freq=10):
        G = self.module_dict['core_module']
        G_optim = self.optim_dict['core_module']

        input_pan = input_batch['input_pan']
        input_lr = input_batch['input_lr']

        output = G(input_lr, input_pan)

        loss_g = 0
        loss_res = dict()
        loss_cfg = self.cfg.get('loss_cfg', {})
        if 'rec_loss' in self.loss_module:
            target = input_batch['target']
            rec_loss = self.loss_module['rec_loss'](
                out=output, gt=target
            )
            loss_g = loss_g + rec_loss * loss_cfg['rec_loss'].w
            loss_res['rec_loss'] = rec_loss.item()

        loss_res['full_loss'] = loss_g.item()

        G_optim.zero_grad()
        loss_g.backward()
        G_optim.step()

        self.print_train_log(iter_id, loss_res, log_freq)

       