import numpy as np
import torch
import torch.nn as nn
from pathlib import Path
from typing import List, Dict, Optional
from datetime import datetime

from .log_storage_manager import LogStorageManager
from .tiered_cache_manager import TieredCacheManager
from .gaussian_block import FrustumCuller, compute_morton_code, compute_block_bounds


class SSDStorageAdapter:
    def __init__(
        self,
        gaussians,
        cameras: List,
        storage_dir: str,
        block_size: int = 4096,
        max_ram_gb: float = 16.0
    ):
        self.gaussians = gaussians
        self.cameras = cameras
        self.block_size = block_size
        self.max_ram_gb = max_ram_gb

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.storage_dir = Path(storage_dir) / timestamp
        self.storage_dir.mkdir(parents=True, exist_ok=True)

        self._sort_by_morton_code()
        self.num_points = gaussians.get_xyz.shape[0]
        self.num_blocks = (self.num_points + block_size - 1) // block_size
        self.total_gaussians = self.num_points

        self._create_base_file()
        self._compute_block_bounds()
        self._initialize_storage()

    def _sort_by_morton_code(self):
        with torch.no_grad():
            xyz_cpu = self.gaussians.get_xyz.cpu()
            global_min = xyz_cpu.min(dim=0).values
            global_max = xyz_cpu.max(dim=0).values

            morton_codes = compute_morton_code(xyz_cpu, global_min, global_max)
            sorted_indices = torch.argsort(morton_codes)

            for attr in ['_xyz', '_scaling', '_rotation', '_opacity', '_features_dc', '_features_rest']:
                if hasattr(self.gaussians, attr):
                    tensor = getattr(self.gaussians, attr)
                    if tensor is not None:
                        sorted_tensor = tensor.data[sorted_indices].clone()
                        if tensor.is_pinned():
                            pinned = torch.empty_like(sorted_tensor).pin_memory()
                            pinned.copy_(sorted_tensor)
                            sorted_tensor = pinned
                        setattr(self.gaussians, attr, nn.Parameter(sorted_tensor, requires_grad=tensor.requires_grad))

    def _create_base_file(self):
        xyz = self.gaussians._xyz.detach().cpu()
        scales = self.gaussians._scaling.detach().cpu()
        rotations = self.gaussians._rotation.detach().cpu()
        opacity = self.gaussians._opacity.detach().cpu()
        features_dc = self.gaussians._features_dc.detach().cpu()
        features_rest = self.gaussians._features_rest.detach().cpu()

        full_tensor = torch.cat([
            xyz, scales, rotations, opacity, features_dc, features_rest
        ], dim=1)

        base_file = self.storage_dir / "base_file.bin"
        with open(base_file, 'wb') as f:
            f.write(full_tensor.numpy().astype(np.float32).tobytes())

        self.xyz_cpu_cache = xyz.numpy()

    def _compute_block_bounds(self):
        self.block_bounds = compute_block_bounds(self.xyz_cpu_cache, self.block_size)
        self.scene_min = self.xyz_cpu_cache.min(axis=0)
        self.scene_max = self.xyz_cpu_cache.max(axis=0)
        self.scene_center = (self.scene_min + self.scene_max) / 2.0
        self.scene_radius = np.linalg.norm(self.scene_max - self.scene_min) / 2.0

    def _initialize_storage(self):
        self.storage = LogStorageManager(
            storage_dir=str(self.storage_dir),
            block_size=self.block_size,
            num_blocks=self.num_blocks,
            point_dim=59
        )
        self.cache = TieredCacheManager(
            storage_manager=self.storage,
            max_ram_gb=self.max_ram_gb,
            block_size=self.block_size,
            point_dim=59
        )
        self.culler = FrustumCuller(self.block_bounds, self.scene_radius)

    def get_visible_blocks(self, camera_idx: int) -> List[int]:
        camera = self.cameras[camera_idx]
        view = camera.world_view_transform.cpu().numpy()
        proj = camera.projection_matrix.cpu().numpy()
        return self.culler.cull(view, proj)

    def prefetch_for_next_iteration(self, iteration: int, batch_size: int, training_schedule: List[int]):
        num_cameras = len(training_schedule)
        next_start = ((iteration) * batch_size) % num_cameras
        next_end = min(next_start + batch_size, num_cameras)
        next_camera_ids = training_schedule[next_start:next_end]

        if len(next_camera_ids) < batch_size:
            remaining = batch_size - len(next_camera_ids)
            next_camera_ids.extend(training_schedule[:remaining])

        all_blocks = set()
        for cam_id in next_camera_ids:
            all_blocks.update(self.get_visible_blocks(cam_id))

        self.cache.prefetch(list(all_blocks))

    def wait_and_load_blocks(self, iteration: int) -> Dict[int, torch.Tensor]:
        return dict(self.cache.cache_data)

    def get_training_schedule(self, shuffle: bool = False) -> List[int]:
        schedule = list(range(len(self.cameras)))
        if shuffle:
            import random
            random.shuffle(schedule)
        return schedule

    def update_ram_cache(self, gaussians, visible_block_ids: List[int], cuda_stream=None):
        block_tensors = []
        block_ids = []

        for block_id in visible_block_ids:
            start_idx = block_id * self.block_size
            end_idx = min(start_idx + self.block_size, self.num_points)

            block_tensor = torch.cat([
                gaussians._xyz.data[start_idx:end_idx],
                gaussians._scaling.data[start_idx:end_idx],
                gaussians._rotation.data[start_idx:end_idx],
                gaussians._opacity.data[start_idx:end_idx],
                gaussians._features_dc.data[start_idx:end_idx],
                gaussians._features_rest.data[start_idx:end_idx],
            ], dim=1)

            block_tensors.append(block_tensor)
            block_ids.append(block_id)

        if block_ids:
            self.cache.sync_from_gpu(block_ids, block_tensors, cuda_stream)

    def shutdown(self):
        self.cache.shutdown()
        self.storage.close()
