import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from mmengine import Config


# LORA finetuning originally by edwardjhu
class LoRALayer():
    def __init__(
        self,
        r: int,
        lora_alpha: int,
        lora_dropout: float,
        merge_weights: bool,
    ):
        self.r = r
        self.lora_alpha = lora_alpha
        # Optional dropout
        if lora_dropout > 0.:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        # Mark the weight as unmerged
        self.merged = False
        self.merge_weights = merge_weights


class LoRALinear(nn.Linear, LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.,
        fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
        merge_weights: bool = True,
        **kwargs
    ):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                           merge_weights=merge_weights)

        self.fan_in_fan_out = fan_in_fan_out
        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
        self.reset_parameters()
        if fan_in_fan_out:
            self.weight.data = self.weight.data.transpose(0, 1)

    def reset_parameters(self):
        # nn.Linear.reset_parameters(self)
        if hasattr(self, 'lora_A'):
            # initialize B the same way as the default for nn.Linear and A to zero
            # this is different than what is described in the paper but should not affect performance
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def forward(self, x: torch.Tensor):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w
        if self.r > 0 and not self.merged:
            result = F.linear(x, T(self.weight), bias=self.bias)
            result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
            return result
        else:
            return F.linear(x, T(self.weight), bias=self.bias)

class ConvLoRA(nn.Conv2d, LoRALayer):
    def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
        #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
        nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
        assert isinstance(kernel_size, int)

        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(
                self.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
            )
            self.lora_B = nn.Parameter(
              self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
            )
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
        self.reset_parameters()
        self.merged = False

    def reset_parameters(self):
        #self.conv.reset_parameters()
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    # def train(self, mode=True):
    #     super(ConvLoRA, self).train(mode)
    #     if mode:
    #         if self.merge_weights and self.merged:
    #             if self.r > 0:
    #                 # Make sure that the weights are not merged
    #                 self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
    #             self.merged = False
    #     else:
    #         if self.merge_weights and not self.merged:
    #             if self.r > 0:
    #                 # Merge the weights and mark it
    #                 self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
    #             self.merged = True

    def forward(self, x):
        if self.r > 0 and not self.merged:
            # return self.conv._conv_forward(
            #     x, 
            #     self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,
            #     self.conv.bias
            # )
            weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
            bias = self.bias

            return F.conv2d(x, weight, bias=bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
        else:
            return F.conv2d(x, self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)

class ConvTransposeLoRA(nn.ConvTranspose2d, LoRALayer):
    def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
        #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
        nn.ConvTranspose2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
        assert isinstance(kernel_size, int)

        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(
                self.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
            )
            self.lora_B = nn.Parameter(
              self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
            )
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
        self.reset_parameters()
        self.merged = False

    def reset_parameters(self):
        #self.conv.reset_parameters()
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    # def train(self, mode=True):
    #     super(ConvTransposeLoRA, self).train(mode)
    #     if mode:
    #         if self.merge_weights and self.merged:
    #             if self.r > 0:
    #                 # Make sure that the weights are not merged
    #                 self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
    #             self.merged = False
    #     else:
    #         if self.merge_weights and not self.merged:
    #             if self.r > 0:
    #                 # Merge the weights and mark it
    #                 self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
    #             self.merged = True

    def forward(self, x):
        if self.r > 0 and not self.merged:
            weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
            bias = self.bias
            return F.conv_transpose2d(x, weight,
                bias=bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding,
                groups=self.groups, dilation=self.dilation)
        else:
            return F.conv_transpose2d(x, self.weight,
                bias=self.bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding,
                groups=self.groups, dilation=self.dilation)


class Conv2dLoRA(ConvLoRA):
    def __init__(self, *args, **kwargs):
        super(Conv2dLoRA, self).__init__(*args, **kwargs)


class ConvTranspose2dLoRA(ConvTransposeLoRA):
    def __init__(self, *args, **kwargs):
        super(ConvTranspose2dLoRA, self).__init__(*args, **kwargs)


def compute_depth_expectation(prob, depth_values):
    depth_values = depth_values.view(*depth_values.shape, 1, 1)
    depth = torch.sum(prob * depth_values, 1)
    return depth


def interpolate_float32(x, size=None, scale_factor=None, mode='nearest', align_corners=None):
    #with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
    return F.interpolate(x.float(), size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners)


def upflow4(flow, mode='bilinear'):
    new_size = (4 * flow.shape[2], 4 * flow.shape[3])
    # with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
    return F.interpolate(flow, size=new_size, mode=mode, align_corners=True)


def coords_grid(batch, ht, wd):
    # coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
    coords = (torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)))
    coords = torch.stack(coords[::-1], dim=0).float()
    return coords[None].repeat(batch, 1, 1, 1)


