import torch
import torch.nn as nn
from typing import List, Dict, Optional, Tuple


def calculate_filters(
    batched_cameras: List,
    xyz: torch.Tensor,
    opacity: torch.Tensor,
    scaling: torch.Tensor,
    rotation: torch.Tensor,
    radius_clip: float = 0.0
) -> Tuple[List[torch.Tensor], None, None]:
    filters = []
    for camera in batched_cameras:
        view_matrix = camera.world_view_transform
        proj_matrix = camera.full_proj_transform

        xyz_homo = torch.cat([xyz, torch.ones(xyz.shape[0], 1, device=xyz.device)], dim=1)
        xyz_view = (view_matrix @ xyz_homo.T).T
        xyz_ndc = (proj_matrix @ xyz_homo.T).T
        xyz_ndc = xyz_ndc[:, :3] / (xyz_ndc[:, 3:4] + 1e-8)

        in_frustum = (
            (xyz_ndc[:, 0] >= -1.2) & (xyz_ndc[:, 0] <= 1.2) &
            (xyz_ndc[:, 1] >= -1.2) & (xyz_ndc[:, 1] <= 1.2) &
            (xyz_view[:, 2] > 0.01)
        )

        visible_opacity = opacity.squeeze(-1) > 0.005
        valid_mask = in_frustum & visible_opacity

        visible_indices = torch.nonzero(valid_mask).squeeze(-1)
        filters.append(visible_indices)

    return filters, None, None


def ssd_offload_train_one_batch(
    gaussians,
    scene,
    batched_cameras: List,
    background: torch.Tensor,
    pipe_args,
    storage_adapter,
    training_schedule: List[int],
    render_fn,
    loss_fn,
    iteration: int = 0
) -> Tuple[List[torch.Tensor], List[int], float]:
    import time
    start_time = time.time()
    
    bsz = len(batched_cameras)
    current_camera_ids = [cam.global_idx for cam in batched_cameras]

    storage_adapter.prefetch_for_next_iteration(iteration, bsz, training_schedule)

    visible_block_ids = set()
    for camera in batched_cameras:
        blocks = storage_adapter.get_visible_blocks(camera.global_idx)
        visible_block_ids.update(blocks)
    visible_block_ids = sorted(list(visible_block_ids))

    if len(visible_block_ids) == 0:
        return [], list(range(bsz)), 0.0

    active_blocks_ram = storage_adapter.wait_and_load_blocks(iteration)

    total_n_gaussians = storage_adapter.total_gaussians

    with torch.no_grad():
        gpu_tensors, retention_stats = gaussians.gpu_working_set_manager.load_visible_blocks_with_retention(
            visible_block_ids=visible_block_ids,
            active_blocks_ram=active_blocks_ram,
            enable_retention=(iteration > 1)
        )

        gaussians._xyz = nn.Parameter(gpu_tensors['xyz'].requires_grad_(True))
        gaussians._scaling = nn.Parameter(gpu_tensors['scaling'].requires_grad_(True))
        gaussians._rotation = nn.Parameter(gpu_tensors['rotation'].requires_grad_(True))
        gaussians._opacity = nn.Parameter(gpu_tensors['opacity'].requires_grad_(True))
        gaussians._features_dc = nn.Parameter(gpu_tensors['features_dc'].requires_grad_(True))
        gaussians._features_rest = nn.Parameter(gpu_tensors['features_rest'].requires_grad_(True))

    n_gaussians = gaussians._xyz.shape[0]
    if n_gaussians == 0:
        return [], list(range(bsz)), 0.0

    with torch.no_grad():
        filters, _, _ = calculate_filters(
            batched_cameras,
            gaussians._xyz,
            gaussians.get_opacity,
            gaussians.get_scaling,
            gaussians.get_rotation,
        )

    losses = []
    for micro_idx, camera in enumerate(batched_cameras):
        filter_indices = filters[micro_idx]

        if len(filter_indices) == 0:
            dummy_loss = torch.tensor(0.0, device='cuda', requires_grad=True)
            losses.append(dummy_loss)
            continue

        rendered_image = render_fn(
            gaussians=gaussians,
            camera=camera,
            filter_indices=filter_indices,
            background=background,
            pipe_args=pipe_args
        )

        gt_image = camera.original_image.cuda()
        loss = loss_fn(rendered_image, gt_image)
        loss.backward()
        losses.append(loss.detach())

    gaussians.optimizer.step()
    gaussians.optimizer.zero_grad()

    storage_adapter.update_ram_cache(gaussians, visible_block_ids)

    batch_time = time.time() - start_time
    ordered_cams = list(range(bsz))

    return losses, ordered_cams, batch_time
