import torch
import numpy as np
from typing import Dict, List, Optional, Tuple


class GPUWorkingSet:
    def __init__(self, num_total_gaussians: int, block_size: int, device: str = 'cuda'):
        self.num_total = num_total_gaussians
        self.block_size = block_size
        self.device = torch.device(device)
        self.num_blocks = (num_total_gaussians + block_size - 1) // block_size

        self.gpu_xyz: Optional[torch.Tensor] = None
        self.gpu_scaling: Optional[torch.Tensor] = None
        self.gpu_rotation: Optional[torch.Tensor] = None
        self.gpu_opacity: Optional[torch.Tensor] = None
        self.gpu_features_dc: Optional[torch.Tensor] = None
        self.gpu_features_rest: Optional[torch.Tensor] = None

        self.previous_blocks: List[int] = []
        self.previous_block_data: Dict[int, Dict[str, torch.Tensor]] = {}

        self.hotspot_stats = {
            'total_iterations': 0,
            'total_blocks_loaded': 0,
            'total_blocks_retained': 0,
            'total_blocks_cold': 0,
        }

        self.loaded_blocks: List[int] = []
        self.local_to_global_idx: Optional[torch.Tensor] = None

    def load_visible_blocks_with_retention(
        self,
        visible_block_ids: List[int],
        active_blocks_ram: Dict[int, torch.Tensor],
        enable_retention: bool = True
    ) -> Tuple[Dict[str, torch.Tensor], Dict]:
        visible_set = set(visible_block_ids)
        prev_set = set(self.previous_blocks)

        hotspot_blocks = visible_set & prev_set if enable_retention else set()
        cold_blocks = visible_set - hotspot_blocks

        all_tensors = {
            'xyz': [], 'scaling': [], 'rotation': [],
            'opacity': [], 'features_dc': [], 'features_rest': []
        }

        all_gaussian_ids = []
        new_block_data = {}

        for block_id in sorted(visible_block_ids):
            start_idx = block_id * self.block_size
            end_idx = min(start_idx + self.block_size, self.num_total)
            block_len = end_idx - start_idx
            all_gaussian_ids.extend(range(start_idx, end_idx))

            if enable_retention and block_id in hotspot_blocks and block_id in self.previous_block_data:
                data = self.previous_block_data[block_id]
            else:
                if block_id not in active_blocks_ram:
                    continue
                block_tensor = active_blocks_ram[block_id]
                if not block_tensor.is_cuda:
                    block_tensor = block_tensor.cuda(non_blocking=True)
                data = {
                    'xyz': block_tensor[:block_len, 0:3],
                    'scaling': block_tensor[:block_len, 3:6],
                    'rotation': block_tensor[:block_len, 6:10],
                    'opacity': block_tensor[:block_len, 10:11],
                    'features_dc': block_tensor[:block_len, 11:14],
                    'features_rest': block_tensor[:block_len, 14:59]
                }

            new_block_data[block_id] = data
            for key in all_tensors:
                all_tensors[key].append(data[key])

        gpu_tensors = {}
        for key, tensor_list in all_tensors.items():
            if tensor_list:
                gpu_tensors[key] = torch.cat(tensor_list, dim=0)
            else:
                gpu_tensors[key] = torch.empty(0, 3 if key != 'rotation' else 4, device=self.device)

        self.gpu_xyz = gpu_tensors['xyz']
        self.gpu_scaling = gpu_tensors['scaling']
        self.gpu_rotation = gpu_tensors['rotation']
        self.gpu_opacity = gpu_tensors['opacity']
        self.gpu_features_dc = gpu_tensors['features_dc']
        self.gpu_features_rest = gpu_tensors['features_rest']

        self.local_to_global_idx = torch.tensor(all_gaussian_ids, dtype=torch.long, device=self.device)
        self.loaded_blocks = sorted(visible_block_ids)
        self.previous_blocks = list(visible_block_ids)
        self.previous_block_data = new_block_data

        hit_rate = len(hotspot_blocks) / len(visible_set) if visible_set else 0.0
        self.hotspot_stats['total_iterations'] += 1
        self.hotspot_stats['total_blocks_loaded'] += len(visible_block_ids)
        self.hotspot_stats['total_blocks_retained'] += len(hotspot_blocks)
        self.hotspot_stats['total_blocks_cold'] += len(cold_blocks)

        memory_mb = len(all_gaussian_ids) * 59 * 4 / (1024**2)

        retention_stats = {
            'total_count': len(visible_block_ids),
            'hotspot_count': len(hotspot_blocks),
            'cold_count': len(cold_blocks),
            'hit_rate': hit_rate,
            'memory_mb': memory_mb,
            'num_gaussians': len(all_gaussian_ids)
        }

        return gpu_tensors, retention_stats

    def get_retention_stats(self) -> Dict:
        total_loaded = self.hotspot_stats['total_blocks_loaded']
        total_retained = self.hotspot_stats['total_blocks_retained']
        return {
            'total_iterations': self.hotspot_stats['total_iterations'],
            'total_blocks_loaded': total_loaded,
            'total_blocks_retained': total_retained,
            'total_blocks_cold': self.hotspot_stats['total_blocks_cold'],
            'avg_hit_rate': total_retained / total_loaded if total_loaded > 0 else 0.0,
            'bandwidth_savings_ratio': total_retained / total_loaded if total_loaded > 0 else 0.0
        }

    def clear(self):
        self.gpu_xyz = None
        self.gpu_scaling = None
        self.gpu_rotation = None
        self.gpu_opacity = None
        self.gpu_features_dc = None
        self.gpu_features_rest = None
        self.local_to_global_idx = None
        self.loaded_blocks = []