def norm_normalize(norm_out):
    min_kappa = 0.01
    norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1)
    norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10
    kappa = F.elu(kappa) + 1.0 + min_kappa
    final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1)
    return final_out


# uncertainty-guided sampling (only used during training)
@torch.no_grad()
def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta):
    device = init_normal.device
    B, _, H, W = init_normal.shape
    N = int(sampling_ratio * H * W)
    beta = beta

    # uncertainty map
    uncertainty_map = -1 * init_normal[:, -1, :, :]  # B, H, W

    # gt_invalid_mask (B, H, W)
    if gt_norm_mask is not None:
        gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest')
        gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5
        uncertainty_map[gt_invalid_mask] = -1e4

    # (B, H*W)
    _, idx = uncertainty_map.view(B, -1).sort(1, descending=True)

    # importance sampling
    if int(beta * N) > 0:
        importance = idx[:, :int(beta * N)]    # B, beta*N

        # remaining
        remaining = idx[:, int(beta * N):]     # B, H*W - beta*N

        # coverage
        num_coverage = N - int(beta * N)

        if num_coverage <= 0:
            samples = importance
        else:
            coverage_list = []
            for i in range(B):
                idx_c = torch.randperm(remaining.size()[1])    # shuffles "H*W - beta*N"
                coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1))     # 1, N-beta*N
            coverage = torch.cat(coverage_list, dim=0)                                      # B, N-beta*N
            samples = torch.cat((importance, coverage), dim=1)                              # B, N

    else:
        # remaining
        remaining = idx[:, :]  # B, H*W

        # coverage
        num_coverage = N

        coverage_list = []
        for i in range(B):
            idx_c = torch.randperm(remaining.size()[1])  # shuffles "H*W - beta*N"
            coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1))  # 1, N-beta*N
        coverage = torch.cat(coverage_list, dim=0)  # B, N-beta*N
        samples = coverage

    # point coordinates
    rows_int = samples // W         # 0 for first row, H-1 for last row
    rows_float = rows_int / float(H-1)         # 0 to 1.0
    rows_float = (rows_float * 2.0) - 1.0       # -1.0 to 1.0

    cols_int = samples % W          # 0 for first column, W-1 for last column
    cols_float = cols_int / float(W-1)         # 0 to 1.0
    cols_float = (cols_float * 2.0) - 1.0       # -1.0 to 1.0

    point_coords = torch.zeros(B, 1, N, 2)
    point_coords[:, 0, :, 0] = cols_float             # x coord
    point_coords[:, 0, :, 1] = rows_float             # y coord
    point_coords = point_coords.to(device)
    return point_coords, rows_int, cols_int


class FlowHead(nn.Module):
    def __init__(self, input_dim=128, hidden_dim=256, output_dim_depth=2, output_dim_norm=4, tuning_mode=None):
        super(FlowHead, self).__init__()
        self.conv1d = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
        self.conv2d = Conv2dLoRA(hidden_dim // 2, output_dim_depth, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)

        self.conv1n = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
        self.conv2n = Conv2dLoRA(hidden_dim // 2, output_dim_norm, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        depth = self.conv2d(self.relu(self.conv1d(x)))
        normal = self.conv2n(self.relu(self.conv1n(x)))
        return torch.cat((depth, normal), dim=1)


class ConvGRU(nn.Module):
    def __init__(self, hidden_dim, input_dim, kernel_size=3, tuning_mode=None):
        super(ConvGRU, self).__init__()
        self.convz = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0)
        self.convr = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0)
        self.convq = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0)

    def forward(self, h, cz, cr, cq, *x_list):
        x = torch.cat(x_list, dim=1)
        hx = torch.cat([h, x], dim=1)

        z = torch.sigmoid((self.convz(hx) + cz))
        r = torch.sigmoid((self.convr(hx) + cr))
        q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq))

        # z = torch.sigmoid((self.convz(hx) + cz).float())
        # r = torch.sigmoid((self.convr(hx) + cr).float())
        # q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq).float())

        h = (1-z) * h + z * q
        return h


