import torch
import torch.nn as nn
import torch.nn.functional as F

import models
from models import register
from utils import make_coord, to_cuda


@register('liif')
class LIIF(nn.Module):

    def __init__(self, encoder_spec, imnet_spec=None,
                 local_ensemble=True, feat_unfold=True, cell_decode=True):
        '''
        Args:
            encoder_spec: image encoder args, you can access both parameter and model itself
                out_dim: encoder output channel dim, , the total output shape (N, C, H_l, W_l)
            imnet_spec: LIIF decoder args
                out_dim: decoder output channel dim
        '''
        super().__init__()
        self.local_ensemble = local_ensemble
        self.feat_unfold = feat_unfold
        self.cell_decode = cell_decode

        self.encoder = models.make(encoder_spec)

        if imnet_spec is not None:
            imnet_in_dim = self.encoder.out_dim
            if self.feat_unfold:
                # do feature unfold, concate 9 nearest pixel embedding
                imnet_in_dim *= 9
            imnet_in_dim += 2 # attach coord
            if self.cell_decode:
                # do cell decoding, concate (s_w, s_h)
                imnet_in_dim += 2
            self.imnet = models.make(imnet_spec, args={'in_dim': imnet_in_dim})
            self.imnet_name = imnet_spec['name']
        else:
            self.imnet = None
            self.imnet_name = None

    def gen_feat(self, inp):
        # feat: shape (N, C, H_l, W_l)
        self.feat = self.encoder(inp)
        return self.feat

    def query_rgb(self, coord, cell=None, band_coord=None):
        '''
        Args:
            coord: shape (N, X, 2), X is a batch size between [0, H_h * W_h]
            cell: shape (N, X, 2), X is a batch size between [0, H_h * W_h], 
                the cell size of each pixel -> (2/H_h, 2/W_h)
            band_coord: shape (N, C or num_band_sample, 2), the band (start, end) interval, [-1, 1]
        '''
        # feat: shape (N, C, H_l, W_l)
        feat = self.feat

        if self.imnet is None:
            # coord.flip(-1).unsqueeze(1): shape (N, 1, X, 2)
            # F.grid_sample(...): shape (N, C, 1, X) 
            # ret: shape (N, X, C)
            # do nearest interpolation to get pixel features for high-res image given low-res feat
            ret = F.grid_sample(feat, coord.flip(-1).unsqueeze(1),
                mode='nearest', align_corners=False)[:, :, 0, :] \
                .permute(0, 2, 1)
            return ret

        if self.feat_unfold:
            # F.unfold(feat, kernel_size=3, padding=1): shape (N, 9C, H_l * W_l)
            # feat: shape (N, 9C, H_l, W_l)
            feat = F.unfold(feat, kernel_size=3, padding=1).view(
                feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])

        if self.local_ensemble:
            vx_lst = [-1, 1]
            vy_lst = [-1, 1]
            eps_shift = 1e-6
        else:
            vx_lst, vy_lst, eps_shift = [0], [0], 0

        # field radius (global: [-1, 1]), half of cell size
        # rx: 2/(H_l * 2)
        rx = 2 / feat.shape[-2] / 2
        # ry: 2/(W_l * 2)
        ry = 2 / feat.shape[-1] / 2

        # feat_coord: shape (H_l, W_l, 2)
        feat_coord = to_cuda(make_coord(feat.shape[-2:], flatten=False))
        # feat_coord: shape (2, H_l, W_l)
        feat_coord = feat_coord.permute(2, 0, 1)
        # feat_coord: shape (1, 2, H_l, W_l)
        feat_coord = feat_coord.unsqueeze(0)
        # feat_coord: shape (N, 2, H_l, W_l)
        feat_coord = feat_coord.expand(feat.shape[0], 2, *feat.shape[-2:])
        # feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() \
        #     .permute(2, 0, 1) \
        #     .unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])

        preds = []
        areas = []
        for vx in vx_lst:
            for vy in vy_lst:
                # get nearest (x -/+ r, y -/+ r) coords
                # coord_: shape (N, X, 2)
                coord_ = coord.clone()
                coord_[:, :, 0] += vx * rx + eps_shift
                coord_[:, :, 1] += vy * ry + eps_shift
                coord_.clamp_(-1 + 1e-6, 1 - 1e-6)

                # q_feat: get the nearest pixel features for (x -/+ r, y -/+ r) in low-res
                #       V matrix
                # feat: shape (N, 9C, H_l, W_l)
                # coord_.flip(-1).unsqueeze(1): (N, 1, X, 2)
                # F.grid_sample(..): shape (N, 9C, 1, X)
                # q_feat: shape (N, X, 9C)
                q_feat = F.grid_sample(
                    feat, coord_.flip(-1).unsqueeze(1),
                    mode='nearest', align_corners=False)[:, :, 0, :] \
                    .permute(0, 2, 1)

                # q_coord: get the nearest coord for (x -/+ r, y -/+ r) in low-res grid 
                #       K matrix
                # feat_coord: shape (N, 2, H_l, W_l)
                # coord_.flip(-1).unsqueeze(1): (N, 1, X, 2)
                # F.grid_sample(..): shape (N, 2, 1, X)
                # q_coord: shape (N, X, 2)
                q_coord = F.grid_sample(
                    feat_coord, coord_.flip(-1).unsqueeze(1),
                    mode='nearest', align_corners=False)[:, :, 0, :] \
                    .permute(0, 2, 1)

                # I don't think (. *H_l, *W_l) is necessary
                # rel_coord: shape (N, X, 2), relative distance between query point and key point
                rel_coord = coord - q_coord
                #  (x_q - x_k) * H_l  ???
                rel_coord[:, :, 0] *= feat.shape[-2]
                #  (y_q - y_k) * W_l  ???
                rel_coord[:, :, 1] *= feat.shape[-1]
                # inp: shape (N, X, 9C+2)
                inp = torch.cat([q_feat, rel_coord], dim=-1)

                if self.cell_decode:
                    # rel_cell: shape (N, X, 2), cell size -> (2/H_h, 2/W_h)
                    rel_cell = cell.clone()
                    # rel_cell: shape (N, X, 2), cell size -> (2*H_l/H_h, 2*W_l/W_h)
                    rel_cell[:, :, 0] *= feat.shape[-2]
                    rel_cell[:, :, 1] *= feat.shape[-1]
                    # inp: shape (N, X, 9C+4)
                    inp = torch.cat([inp, rel_cell], dim=-1)

                bs, q = coord.shape[:2]

                if self.imnet_name == "mlp":
                    # pred: shape (N, X, mlp_out_dim = num_b = C or num_band_sample)
                    pred = self.imnet(inp.reshape(bs * q, -1)).reshape(bs, q, -1)
                elif self.imnet_name == "banddec":
                    # pred: shape (N, X, num_b = C or num_band_sample)
                    pred = self.imnet(inp, band_coord)
                preds.append(pred)

                # area: shape (N, X), area of pixel -> (x_q - x_k) * H_l * (y_q - y_k) * W_l
                area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
                areas.append(area + 1e-9)

        # torch.stack(areas): shape (4, N, X)
        # tot_area: shape (N, X), the total area of four rectangle
        tot_area = torch.stack(areas).sum(dim=0)
        if self.local_ensemble:
            # why switch??????
            t = areas[0]; areas[0] = areas[3]; areas[3] = t
            t = areas[1]; areas[1] = areas[2]; areas[2] = t
        ret = 0
        for pred, area in zip(preds, areas):
            ret = ret + pred * (area / tot_area).unsqueeze(-1)
        return ret

    def forward(self, inp, coord, cell, band_coord):
        self.gen_feat(inp)
        return self.query_rgb(coord, cell, band_coord)
