import os
import sys
import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
from argparse import ArgumentParser
from typing import List, Dict, Optional

from storage.storage_adapter import SSDStorageAdapter
from strategies.engine import ssd_offload_train_one_batch
from strategies.gpu_working_set import GPUWorkingSet


class DummyGaussians:
    def __init__(self, num_points: int = 100000, device: str = "cuda", block_size: int = 4096):
        self.device = device
        self._xyz = nn.Parameter(torch.randn(num_points, 3, device=device))
        self._scaling = nn.Parameter(torch.randn(num_points, 3, device=device) * 0.01)
        self._rotation = nn.Parameter(torch.zeros(num_points, 4, device=device))
        self._rotation.data[:, 0] = 1.0
        self._opacity = nn.Parameter(torch.ones(num_points, 1, device=device) * 0.5)
        self._features_dc = nn.Parameter(torch.randn(num_points, 3, device=device))
        self._features_rest = nn.Parameter(torch.randn(num_points, 45, device=device))
        
        self.gpu_working_set_manager = GPUWorkingSet(
            num_total_gaussians=num_points,
            block_size=block_size,
            device=device
        )
        
        self.optimizer = torch.optim.Adam([
            {'params': self._xyz, 'lr': 0.00016},
            {'params': self._scaling, 'lr': 0.005},
            {'params': self._rotation, 'lr': 0.001},
            {'params': self._opacity, 'lr': 0.05},
            {'params': self._features_dc, 'lr': 0.0025},
            {'params': self._features_rest, 'lr': 0.0025 / 20},
        ])

    @property
    def get_xyz(self):
        return self._xyz

    @property
    def get_scaling(self):
        return self._scaling

    @property
    def get_rotation(self):
        return self._rotation

    @property
    def get_opacity(self):
        return self._opacity

    @property
    def get_features(self):
        return torch.cat([self._features_dc.unsqueeze(1), self._features_rest.reshape(-1, 15, 3)], dim=1)


class DummyCamera:
    def __init__(self, width: int = 1920, height: int = 1080, device: str = "cuda", global_idx: int = 0):
        self.image_width = width
        self.image_height = height
        self.FoVx = 1.2
        self.FoVy = 0.9
        self.device = device
        self.global_idx = global_idx

        self.world_view_transform = torch.eye(4, device=device, dtype=torch.float32)
        self.world_view_transform[2, 3] = -10.0
        
        znear, zfar = 0.01, 100.0
        fovx = self.FoVx
        tanHalfFovX = np.tan(fovx / 2)
        proj = torch.zeros(4, 4, device=device, dtype=torch.float32)
        proj[0, 0] = 1.0 / tanHalfFovX
        proj[1, 1] = 1.0 / tanHalfFovX
        proj[2, 2] = zfar / (zfar - znear)
        proj[2, 3] = -(zfar * znear) / (zfar - znear)
        proj[3, 2] = 1.0
        
        self.projection_matrix = proj
        self.full_proj_transform = proj @ self.world_view_transform
        self.camera_center = torch.tensor([0.0, 0.0, 10.0], device=device)
        self.original_image = torch.rand(3, height, width, device=device)


class DummyPipeArgs:
    def __init__(self):
        self.convert_SHs_python = False
        self.compute_cov3D_python = False


def dummy_render(viewpoint_camera, pc, pipe, bg, visible_mask=None, **kwargs):
    H, W = viewpoint_camera.image_height, viewpoint_camera.image_width
    rendered_image = torch.rand(3, H, W, device=pc.get_xyz.device, requires_grad=True)
    radii = torch.ones(pc.get_xyz.shape[0], device=pc.get_xyz.device)
    return {
        "render": rendered_image,
        "viewspace_points": pc.get_xyz[:, :2].clone().requires_grad_(True),
        "visibility_filter": torch.ones(pc.get_xyz.shape[0], dtype=torch.bool, device=pc.get_xyz.device),
        "radii": radii
    }


def l1_loss(pred, target):
    return torch.abs(pred - target).mean()


def run_training(
    num_points: int = 100000,
    batch_size: int = 4,
    num_iterations: int = 100,
    storage_dir: str = "./ssd_storage",
    device: str = "cuda"
):
    print(f"=== SSD Offload Minimal Training Demo ===")
    print(f"Num points: {num_points}")
    print(f"Batch size: {batch_size}")
    print(f"Iterations: {num_iterations}")
    print(f"Storage dir: {storage_dir}")
    print()

    print("[1/4] Creating dummy Gaussians...")
    gaussians = DummyGaussians(num_points=num_points, device=device, block_size=4096)

    print("[2/4] Creating dummy cameras...")
    cameras = [DummyCamera(device=device, global_idx=i) for i in range(batch_size * 10)]

    print("[3/4] Initializing SSD Storage Adapter...")
    storage_adapter = SSDStorageAdapter(
        gaussians=gaussians,
        cameras=cameras,
        storage_dir=storage_dir,
        block_size=4096,
        max_ram_gb=8.0
    )
    print(f"  - Total Gaussians: {storage_adapter.total_gaussians}")
    print(f"  - Num Blocks: {storage_adapter.num_blocks}")
    print(f"  - Block Size: {storage_adapter.block_size}")

    print("[4/4] Starting training loop...")
    background = torch.zeros(3, device=device)
    pipe = DummyPipeArgs()

    for iteration in range(1, num_iterations + 1):
        batch_start = (iteration - 1) * batch_size % len(cameras)
        batched_cameras = cameras[batch_start:batch_start + batch_size]
        training_schedule = list(range(len(batched_cameras)))

        losses, cam_indices, batch_time = ssd_offload_train_one_batch(
            gaussians=gaussians,
            scene=None,
            batched_cameras=batched_cameras,
            background=background,
            pipe_args=pipe,
            storage_adapter=storage_adapter,
            training_schedule=training_schedule,
            render_fn=dummy_render,
            loss_fn=l1_loss,
            iteration=iteration
        )

        total_loss = sum(losses) / len(losses) if losses else 0.0

        if iteration % 10 == 0 or iteration == 1:
            loss_val = total_loss.item() if isinstance(total_loss, torch.Tensor) else total_loss
            print(f"  Iter {iteration:4d}/{num_iterations} | Loss: {loss_val:.6f} | Time: {batch_time:.3f}s")

    print()
    print("=== Training Complete ===")
    print(f"Final storage stats:")
    print(f"  - Storage dir: {storage_adapter.storage_dir}")


def main():
    parser = ArgumentParser(description="SSD Offload Minimal Training Demo")
    parser.add_argument("--num_points", type=int, default=100000, help="Number of Gaussian points")
    parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
    parser.add_argument("--iterations", type=int, default=100, help="Number of iterations")
    parser.add_argument("--storage_dir", type=str, default="./ssd_storage", help="SSD storage directory")
    parser.add_argument("--device", type=str, default="cuda", help="Device (cuda or cpu)")
    args = parser.parse_args()

    run_training(
        num_points=args.num_points,
        batch_size=args.batch_size,
        num_iterations=args.iterations,
        storage_dir=args.storage_dir,
        device=args.device
    )


if __name__ == "__main__":
    main()