def pool2x(x):
    return F.avg_pool2d(x, 3, stride=2, padding=1)


def pool4x(x):
    return F.avg_pool2d(x, 5, stride=4, padding=1)


def interp(x, dest):
    interp_args = {'mode': 'bilinear', 'align_corners': True}
    return interpolate_float32(x, dest.shape[2:], **interp_args)


class BasicMultiUpdateBlock(nn.Module):
    def __init__(self, args, hidden_dims=[], out_dims=2, tuning_mode=None):
        super().__init__()
        self.args = args
        self.n_gru_layers = args.model.decode_head.n_gru_layers # 3
        self.n_downsample = args.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K)

        # self.encoder = BasicMotionEncoder(args)
        # encoder_output_dim = 128 # if there is corr volume
        encoder_output_dim = 6 # no corr volume

        self.gru08 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (self.n_gru_layers > 1), tuning_mode=tuning_mode)
        self.gru16 = ConvGRU(hidden_dims[1], hidden_dims[0] * (self.n_gru_layers == 3) + hidden_dims[2], tuning_mode=tuning_mode)
        self.gru32 = ConvGRU(hidden_dims[0], hidden_dims[1], tuning_mode=tuning_mode)
        self.flow_head = FlowHead(hidden_dims[2], hidden_dim=2*hidden_dims[2], tuning_mode=tuning_mode)
        factor = 2**self.n_downsample

        self.mask = nn.Sequential(
            Conv2dLoRA(hidden_dims[2], hidden_dims[2], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0),
            nn.ReLU(inplace=True),
            Conv2dLoRA(hidden_dims[2], (factor**2)*9, 1, padding=0, r = 8 if tuning_mode == 'lora' else 0))

    def forward(self, net, inp, corr=None, flow=None, iter08=True, iter16=True, iter32=True, update=True):

        if iter32:
            net[2] = self.gru32(net[2], *(inp[2]), pool2x(net[1]))
        if iter16:
            if self.n_gru_layers > 2:
                net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1]), interp(net[2], net[1]))
            else:
                net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1]))
        if iter08:
            if corr is not None:
                motion_features = self.encoder(flow, corr)
            else:
                motion_features = flow
            if self.n_gru_layers > 1:
                net[0] = self.gru08(net[0], *(inp[0]), motion_features, interp(net[1], net[0]))
            else:
                net[0] = self.gru08(net[0], *(inp[0]), motion_features)

        if not update:
            return net

        delta_flow = self.flow_head(net[0])

        # scale mask to balence gradients
        mask = .25 * self.mask(net[0])
        return net, mask, delta_flow


class LayerNorm2d(nn.LayerNorm):
    def __init__(self, dim):
        super(LayerNorm2d, self).__init__(dim)

    def forward(self, x):
        x = x.permute(0, 2, 3, 1).contiguous()
        x = super(LayerNorm2d, self).forward(x)
        x = x.permute(0, 3, 1, 2).contiguous()
        return x


class ResidualBlock(nn.Module):
    def __init__(self, in_planes, planes, norm_fn='group', stride=1, tuning_mode=None):
        super(ResidualBlock, self).__init__()

        self.conv1 = Conv2dLoRA(in_planes, planes, kernel_size=3, padding=1, stride=stride, r = 8 if tuning_mode == 'lora' else 0)
        self.conv2 = Conv2dLoRA(planes, planes, kernel_size=3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
        self.relu = nn.ReLU(inplace=True)

        num_groups = planes // 8

        if norm_fn == 'group':
            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
            if not (stride == 1 and in_planes == planes):
                self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)

        elif norm_fn == 'batch':
            self.norm1 = nn.BatchNorm2d(planes)
            self.norm2 = nn.BatchNorm2d(planes)
            if not (stride == 1 and in_planes == planes):
                self.norm3 = nn.BatchNorm2d(planes)

        elif norm_fn == 'instance':
            self.norm1 = nn.InstanceNorm2d(planes)
            self.norm2 = nn.InstanceNorm2d(planes)
            if not (stride == 1 and in_planes == planes):
                self.norm3 = nn.InstanceNorm2d(planes)

        elif norm_fn == 'layer':
            self.norm1 = LayerNorm2d(planes)
            self.norm2 = LayerNorm2d(planes)
            if not (stride == 1 and in_planes == planes):
                self.norm3 = LayerNorm2d(planes)

        elif norm_fn == 'none':
            self.norm1 = nn.Sequential()
            self.norm2 = nn.Sequential()
            if not (stride == 1 and in_planes == planes):
                self.norm3 = nn.Sequential()

        if stride == 1 and in_planes == planes:
            self.downsample = None

        else:
            self.downsample = nn.Sequential(
                Conv2dLoRA(in_planes, planes, kernel_size=1, stride=stride,  r = 8 if tuning_mode == 'lora' else 0), self.norm3)

    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.norm1(y)
        y = self.relu(y)
        y = self.conv2(y)
        y = self.norm2(y)
        y = self.relu(y)

        if self.downsample is not None:
            x = self.downsample(x)

        return self.relu(x+y)


