from .unet_autoenc import *
from .sr_mot_enc import ResNetFaceEmbedding
import torch
import functools
from insightface.app import FaceAnalysis
import cv2

@dataclass
class BeatGANsSRConfig(BeatGANsAutoencConfig):
    # number of style channels
    mlp_dim: int = 512
    context_dim: int = 512
    finetune = False
    ft_img_path = '/data/yangjiarui/diffae/datasets/finetune/tmp1/0177.png'
    def make_model(self):
        return BeatGANsSRModel(self)

class BeatGANsSRModel(BeatGANsAutoencModel):
    def __init__(self, conf: BeatGANsSRConfig):
        super().__init__(conf)
        self.conf = conf
        self.finetune = conf.finetune

        self.input_blocks[0] = TimestepEmbedSequential(
                conv_nd(2, 6, 128, 3, padding=1))

        self._trainable_modules = [self.encoder.input_blocks]
        self._replace_attention_blocks(self.input_blocks)
        self._replace_attention_blocks(self.middle_block)
        self._replace_attention_blocks(self.output_blocks)

        if self.finetune:
            for p in self.parameters():
                p.requires_grad = False

            face_app = FaceAnalysis(name='antelopev2/antelopev2', root='./checkpoints',
                                         providers=['CPUExecutionProvider'])
            face_app.prepare(ctx_id=0, det_size=(256, 256))
            img_tmp = cv2.imread(conf.ft_img_path)
            faces = face_app.get(img_tmp)

            largest_face = sorted(faces, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[
                -1]
            id_emb = torch.tensor(largest_face['embedding'], dtype=torch.float32)[None]
            id_emb = id_emb / torch.norm(id_emb, dim=1, keepdim=True)  # normalize
            self.learned_context = nn.Parameter(id_emb)
        else:
            for m in self._trainable_modules:
                for p in m.parameters():
                    p.requires_grad = True

    def _replace_attention_blocks(self, block_list):
        for i, block in enumerate(block_list):
            if isinstance(block, nn.Sequential):
                for name, layer in block._modules.items():
                    if isinstance(layer, AttentionBlock):
                        new_block = CrossAttentionBlock(layer, self.conf.context_dim)
                        block._modules[name] = new_block
                        self._trainable_modules.append(new_block)
            elif isinstance(block, AttentionBlock):
                new_block = CrossAttentionBlock(block, self.conf.context_dim)
                block_list[i] = new_block
                self._trainable_modules.append(new_block)

    def forward(self,
                x,
                t,
                x_start=None,
                cond=None,
                noise=None,
                **kwargs):

        batch_size = x.size(0)
        lr, warp, context, ref = kwargs.get('lr'), kwargs.get('warp'), kwargs.get('id_emb'), kwargs.get('ref')
        x = torch.cat((x, lr), dim=1)
        # cond = self.encode(torch.cat((warp, ref), dim=1))['cond']
        cond = self.encode(warp)['cond']

        if self.finetune:
            context = self.learned_context.expand(batch_size, -1, -1)

        if t is not None:
            _t_emb = timestep_embedding(t, self.conf.model_channels)
            _t_cond_emb = timestep_embedding(t, self.conf.model_channels)
        else:
            # this happens when training only autoenc
            _t_emb = None
            _t_cond_emb = None

        res = self.time_embed.forward(
            time_emb=_t_emb,
            cond=cond,
            time_cond_emb=_t_cond_emb,
        )

        emb = res.time_emb
        cond_emb = res.emb

        # where in the model to supply time conditions
        enc_time_emb = emb
        mid_time_emb = emb
        dec_time_emb = emb
        # where in the model to supply style conditions
        enc_cond_emb = cond_emb
        mid_cond_emb = cond_emb
        dec_cond_emb = cond_emb

        hs = [[] for _ in range(len(self.conf.channel_mult))]

        if x is not None:
            h = x.type(self.dtype)
            k = 0
            for i in range(len(self.input_num_blocks)):
                for j in range(self.input_num_blocks[i]):
                    h = self.input_blocks[k](h,
                                             emb=enc_time_emb,
                                             cond=enc_cond_emb,
                                             context=context)

                    hs[i].append(h)
                    k += 1
            assert k == len(self.input_blocks)

            # middle blocks
            h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb, context=context)
        else:
            # no lateral connections
            # happens when training only the autonecoder
            h = None
            hs = [[] for _ in range(len(self.conf.channel_mult))]

        # output blocks
        k = 0
        for i in range(len(self.output_num_blocks)):
            for j in range(self.output_num_blocks[i]):
                try:
                    lateral = hs[-i - 1].pop()
                    # print(i, j, lateral.shape)
                except IndexError:
                    lateral = None
                    # print(i, j, lateral)

                h = self.output_blocks[k](h,
                                          emb=dec_time_emb,
                                          cond=dec_cond_emb,
                                          lateral=lateral,
                                          context=context)

                k += 1

        pred = self.out(h)

        return AutoencReturn(pred=pred, cond=cond)


class AutoencReturn(NamedTuple):
    pred: Tensor
    cond: Tensor = None

if __name__ == '__main__':
    from templates import *
    from templates_latent import *
    import torch

    conf = ffhq256_autoenc_sr()
    state_dict = torch.load('/data/yangjiarui/diffae/checkpoints/last.ckpt', map_location='cpu')['state_dict']
    new_weights = {}
    for key, value in state_dict.items():
        if key.startswith('model.'):
            new_key = key[6:]
        else:
            new_key = key
        new_weights[new_key] = value

    model = conf.make_model_conf().make_model()
    model.load_state_dict(new_weights, strict=False)

    # for name, param in model.named_parameters():
    #     if param.requires_grad:
    #         print(name)
    lr = torch.randn((2,3,256,256)).cuda()
    hr = torch.randn((2,3,256,256)).cuda()
    ref = torch.randn((2,3,256,256)).cuda()
    warp = torch.randn((2,3,256,256)).cuda()
    kwargs = {'lr': lr, "ref": ref, 'warp': warp}
    t = torch.randn((2)).cuda()
    model.cuda()

    out = model(x=hr, t=t, x_start=hr, **kwargs)
    print(out.pred.shape)
