from typing import List
import os
import torch
from torch import nn
from einops import rearrange, repeat, reduce
import torch.nn.functional as F

from src.models.blocks.transformer_t import FuseTransformer, Transformer
from src.models.blocks.resnet import ResBlock
from src.utils.project import ray_sample, get_ray_limits_box


def compute_projections(points, cameras):
    # points: b x n x 3
    # cameras: b x f x 32
    b, f,  _ = cameras.shape
    intrinsics = cameras[..., :16].reshape(
        b, f, 4, 4
    ) 
    c2w = cameras[..., 16:].reshape(b, f, 4, 4) 
    points_h = torch.cat(
        [points, torch.ones_like(points[..., :1])], dim=-1
    )  # (B, N, 4)

    w2p = torch.matmul(intrinsics, torch.inverse(c2w))  
    projections = torch.einsum("b f i j, b n j -> b f n i", w2p, points_h)
    # TODO: check is this correct
    pixel_locations = projections[..., :2] / torch.clamp(
        projections[..., 2:3], min=1e-8
    )
    pixel_locations = torch.clamp(pixel_locations, min=-10, max=10) # to avoid grid sample nan
    mask = projections[..., 2] > 0  # opencv camera
    inbound = (pixel_locations[..., 0] <= 1.) & \
            (pixel_locations[..., 0] >= 0) & \
            (pixel_locations[..., 1] <= 1.) &\
            (pixel_locations[..., 1] >= 0)
    mask = mask & inbound
    return pixel_locations, mask


class FeatureAggregator(nn.Module):
    def __init__(
        self,
        dim: int,
        depth: int = 1,
        embed_dim: int = 0,
        use_resnet: bool = True,
        use_fhw_attn: bool = False,
        pos_encs: List[str] = ["abs", "ref"],
        voxel_size: int = 16,
        level: int = 4,
    ):
        super().__init__()

        d_head = 64
        n_heads = 8
        # feature extract -> t1 -> t2 -> t3 (n s) (n k) (n)
        self.voxel_size = voxel_size
        self.voxel = nn.Parameter(torch.randn(voxel_size, 3*voxel_size, 64)*0.001)
        voxel_list = []
        for i in range(level):
            voxel_list.append(nn.Parameter(torch.randn(voxel_size//(2**i), 3*voxel_size//(2**i), 64)*0.001))
        self.voxel_list = nn.ParameterList(voxel_list)
        self.T1 = Transformer(
            in_channels=64,
            out_channels=32,
            context_dim=4,
            n_heads=n_heads,
            d_head=d_head,
            depth=depth,
        )
        self.view_fuse = nn.Linear(32, 1, bias=False)
        self.T2 = FuseTransformer(
            in_channels=32*3,
            out_channels=32,
            n_heads=1,
            d_head=32,
            depth=depth,
        )
        self.ray_fuse = nn.Linear(32, 1, bias=False)
        self.to_out = nn.Linear(32*level, 4)
        self.use_resnet = use_resnet
        if use_resnet:
            self.resnet = ResBlock(32, 1280, use_scale_shift_norm=True)
        self.perturb = False

    def construct_spatial_volume(self, query, features, t_embed = None, cameras=None):
        b, f, c, h, w = features.shape
        V = query.shape[0]
        device = features.device
        query = repeat(query, 'h w c -> b (h w) c', b=b)
        context = rearrange(features, 'b f c h w -> b (f h w) c')
        embd = rearrange(t_embed, '(b f) c -> b f c', b=b)[:, 0]
        volume = self.T1(query, embd, context=context)
        volume = rearrange(volume, "b (h i w) c -> (b i) c h w", b=b, i=3, h=V, w=V)
        embd = rearrange(t_embed, '(b f) c -> b f c', b=b)[:, 0]
        if self.use_resnet:
            volume = self.resnet(volume, embd)

        return volume
    def rendering(self, images, volume, cameras, t_emb):
        device = volume.device
        b, f, c, h, w = images.shape
        c2w = cameras[..., 16:].reshape(-1, 4, 4)
        intrinsics = cameras[..., :16].reshape(-1, 4, 4)[:, :3, :3]
        ray_origins, ray_directions = ray_sample(c2w, intrinsics, h)
        num_steps = h
        volume = repeat(volume, 'bi c h w -> (bi f) c h w', f=f)
        nears, fars = get_ray_limits_box(
            ray_origins, ray_directions, 2
        )

        is_ray_valid = nears > fars

        if torch.any(is_ray_valid).item():
            nears[~is_ray_valid] = nears[is_ray_valid].min()
            fars[~is_ray_valid] = nears[is_ray_valid].max()

        z_vals = torch.linspace(0.0, 1.0, num_steps, device=device) # s
        z_vals = repeat(z_vals, 's -> (b f) 1 s 1', b = b, f=f)
        nears = rearrange(nears, 'bf n 1 -> bf n 1 1')
        fars = rearrange(fars, 'bf n 1 -> bf n 1 1')
        z_vals = nears + (fars - nears) * z_vals # b x n x s x 1
        # perturb z_vals
        sample_dist = (fars - nears) / num_steps
        if self.perturb:
            z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist

        ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1)
        xyzs = ray_origins.unsqueeze(-2) + ray_directions.unsqueeze(-2) * z_vals # bf x n x s x 3
        # visual_camera(cameras, points=xyzs)
        xy, yz, xz = xyzs[..., [0, 1]], xyzs[..., [1, 2]] , xyzs[..., [0, 2]] 
        coordinates = torch.stack([xy, yz, xz], dim=-1) # bf n s c i
        xyzs = rearrange(coordinates, 'bf n s c i-> (bf i) n s c')
        sampled_features = torch.nn.functional.grid_sample(volume,
                                    xyzs,
                                    mode='bilinear', padding_mode='zeros', align_corners=False)
        sampled_features = rearrange(sampled_features, '(bf i) c n s -> (bf n) s (i c)', n=h*w, s=num_steps, i=3)
        embd = repeat(t_emb, 'bf c -> (bf n) 1 c', n=h*w)
        sampled_features = self.T2(sampled_features, embd)
        sampled_features = rearrange(sampled_features, '(bf n) s c -> bf n s c', n=h*w, s=num_steps)
        weight = self.ray_fuse(sampled_features) # b x n x s x 1
        weight = F.softmax(weight, dim=-2)
        images = (weight * sampled_features).sum(dim=-2) # b x n x c

        return images

    def forward(
        self,
        features,
        t_emb = None,
        cameras = None, 
    ):
        b, f, _, h, w = features.shape
        
        if t_emb is None:
            t_emb = torch.zeros(b*f, 1280, device=features.device)
        images_list = []
        for query in self.voxel_list:
            triplane = self.construct_spatial_volume(query, features, t_emb, cameras)
            images = self.rendering(features, triplane, cameras, t_emb)
            images_list.append(images)
        images = torch.stack(images_list, dim=-1) # bf n c
        images = rearrange(images, 'bf n c i -> bf n (i c)')
        images = self.to_out(images)
        images = rearrange(images, '(b f) (h w) c -> b f c h w', f=f, h=h, w=w)

        return images