class ContextFeatureEncoder(nn.Module):
    '''
    Encoder features are used to:
        1. initialize the hidden state of the update operator 
        2. and also injected into the GRU during each iteration of the update operator
    '''
    def __init__(self, in_dim, output_dim, tuning_mode=None):
        '''
        in_dim     = [x4, x8, x16, x32]
        output_dim = [hindden_dims,   context_dims]
                    [[x4,x8,x16,x32],[x4,x8,x16,x32]]
        '''
        super().__init__()

        output_list = []
        for dim in output_dim:
            conv_out = nn.Sequential(
                ResidualBlock(in_dim[0], dim[0], 'layer', stride=1, tuning_mode=tuning_mode),
                Conv2dLoRA(dim[0], dim[0], 3, padding=1,  r = 8 if tuning_mode == 'lora' else 0))
            output_list.append(conv_out)

        self.outputs04 = nn.ModuleList(output_list)

        output_list = []
        for dim in output_dim:
            conv_out = nn.Sequential(
                ResidualBlock(in_dim[1], dim[1], 'layer', stride=1, tuning_mode=tuning_mode),
                Conv2dLoRA(dim[1], dim[1], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0))
            output_list.append(conv_out)

        self.outputs08 = nn.ModuleList(output_list)

        output_list = []
        for dim in output_dim:
            conv_out = nn.Sequential(
                ResidualBlock(in_dim[2], dim[2], 'layer', stride=1, tuning_mode=tuning_mode),
                Conv2dLoRA(dim[2], dim[2], 3, padding=1,  r = 8 if tuning_mode == 'lora' else 0))
            output_list.append(conv_out)

        self.outputs16 = nn.ModuleList(output_list)

    def forward(self, encoder_features):
        x_4, x_8, x_16, x_32 = encoder_features

        outputs04 = [f(x_4) for f in self.outputs04]
        outputs08 = [f(x_8) for f in self.outputs08]
        outputs16 = [f(x_16)for f in self.outputs16]
        # outputs32 = [f(x_32) for f in self.outputs32]

        return (outputs04, outputs08, outputs16)


class ConvBlock(nn.Module):
    # reimplementation of DPT
    def __init__(self, channels, tuning_mode=None):
        super(ConvBlock, self).__init__()

        self.act = nn.ReLU(inplace=True)
        self.conv1 = Conv2dLoRA(
            channels,
            channels,
            kernel_size=3,
            stride=1,
            padding=1,
            r = 8 if tuning_mode == 'lora' else 0
        )
        self.conv2 = Conv2dLoRA(
            channels,
            channels,
            kernel_size=3,
            stride=1,
            padding=1,
            r=8 if tuning_mode == 'lora' else 0
        )

    def forward(self, x):
        out = self.act(x)
        out = self.conv1(out)
        out = self.act(out)
        out = self.conv2(out)
        return x + out


class FuseBlock(nn.Module):
    # reimplementation of DPT
    def __init__(self, in_channels, out_channels, fuse=True, upsample=True, scale_factor=2, tuning_mode=None):
        super(FuseBlock, self).__init__()

        self.fuse = fuse
        self.scale_factor = scale_factor
        self.way_trunk = ConvBlock(in_channels, tuning_mode=tuning_mode)
        if self.fuse:
            self.way_branch = ConvBlock(in_channels, tuning_mode=tuning_mode)

        self.out_conv = Conv2dLoRA(
            in_channels,
            out_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            r = 8 if tuning_mode == 'lora' else 0
        )
        self.upsample = upsample

    def forward(self, x1, x2=None):
        if x2 is not None:
            x2 = self.way_branch(x2)
            x1 = x1 + x2

        out = self.way_trunk(x1)

        if self.upsample:
            out = interpolate_float32(
                out, scale_factor=self.scale_factor, mode="bilinear", align_corners=True
            )
        out = self.out_conv(out)
        return out


