import torch
from torch import nn as nn
from torch.nn import functional as F
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)  # Enforce the use of deterministic algorithms
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import ResidualBlockNoBN, flow_warp, make_layer, compute_video_complexity

from .spynet_arch import SpyNet
from basicsr.archs.quantize import QConv2d
import kornia as K
from torch.utils.checkpoint import checkpoint

@ARCH_REGISTRY.register()
class BasicVSR(nn.Module):
    """A recurrent network for video SR. Now only x4 is supported.

    Args:
        num_feat (int): Number of channels. Default: 64.
        num_block (int): Number of residual blocks for each branch. Default: 15
        spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
    """

    def __init__(self, args = None, num_feat=64, num_block=15, spynet_path=None):
        super().__init__()
        self.fq = args.fq
        self.num_feat = num_feat

        # alignment
        self.spynet = SpyNet(spynet_path)

        # propagation
        # Head module
        if args.fq:
            self.backward_trunk_head = nn.Sequential(QConv2d(args, num_feat + 3, num_feat, 3, 1, 1, bias=True, non_adaptive=True, to_8bit=True))
            self.backward_trunk_head_relu = nn.Sequential(nn.LeakyReLU(negative_slope=0.1, inplace=True))
            self.forward_trunk_head = nn.Sequential(QConv2d(args, num_feat + 3, num_feat, 3, 1, 1, bias=True, non_adaptive=True, to_8bit=True))
            self.forward_trunk_head_relu = nn.Sequential(nn.LeakyReLU(negative_slope=0.1, inplace=True))
        else:
            self.backward_trunk_head = nn.Sequential(nn.Conv2d(num_feat + 3, num_feat, 3, 1, 1, bias=True),
                                                     nn.LeakyReLU(negative_slope=0.1, inplace=True))
            self.forward_trunk_head = nn.Sequential(nn.Conv2d(num_feat + 3, num_feat, 3, 1, 1, bias=True),
                                                    nn.LeakyReLU(negative_slope=0.1, inplace=True))

        #Body module
        self.backward_trunk_body = ConvResidualBlocks(args, num_feat + 3, num_feat, num_block)
        self.forward_trunk_body = ConvResidualBlocks(args, num_feat + 3, num_feat, num_block)

        # reconstruction
        if args.fq:
            self.fusion = QConv2d(args, num_feat * 2, num_feat, 1, 1, 0, bias=True, non_adaptive=True, to_8bit=True)
            self.upconv1 = QConv2d(args, num_feat, num_feat * 4, 3, 1, 1, bias=True, non_adaptive=True, to_8bit=True)
            self.upconv2 = QConv2d(args, num_feat, 64 * 4, 3, 1, 1, bias=True, non_adaptive=True, to_8bit=True)
            self.conv_hr = QConv2d(args,64, 64, 3, 1, 1, bias=True, non_adaptive=True, to_8bit=True)
            self.conv_last = QConv2d(args,64, 3, 3, 1, 1, bias=True, non_adaptive=True, to_8bit=True)
        else:
            self.fusion = nn.Conv2d(num_feat * 2, num_feat, 1, 1, 0, bias=True)
            self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
            self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
            self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
            self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)

        self.pixel_shuffle = nn.PixelShuffle(2)

        # activation functions
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)

        if args.videowise:
            self.measure_l = nn.Parameter(torch.FloatTensor([128]).cuda())  # For video_wise
            self.measure_u = nn.Parameter(torch.FloatTensor([128]).cuda())  # For video_wise
            self.tanh = nn.Tanh()
            self.ema_epoch = 1
            self.init = False

        self.args = args
        # self.alpha2 = nn.Parameter(torch.FloatTensor([5]).cuda())

    def get_flow(self, x):
        b, n, c, h, w = x.size()

        x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
        x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)

        flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
        flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)

        return flows_forward, flows_backward

    def forward(self, x):

        """Forward function of BasicVSR.

        Args:
            x: Input frames with shape (b, n, c, h, w). n is the temporal dimension / number of frames.
        """
        flows_forward, flows_backward = self.get_flow(x)
        b, n, _, h, w = x.size()
        if self.args.videowise:
            video = x.clone()#（16/2，3，96，96）
            video_grad = compute_video_complexity(self.args, video,flows_forward, flows_backward)

            # print('video_grad',video_grad)
            # print('self.init',self.init)
            if self.init:
                # print(image_grad)
                # print('self.ema_epoch',self.ema_epoch)
                if self.ema_epoch == 1:
                    measure_l = torch.quantile(video_grad.detach(), self.args.video_percentile/100.0)
                    measure_u = torch.quantile(video_grad.detach(), 1-self.args.video_percentile/100.0)
                    nn.init.constant_(self.measure_l, measure_l)
                    nn.init.constant_(self.measure_u, measure_u)
                    print('update succesfully,now the measure_l',self.measure_l)
                else:
                    beta = self.args.ema_beta
                    new_measure_l = self.measure_l * beta + torch.quantile(video_grad.detach(), self.args.video_percentile/100.0) * (1-beta)
                    new_measure_u = self.measure_u * beta + torch.quantile(video_grad.detach(), 1-self.args.video_percentile/100.0) * (1-beta)
                    nn.init.constant_(self.measure_l, new_measure_l.item())
                    nn.init.constant_(self.measure_u, new_measure_u.item())

                self.ema_epoch += 1
                bit_img = torch.Tensor([0.0]).cuda()

            else:
                bit_img_soft = (video_grad - (self.measure_u + self.measure_l)/2) * (2/(self.measure_u - self.measure_l)) # Scale the range to [-1, 1] with an interval length of 2
                bit_img_soft = self.tanh(bit_img_soft)
                bit_img_hard = (video_grad < self.measure_l) * (-1.0) + (video_grad >= self.measure_l) * (video_grad <= self.measure_u) * (0.0) + (video_grad> self.measure_u) *(1.0)
                bit_img = bit_img_soft - bit_img_soft.detach() + bit_img_hard.detach() # the order matters
                bit_img = bit_img.view(bit_img.shape[0], 1, 1, 1)


        # backward branch
        out_l = []
        feat = None
        feat_prop = x.new_zeros(b, self.num_feat, h, w)
        for i in range(n - 1, -1, -1):
            x_i = x[:, i, :, :, :]
            if i < n - 1:
                flow = flows_backward[:, i, :, :, :]
                feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
            feat_prop = torch.cat([x_i, feat_prop], dim=1)
            if self.args.fq:
                bit_fq = torch.zeros(x.shape[0]).cuda()
                feat_prop, bit_fq = self.backward_trunk_head([feat_prop, bit_fq])
                feat_prop = self.backward_trunk_head_relu(feat_prop)
            else:
                feat_prop = self.backward_trunk_head(feat_prop)

            bit = torch.zeros(x.shape[0]).cuda()  # Batch size

            if self.args.videowise:
                # feat_prop, feat, bit, bit_img = checkpoint(lambda x: self.backward_trunk_body(x),[feat_prop,feat,bit,bit_img])
                feat_prop, feat, bit, bit_img = self.backward_trunk_body([feat_prop,feat,bit,bit_img])
            else:
                # feat_prop, feat, bit = checkpoint(lambda x: self.backward_trunk_body(x),[feat_prop, feat, bit])
                feat_prop, feat, bit = self.backward_trunk_body([feat_prop, feat, bit])

            out_l.insert(0, feat_prop)
            backward_bit = bit.clone()
            # print('backward_bit******',bit)


        # forward branch
        feat_prop = torch.zeros_like(feat_prop)
        for i in range(0, n):
            x_i = x[:, i, :, :, :]
            if i > 0:
                flow = flows_forward[:, i - 1, :, :, :]
                feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))

            feat_prop = torch.cat([x_i, feat_prop], dim=1)
            if self.args.fq:
                feat_prop,bit_fq = self.forward_trunk_head([feat_prop, bit_fq])
                feat_prop = self.forward_trunk_head_relu(feat_prop)
            else:
                feat_prop = self.forward_trunk_head(feat_prop)
            # print('**********beforeforward_bit******',backward_bit)
            bit = backward_bit.clone()

            if self.args.videowise:
                # feat_prop, feat, bit, bit_img = checkpoint(lambda x: self.forward_trunk_body(x),[feat_prop, feat, bit, bit_img])
                feat_prop, feat, bit, bit_img = self.forward_trunk_body([feat_prop,feat,bit,bit_img])
            else:
                # feat_prop, feat, bit = checkpoint(lambda x: self.forward_trunk_body(x),[feat_prop, feat, bit])
                feat_prop, feat, bit = self.forward_trunk_body([feat_prop, feat, bit])
            # print('forward_bit******',bit)
            out = torch.cat([out_l[i], feat_prop], dim=1)
            # upsample
            if self.args.fq:
                out,bit_fq = self.fusion([out,bit_fq])
                out = self.lrelu(out)
                out,bit_fq = self.upconv1([out,bit_fq])
                out = self.lrelu(self.pixel_shuffle(out))
                out,bit_fq = self.upconv2([out,bit_fq])
                out = self.lrelu(self.pixel_shuffle(out))
                out,bit_fq = self.conv_hr([out,bit_fq])
                out = self.lrelu(out)
                out,bit_fq = self.conv_last([out,bit_fq])
            else:
                out = self.lrelu(self.fusion(out))
                out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
                out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
                out = self.lrelu(self.conv_hr(out))
                out = self.conv_last(out)
            base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
            out += base
            out_l[i] = out

        return torch.stack(out_l, dim=1),feat,bit


class ConvResidualBlocks(nn.Module):
    """Conv and residual block used in BasicVSR.

    Args:
        num_in_ch (int): Number of input channels. Default: 3.
        num_out_ch (int): Number of output channels. Default: 64.
        num_block (int): Number of residual blocks. Default: 15.
    """

    def __init__(self, args, num_in_ch=3, num_out_ch=64, num_block=15):
        super().__init__()
        self.main = nn.Sequential(
            make_layer(ResidualBlockNoBN, num_block, args = args, num_feat=num_out_ch))

    def forward(self, fea):
        return self.main(fea)