import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
from einops import rearrange
import math
import random
from .encoder import CrossViewEncoder
from .lifting import lifting
def build_backbone(name, type):
    if name == 'dinov2':
        assert type in ['vits14', 'vitb14', 'vitl14', 'vitg14']
        backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_{}'.format(type))
        down_rate = 14
        if type == 'vitb14':
            backbone_dim = 768
        elif type == 'vits14':
            backbone_dim = 384
    return backbone, down_rate, backbone_dim

class LEAP(nn.Module):
    def __init__(self, encoder, lifting, render) -> None:
        super(LEAP, self).__init__()
        # build cross-view feature encoder
        # self.backbone = build_backbone('dinov2', 'vitb14')[0]
        self.proj = nn.Conv2d(3, 256, 1)
        self.encoder = encoder

        # build 2D-3D lifting
        self.lifting = lifting

        # build 3D-2D render module
        self.render_module = render


    def extract_feature(self, x, return_h_w=False):
        b, _, h_origin, w_origin = x.shape
        out = self.backbone.get_intermediate_layers(x, n=1)[0]
        h, w = int(h_origin / self.backbone.patch_embed.patch_size[0]), int(w_origin / self.backbone.patch_embed.patch_size[1])
        dim = out.shape[-1]
        out = out.reshape(b, h, w, dim).permute(0,3,1,2)
        return out


    def forward(self, images, target_cameras):
        '''
        imgs in shape [b,t,C,H,W]
        '''
        b,t = images.shape[:2]
        images = images * 2.0 - 1.0
        images = rearrange(images, 'b t c h w -> (b t) c h w')

        # features = self.extract_feature(images)
        features = self.proj(images)

        features = rearrange(features, '(b t) c h w -> b t c h w', b=b, t=t)

        features = self.encoder(features)      

        features_3d = self.lifting(features)         

        results = self.render_module(features_3d, target_cameras)
        return results
    
    def set_use_memory_efficient_attention_xformers(
        self, valid: bool, attention_op=None
    ) -> None:
        def fn_recursive_set_mem_eff(module: torch.nn.Module):
            if hasattr(module, "set_use_memory_efficient_attention_xformers"):
                module.set_use_memory_efficient_attention_xformers(valid, attention_op)

            for child in module.children():
                fn_recursive_set_mem_eff(child)

        for module in self.children():
            if isinstance(module, torch.nn.Module):
                fn_recursive_set_mem_eff(module)