class Readout(nn.Module):
    # From DPT
    def __init__(self, in_features, use_cls_token=True, num_register_tokens=0, tuning_mode=None):
        super(Readout, self).__init__()
        self.use_cls_token = use_cls_token
        if self.use_cls_token == True:
            self.project_patch = LoRALinear(in_features, in_features, r = 8 if tuning_mode == 'lora' else 0)
            self.project_learn = LoRALinear((1 + num_register_tokens) * in_features, in_features, bias=False, r = 8 if tuning_mode == 'lora' else 0)
            self.act = nn.GELU()
        else:
            self.project = nn.Identity()

    def forward(self, x):

        if self.use_cls_token == True:
            x_patch = self.project_patch(x[0])
            x_learn = self.project_learn(x[1])
            x_learn = x_learn.expand_as(x_patch).contiguous()
            features = x_patch + x_learn
            return self.act(features)
        else:
            return self.project(x)


class Token2Feature(nn.Module):
    # From DPT
    def __init__(self, vit_channel, feature_channel, scale_factor, use_cls_token=True, num_register_tokens=0, tuning_mode=None):
        super(Token2Feature, self).__init__()
        self.scale_factor = scale_factor
        self.readoper = Readout(in_features=vit_channel, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens,  tuning_mode=tuning_mode)
        if scale_factor > 1 and isinstance(scale_factor, int):
            self.sample = ConvTranspose2dLoRA(r = 8 if tuning_mode == 'lora' else 0,
                in_channels=vit_channel,
                out_channels=feature_channel,
                kernel_size=scale_factor,
                stride=scale_factor,
                padding=0,
            )

        elif scale_factor > 1:
            self.sample = nn.Sequential(
                # Upsample2(upscale=scale_factor),
                # nn.Upsample(scale_factor=scale_factor),
                Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0,
                    in_channels=vit_channel,
                    out_channels=feature_channel,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                ),
            )

        elif scale_factor < 1:
            scale_factor = int(1.0 / scale_factor)
            self.sample = Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0,
                in_channels=vit_channel,
                out_channels=feature_channel,
                kernel_size=scale_factor+1,
                stride=scale_factor,
                padding=1,
            )

        else:
            self.sample = nn.Identity()

    def forward(self, x):
        x = self.readoper(x)
        # if use_cls_token == True:
        x = x.permute(0, 3, 1, 2).contiguous()
        if isinstance(self.scale_factor, float):
            x = interpolate_float32(x.float(), scale_factor=self.scale_factor, mode='nearest')
        x = self.sample(x)
        return x


