import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from models import register
import models
from utils import make_coord
torch.cuda.empty_cache()
from models.e2d import InpaintGenerator
import os
# from clip import clip
from torchvision import transforms
from models.modeling.image_encoder import ImageEncoderViT



@register('text_liif')
class LIIF(nn.Module):

    def __init__(self, encoder_spec, imnet_spec=None,
                 local_ensemble=True, feat_unfold=True, cell_decode=True):
        super().__init__()
        self.local_ensemble = True
        self.feat_unfold = feat_unfold
        self.cell_decode = cell_decode
        self.e2d = InpaintGenerator()

        self.clip_encoder = MaskCLIP()
        self.l1_loss = nn.L1Loss()


        fir_dir = './0831a/_train_celebAHQ-64-128_liif/epoch-200.pth'
        fir_spec = torch.load(fir_dir)['model']
        self.fir_model = models.make(fir_spec, load_sd=False)

        self.conv = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=7, padding=3)

        if imnet_spec is not None:
            imnet_in_dim = 64
            if self.feat_unfold:
                imnet_in_dim *= 9
            imnet_in_dim += 2
            self.imnet = models.make(imnet_spec, args={'in_dim': imnet_in_dim})


    def gen_feat(self, masked_img, mask):

        masked_img = torch.cat([masked_img, mask], dim=1)
        feat = self.e2d(masked_img.float())
        return feat


    def query_rgb(self, hr_coord):

        pred_img = self.liif(self.feat, hr_coord, self.imnet)

        loss = self.l1_loss(pred_img, self.gt_img)
       
        return pred_img, loss


    def forward(self, masked_img_feat, gt_img_feat, mask, gt_img, masked_img, hr_coord):

        b, _, h, w = masked_img.size()

        self.clip_feat = self.clip_encoder(masked_img)
        self.feat = self.gen_feat(masked_img, mask)
        self.clip_feat, l = self.fir_model(self.clip_feat, gt_img_feat, mask, hr_coord)
  

        self.feat = torch.cat([self.feat, self.clip_feat], dim=1)
        self.feat = self.conv(self.feat)
        self.gt_img = gt_img

        return self.query_rgb(hr_coord)


    def liif(self, feat, coord,  model):


        N, C, _, _ = feat.shape
        feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda()\
            .permute(2, 0, 1) \
            .unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])

        if self.feat_unfold:
            feat = F.unfold(feat, 3, padding=1).view(feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])
        if self.local_ensemble:
            vx_lst = [-1, 1]
            vy_lst = [-1, 1]
            eps_shift = 1e-6
        else:
            vx_lst, vy_lst, eps_shift = [0], [0], 0
        rx = 2 / feat.shape[-2] / 2
        ry = 2 / feat.shape[-1] / 2
        preds = []
        areas = []

        for vx in vx_lst:
            for vy in vy_lst:

                coord_ = coord.clone()
                coord_[:, :, 0] += vx * rx + eps_shift
                coord_[:, :, 1] += vy * ry + eps_shift
                coord_.clamp_(-1 + 1e-6, 1 - 1e-6)

                q_feat = F.grid_sample(
                    feat, coord_.flip(-1).unsqueeze(1),
                    mode='nearest', align_corners=False)[:, :, 0, :] \
                    .permute(0, 2, 1)

                q_coord = F.grid_sample(
                    feat_coord, coord_.flip(-1).unsqueeze(1),
                    mode='nearest', align_corners=False)[:, :, 0, :] \
                    .permute(0, 2, 1)


                rel_coord = coord - q_coord
                rel_coord[:, :, 0] *= feat.shape[-2]
                rel_coord[:, :, 1] *= feat.shape[-1]

                inp = torch.cat([q_feat, rel_coord], dim=-1)

                bs, q = coord.shape[:2]
                a = inp.view(bs * q, -1)
                pred = model(a).view(bs, q, -1)
                preds.append(pred)
                area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
                areas.append(area + 1e-9)

        tot_area = torch.stack(areas).sum(dim=0)
        if self.local_ensemble:
            t = areas[0]; areas[0] = areas[3]; areas[3] = t
            t = areas[1]; areas[1] = areas[2]; areas[2] = t
        ret = 0
        for pred, area in zip(preds, areas):
            ret = ret + pred * (area / tot_area).unsqueeze(-1)
        return ret
