import os
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import torch


@dataclass
class BlockLocation:
    file_id: int
    offset: int
    size: int
    version: int


class LogStorageManager:
    def __init__(
        self,
        storage_dir: str,
        block_size: int,
        num_blocks: int,
        point_dim: int = 59,
        dtype: torch.dtype = torch.float32
    ):
        self.storage_dir = Path(storage_dir)
        self.storage_dir.mkdir(parents=True, exist_ok=True)

        self.block_size = block_size
        self.num_blocks = num_blocks
        self.point_dim = point_dim
        self.dtype = dtype
        self.bytes_per_point = point_dim * 4
        self.bytes_per_block = block_size * self.bytes_per_point

        self.index: Dict[int, BlockLocation] = {}
        self.index_lock = threading.Lock()

        self.file_handles: Dict[int, object] = {}
        self.file_paths: Dict[int, Path] = {0: self.storage_dir / "base_file.bin"}

        self.next_patch_id = 1
        self.patch_counter_lock = threading.Lock()

        self._initialize_index()

    def _initialize_index(self):
        base_file = self.file_paths[0]

        if base_file.exists():
            for block_id in range(self.num_blocks):
                self.index[block_id] = BlockLocation(
                    file_id=0,
                    offset=block_id * self.bytes_per_block,
                    size=self.bytes_per_block,
                    version=0
                )
        else:
            with open(base_file, 'wb') as f:
                f.seek(self.num_blocks * self.bytes_per_block - 1)
                f.write(b'\0')

            for block_id in range(self.num_blocks):
                self.index[block_id] = BlockLocation(
                    file_id=0,
                    offset=block_id * self.bytes_per_block,
                    size=self.bytes_per_block,
                    version=0
                )

    def read_blocks(self, block_ids: List[int]) -> Dict[int, torch.Tensor]:
        if not block_ids:
            return {}

        result = {}
        file_groups: Dict[int, List[Tuple[int, BlockLocation]]] = {}

        with self.index_lock:
            for block_id in block_ids:
                if block_id not in self.index:
                    raise ValueError(f"Block {block_id} not in index")
                location = self.index[block_id]
                if location.file_id not in file_groups:
                    file_groups[location.file_id] = []
                file_groups[location.file_id].append((block_id, location))

        for file_id, blocks in file_groups.items():
            file_path = self.file_paths[file_id]
            blocks.sort(key=lambda x: x[1].offset)

            with open(file_path, 'rb') as f:
                for block_id, location in blocks:
                    f.seek(location.offset)
                    data_bytes = f.read(location.size)
                    np_array = np.frombuffer(data_bytes, dtype=np.float32).copy()
                    curr_num_points = np_array.size // self.point_dim
                    tensor = torch.from_numpy(np_array).reshape(curr_num_points, self.point_dim)
                    result[block_id] = tensor

        return result

    def write_patch(self, block_dict: Dict[int, torch.Tensor]) -> int:
        if not block_dict:
            return -1

        with self.patch_counter_lock:
            patch_id = self.next_patch_id
            self.next_patch_id += 1

        timestamp = int(time.time() * 1000000)
        patch_file = self.storage_dir / f"patch_{patch_id:06d}_{timestamp}.bin"
        self.file_paths[patch_id] = patch_file

        current_offset = 0
        updated_locations = {}

        with open(patch_file, 'wb') as f:
            for block_id in sorted(block_dict.keys()):
                tensor = block_dict[block_id]
                if tensor.is_cuda:
                    tensor = tensor.cpu()
                data_bytes = tensor.numpy().astype(np.float32).tobytes()
                f.write(data_bytes)

                with self.index_lock:
                    old_version = self.index[block_id].version if block_id in self.index else 0
                    updated_locations[block_id] = BlockLocation(
                        file_id=patch_id,
                        offset=current_offset,
                        size=len(data_bytes),
                        version=old_version + 1
                    )
                current_offset += len(data_bytes)

        with self.index_lock:
            for block_id, location in updated_locations.items():
                self.index[block_id] = location

        return patch_id

    def close(self):
        for fh in self.file_handles.values():
            try:
                fh.close()
            except:
                pass
        self.file_handles.clear()