class EncoderFeature(nn.Module):
    def __init__(self, vit_channel, num_ch_dec=[256, 512, 1024, 1024], use_cls_token=True, num_register_tokens=0, tuning_mode=None):
        super(EncoderFeature, self).__init__()
        self.vit_channel = vit_channel
        self.num_ch_dec = num_ch_dec

        self.read_3 = Token2Feature(self.vit_channel, self.num_ch_dec[3], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
        self.read_2 = Token2Feature(self.vit_channel, self.num_ch_dec[2], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
        self.read_1 = Token2Feature(self.vit_channel, self.num_ch_dec[1], scale_factor=2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
        self.read_0 = Token2Feature(self.vit_channel, self.num_ch_dec[0], scale_factor=7/2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)

    def forward(self, ref_feature):
        x = self.read_3(ref_feature[3])  # 1/14
        x2 = self.read_2(ref_feature[2]) # 1/14
        x1 = self.read_1(ref_feature[1]) # 1/7
        x0 = self.read_0(ref_feature[0]) # 1/4

        return x, x2, x1, x0


class DecoderFeature(nn.Module):
    def __init__(self, vit_channel, num_ch_dec=[128, 256, 512, 1024, 1024], use_cls_token=True, tuning_mode=None):
        super(DecoderFeature, self).__init__()
        self.vit_channel = vit_channel
        self.num_ch_dec = num_ch_dec

        self.upconv_3 = FuseBlock(
            self.num_ch_dec[4],
            self.num_ch_dec[3],
        fuse=False, upsample=False, tuning_mode=tuning_mode)

        self.upconv_2 = FuseBlock(
            self.num_ch_dec[3],
            self.num_ch_dec[2],
        tuning_mode=tuning_mode)

        self.upconv_1 = FuseBlock(
            self.num_ch_dec[2],
            self.num_ch_dec[1] + 2,
            scale_factor=7/4,
        tuning_mode=tuning_mode)

    def forward(self, ref_feature):
        x, x2, x1, x0 = ref_feature # 1/14 1/14 1/7 1/4

        x = self.upconv_3(x)     # 1/14
        x = self.upconv_2(x, x2) # 1/7
        x = self.upconv_1(x, x1) # 1/4
        # x = self.upconv_0(x, x0) # 4/7
        return x


class RAFTDepthNormalDPT5(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.in_channels = cfg.model.decode_head.in_channels # [1024, 1024, 1024, 1024]
        self.feature_channels = cfg.model.decode_head.feature_channels # [256, 512, 1024, 1024] [2/7, 1/7, 1/14, 1/14]
        self.decoder_channels = cfg.model.decode_head.decoder_channels # [128, 256, 512, 1024, 1024] [-, 1/4, 1/7, 1/14, 1/14]
        self.use_cls_token = cfg.model.decode_head.use_cls_token
        self.up_scale = cfg.model.decode_head.up_scale
        self.num_register_tokens = cfg.model.decode_head.num_register_tokens
        self.min_val = cfg.data_basic.depth_normalize[0]
        self.max_val = cfg.data_basic.depth_normalize[1]
        self.regress_scale = 100.0\

        try:
            tuning_mode = cfg.model.decode_head.tuning_mode
        except:
            tuning_mode = None
        self.tuning_mode = tuning_mode

        self.hidden_dims = self.context_dims = cfg.model.decode_head.hidden_channels # [128, 128, 128, 128]
        self.n_gru_layers = cfg.model.decode_head.n_gru_layers # 3
        self.n_downsample = cfg.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K)
        self.iters = cfg.model.decode_head.iters # 22
        self.slow_fast_gru = cfg.model.decode_head.slow_fast_gru # True

        self.num_depth_regressor_anchor = 256 # 512
        self.used_res_channel = self.decoder_channels[1] # now, use 2/7 res
        self.token2feature = EncoderFeature(self.in_channels[0], self.feature_channels, self.use_cls_token, self.num_register_tokens, tuning_mode=tuning_mode)
        self.decoder_mono = DecoderFeature(self.in_channels, self.decoder_channels, tuning_mode=tuning_mode)
        self.depth_regressor = nn.Sequential(
            Conv2dLoRA(self.used_res_channel,
                      self.num_depth_regressor_anchor,
                      kernel_size=3,
                      padding=1, r = 8 if tuning_mode == 'lora' else 0),
            # nn.BatchNorm2d(self.num_depth_regressor_anchor),
            nn.ReLU(inplace=True),
            Conv2dLoRA(self.num_depth_regressor_anchor,
                      self.num_depth_regressor_anchor,
                      kernel_size=1, r = 8 if tuning_mode == 'lora' else 0),
        )
        self.normal_predictor = nn.Sequential(
            Conv2dLoRA(self.used_res_channel,
                      128,
                      kernel_size=3,
                      padding=1, r = 8 if tuning_mode == 'lora' else 0,),
            # nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True),
            Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True),
            Conv2dLoRA(128, 3, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0),
        )

        self.context_feature_encoder = ContextFeatureEncoder(self.feature_channels, [self.hidden_dims, self.context_dims], tuning_mode=tuning_mode)
        self.context_zqr_convs = nn.ModuleList([Conv2dLoRA(self.context_dims[i], self.hidden_dims[i]*3, 3, padding=3//2, r = 8 if tuning_mode == 'lora' else 0) for i in range(self.n_gru_layers)])
        self.update_block = BasicMultiUpdateBlock(cfg, hidden_dims=self.hidden_dims, out_dims=6, tuning_mode=tuning_mode)

        self.relu = nn.ReLU(inplace=True)

    def get_bins(self, bins_num):
        depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cuda")
        # depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cpu")
        depth_bins_vec = torch.exp(depth_bins_vec)
        return depth_bins_vec

    def register_depth_expectation_anchor(self, bins_num, B):
        depth_bins_vec = self.get_bins(bins_num)
        depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1)
        self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False)

    def clamp(self, x):
        y = self.relu(x - self.min_val) + self.min_val
        y = self.max_val - self.relu(self.max_val - y)
        return y

    def regress_depth(self, feature_map_d):
        prob_feature = self.depth_regressor(feature_map_d)
        prob = prob_feature.softmax(dim=1)
        # prob = prob_feature.float().softmax(dim=1)

        # Error logging
        if torch.isnan(prob).any():
            print('prob_feat_nan!!!')
        if torch.isinf(prob).any():
            print('prob_feat_inf!!!')

        B = prob.shape[0]
        if "depth_expectation_anchor" not in self._buffers:
            self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B)
        d = compute_depth_expectation(
            prob,
            self.depth_expectation_anchor[:B, ...]).unsqueeze(1)

        # Error logging
        if torch.isnan(d ).any():
            print('d_nan!!!')
        if torch.isinf(d ).any():
            print('d_inf!!!')

        return (self.clamp(d) - self.max_val)/ self.regress_scale, prob_feature

    def pred_normal(self, feature_map, confidence):
        normal_out = self.normal_predictor(feature_map)

        # Error logging
        if torch.isnan(normal_out).any():
            print('norm_nan!!!')
        if torch.isinf(normal_out).any():
            print('norm_feat_inf!!!')

        return norm_normalize(torch.cat([normal_out, confidence], dim=1))
        # return norm_normalize(torch.cat([normal_out, confidence], dim=1).float())

    def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True):
        y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device),
                               torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij')
        meshgrid = torch.stack((x, y))
        meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1)
        # self.register_buffer('meshgrid', meshgrid, persistent=False)
        return meshgrid

    def upsample_flow(self, flow, mask):
        """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
        N, D, H, W = flow.shape
        factor = 2 ** self.n_downsample
        mask = mask.view(N, 1, 9, factor, factor, H, W)
        mask = torch.softmax(mask, dim=2)
        # mask = torch.softmax(mask.float(), dim=2)

        # up_flow = F.unfold(factor * flow, [3,3], padding=1)
        up_flow = F.unfold(flow, [3,3], padding=1)
        up_flow = up_flow.view(N, D, 9, 1, 1, H, W)

        up_flow = torch.sum(mask * up_flow, dim=2)
        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
        return up_flow.reshape(N, D, factor*H, factor*W)

    def initialize_flow(self, img):
        """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
        N, _, H, W = img.shape

        coords0 = coords_grid(N, H, W).to(img.device)
        coords1 = coords_grid(N, H, W).to(img.device)

        return coords0, coords1

    def upsample(self, x, scale_factor=2):
        """Upsample input tensor by a factor of 2
        """
        return interpolate_float32(x, scale_factor=scale_factor*self.up_scale/8, mode="nearest")

    def forward(self, vit_features, **kwargs):
        # read vit token to multi-scale features
        B, H, W, _, _, num_register_tokens = vit_features[1]
        vit_features = vit_features[0]

        # Error logging
        if torch.isnan(vit_features[0]).any():
            print('vit_feature_nan!!!')
        if torch.isinf(vit_features[0]).any():
            print('vit_feature_inf!!!')

        if self.use_cls_token == True:
            vit_features = [[ft[:, 1+num_register_tokens:, :].view(B, H, W, self.in_channels[0]), \
                ft[:, 0:1+num_register_tokens, :].view(B, 1, 1, self.in_channels[0] * (1+num_register_tokens))] for ft in vit_features]
        else:
            vit_features = [ft.view(B, H, W, self.in_channels[0]) for ft in vit_features]
        encoder_features = self.token2feature(vit_features) # 1/14, 1/14, 1/7, 1/4

        # Error logging
        for en_ft in encoder_features:
            if torch.isnan(en_ft).any():
                print('decoder_feature_nan!!!')
                print(en_ft.shape)
            if torch.isinf(en_ft).any():
                print('decoder_feature_inf!!!')
                print(en_ft.shape)

        # decode features to init-depth (and confidence)
        ref_feat = self.decoder_mono(encoder_features) # now, 1/4 for depth

        # Error logging
        if torch.isnan(ref_feat).any():
            print('ref_feat_nan!!!')
        if torch.isinf(ref_feat).any():
            print('ref_feat_inf!!!')

        feature_map = ref_feat[:, :-2, :, :] # feature map share of depth and normal prediction
        depth_confidence_map = ref_feat[:, -2:-1, :, :]
        normal_confidence_map = ref_feat[:, -1:, :, :]
        depth_pred, binmap = self.regress_depth(feature_map) # regress bin for depth
        normal_pred = self.pred_normal(feature_map, normal_confidence_map) # mlp for normal

        depth_init = torch.cat((depth_pred, depth_confidence_map, normal_pred), dim=1) # (N, 1+1+4, H, W)

        # encoder features to context-feature for init-hidden-state and contex-features
        cnet_list = self.context_feature_encoder(encoder_features[::-1])
        net_list = [torch.tanh(x[0]) for x in cnet_list] # x_4, x_8, x_16 of hidden state
        inp_list = [torch.relu(x[1]) for x in cnet_list] # x_4, x_8, x_16 context features

        # Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning 
        inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)]

        coords0, coords1 = self.initialize_flow(net_list[0])
        if depth_init is not None:
            coords1 = coords1 + depth_init

        if self.training:
            low_resolution_init = [self.clamp(depth_init[:,:1] * self.regress_scale + self.max_val), depth_init[:,1:2], norm_normalize(depth_init[:,2:].clone())]
            init_depth = upflow4(depth_init)
            flow_predictions = [self.clamp(init_depth[:,:1] * self.regress_scale + self.max_val)]
            conf_predictions = [init_depth[:,1:2]]
            normal_outs = [norm_normalize(init_depth[:,2:].clone())]

        else:
            flow_predictions = []
            conf_predictions = []
            samples_pred_list = []
            coord_list = []
            normal_outs = []
            low_resolution_init = []

        for itr in range(self.iters):
            # coords1 = coords1.detach()
            flow = coords1 - coords0
            if self.n_gru_layers == 3 and self.slow_fast_gru: # Update low-res GRU
                net_list = self.update_block(net_list, inp_list, iter32=True, iter16=False, iter08=False, update=False)
            if self.n_gru_layers >= 2 and self.slow_fast_gru:# Update low-res GRU and mid-res GRU
                net_list = self.update_block(net_list, inp_list, iter32=self.n_gru_layers==3, iter16=True, iter08=False, update=False)
            net_list, up_mask, delta_flow = self.update_block(net_list, inp_list, None, flow, iter32=self.n_gru_layers==3, iter16=self.n_gru_layers>=2)

            # F(t+1) = F(t) + \Delta(t)
            coords1 = coords1 + delta_flow

            # We do not need to upsample or output intermediate results in test_mode
            #if (not self.training) and itr < self.iters-1:
                #continue

            # upsample predictions
            if up_mask is None:
                flow_up = self.upsample(coords1-coords0, 4)
            else:
                flow_up = self.upsample_flow(coords1 - coords0, up_mask)
                # flow_up = self.upsample(coords1-coords0, 4)

            flow_predictions.append(self.clamp(flow_up[:,:1] * self.regress_scale + self.max_val))
            conf_predictions.append(flow_up[:,1:2])
            normal_outs.append(norm_normalize(flow_up[:,2:].clone()))

        outputs=dict(
            prediction=flow_predictions[-1],
            predictions_list=flow_predictions,
            confidence=conf_predictions[-1],
            confidence_list=conf_predictions,
            pred_logit=None,
            # samples_pred_list=samples_pred_list,
            # coord_list=coord_list,
            prediction_normal=normal_outs[-1],
            normal_out_list=normal_outs,
            low_resolution_init=low_resolution_init,
        )

        return outputs


if __name__ == "__main__":
    cfg = Config.fromfile('/cpfs01/shared/public/users/mu.hu/monodepth/mono/configs/RAFTDecoder/vit.raft.full2t.py')
    cfg.model.decode_head.in_channels = [384, 384, 384, 384]
    cfg.model.decode_head.feature_channels = [96, 192, 384, 768]
    cfg.model.decode_head.decoder_channels = [48, 96, 192, 384, 384]
    cfg.model.decode_head.hidden_channels = [48, 48, 48, 48, 48]
    cfg.model.decode_head.up_scale = 7

    cfg.model.decode_head.use_cls_token = True
    cfg.model.decode_head.num_register_tokens = 4
    vit_feature = [[torch.rand((2, (74 * 74) + 5, 384)).cuda(),\
                    torch.rand((2, (74 * 74) + 5, 384)).cuda(), \
                    torch.rand((2, (74 * 74) + 5, 384)).cuda(), \
                    torch.rand((2, (74 * 74) + 5, 384)).cuda()], (2, 74, 74, 1036, 1036, 4)]

    decoder = RAFTDepthNormalDPT5(cfg).cuda()
    output = decoder(vit_feature)
    temp = 1




