"""
Data generation and loading script for video data.
Now includes batched generation to handle large datasets and memory-safe
CPU-based generation with pre-allocation to prevent crashes.
"""
import torch
from pathlib import Path
from dataclasses import dataclass

def set_seed(seed: int):
    """Sets the random seed for reproducibility."""
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

@dataclass
class DataConfig:
    """Configuration for data generation."""
    data_dir: Path = Path("data")
    n_train: int = 128 * 100
    n_val: int = 1000
    n_test: int = 1000
    no_timesteps: int = 30
    image_size: int = 16
    cube_size: int = 3
    generation_batch_size: int = 2048
    jump_prob: float = 0.3

class JumpDataLoader:
    """Handles loading of pre-generated video data."""
    def __init__(self, data_dir: Path, seed: int, image_size: int, jump_prob=0.3):
        self.data_dir = data_dir
        self.seed = seed
        self.image_size = image_size
        self.data_dir.mkdir(parents=True, exist_ok=True)
        self.jump_prob = jump_prob

    def get_data(
        self,
        subsample_time: int,
        no_timesteps: int,
        batch_size: int = 64,
        device: str = "cuda",
        equidist: bool = False
    ):
        """Loads data and creates a DataLoader for training."""
        # Load data from disk (it's on CPU by default)
        train_data = torch.load(self.data_dir / f"train_data_size{self.image_size}_jump_prob{self.jump_prob}.pt")
        val_data = torch.load(self.data_dir / f"val_data_size{self.image_size}_jump_prob{self.jump_prob}.pt")
        test_data = torch.load(self.data_dir / f"test_data_size{self.image_size}_jump_prob{self.jump_prob}.pt")
        
        # We need the full time tensor for subsampling
        time_1d = torch.linspace(0., no_timesteps - 1, no_timesteps)
        full_train_times = time_1d.view(1, no_timesteps, 1, 1, 1).expand(train_data.shape[0], -1, 1, self.image_size, self.image_size)

        # The subsampling creates the combined data+time tensor, which is then moved to the target device
        data_train_subsampled = random_subsample(
            train_data, full_train_times, subsample_time,
            random_seed=self.seed
        ).to(device)

        dataloader = torch.utils.data.DataLoader(
            data_train_subsampled, batch_size=batch_size, shuffle=True, drop_last=True
        )

        times_eval = random_times(
            no_timesteps, subsample_time, random_seed=0, equidist=equidist, device=device
        )

        # Return the full, non-subsampled validation and test sets, moved to the target device
        return dataloader, val_data.to(device), test_data.to(device), times_eval


def create_video_data_with_time(
    num_samples: int,
    no_timesteps: int = 50,
    image_size: int = 32,
    cube_size: int = 3,
    generation_batch_size: int = 2048,
    jump_prob = 0.3
) -> torch.Tensor:
    """
    Generates a video dataset using a memory-efficient pre-allocation strategy.
    """
    print(f"Generating {num_samples} samples in batches of {generation_batch_size} on CPU...")
    
    # --- FIX: Pre-allocate the entire data tensor on CPU ---
    full_data = torch.zeros((num_samples, no_timesteps, 1, image_size, image_size), device="cpu")
    num_generated = 0
    
    while num_generated < num_samples:
        current_batch_size = min(generation_batch_size, num_samples - num_generated)
        
        data_batch = torch.zeros((current_batch_size, no_timesteps, 1, image_size, image_size), device="cpu")

        for n in range(current_batch_size):
            x = torch.randint(low=0, high=image_size - cube_size, size=(1,)).item()
            y = torch.randint(low=0, high=image_size - cube_size, size=(1,)).item()
            direction = -1 if torch.rand(1).item() < 0.5 else 1
            v_direction = -3 if torch.rand(1).item() < 0.5 else 3

            for t in range(no_timesteps):
                frame = torch.zeros((image_size, image_size), device="cpu")
                frame[y:y+cube_size, x:x+cube_size] = 1.0
                data_batch[n, t, 0] = frame

                new_x = x + direction
                if 0 <= new_x <= image_size - cube_size:
                    x = new_x
                else: # Bounce off walls
                    direction *= -1
                    x += direction
                
                if torch.rand(1).item() < jump_prob:
                    new_y = y + v_direction
                    if 0 <= new_y <= image_size - cube_size:
                        y = new_y
                    else:
                        v_direction *= -1
                        y += v_direction

        # --- FIX: Fill the pre-allocated tensor in place ---
        start_idx = num_generated
        end_idx = num_generated + current_batch_size
        full_data[start_idx:end_idx] = data_batch
        print(f"Generated {num_generated} samples...")
        num_generated += current_batch_size
        del data_batch

    # Create and attach the time channel on CPU
    time = torch.linspace(0., no_timesteps - 1, no_timesteps, device="cpu")
    time_channel = time.view(1, no_timesteps, 1, 1, 1).expand(num_samples, -1, 1, image_size, image_size)
    print(f"Created and attached time channel...")
    final_data = torch.cat([full_data, time_channel], dim=2)
    print(f"Final data size: {final_data.shape}")
    del full_data
    del time_channel
    return final_data


def random_subsample(
    data: torch.Tensor,
    times: torch.Tensor,
    subsample_time: int,
    random_seed: int = 0
) -> torch.Tensor:
    """Randomly subsamples timesteps from a batch of video time-series."""
    set_seed(random_seed)
    
    batch_size, total_timesteps, channels, H, W = data.shape

    perm = torch.rand(batch_size, total_timesteps,device=data.device).argsort(dim=1)
    perm_sorted = torch.sort(perm[:, :subsample_time], dim=1).values

    # Use torch.gather for robust, vectorized indexing
    repeated_indices = perm_sorted.view(batch_size, subsample_time, 1, 1, 1).repeat(1, 1, channels, H, W)
    subsampled_data = torch.gather(data, 1, repeated_indices)
    subsampled_times = torch.gather(times, 1, repeated_indices)

    return torch.cat([subsampled_data, subsampled_times], dim=2)


def random_times(
    no_timesteps: int,
    subsample_time: int,
    random_seed: int = 0,
    device: str = "cuda",
    equidist: bool = False
) -> torch.Tensor:
    """Generates a sorted tensor of unique random time indices."""
    if equidist:
        jump_every = no_timesteps // subsample_time
        return torch.arange(0, no_timesteps, jump_every, device=device)
    else:
        set_seed(random_seed)
        perm = torch.randperm(no_timesteps, device=device)[:subsample_time]
        return torch.sort(perm).values


def generate_and_save_data(config: DataConfig, device: str = "cuda"):
    """Generates and saves train, validation, and test datasets if they don't exist."""
    print(f"Checking for video datasets with size={config.image_size}...")
    config.data_dir.mkdir(parents=True, exist_ok=True)

    # Train
    train_path = config.data_dir / f"train_data_size{config.image_size}_jump_prob{config.jump_prob}.pt"
    if not train_path.exists():
        print(f"Creating training data (size={config.image_size})...")
        train_data = create_video_data_with_time(
            config.n_train, config.no_timesteps, config.image_size, config.cube_size, 
            config.generation_batch_size,config.jump_prob
        )
        torch.save(train_data[:, :, 0:1], train_path)
        del train_data
    
    # Validation
    val_path = config.data_dir / f"val_data_size{config.image_size}_jump_prob{config.jump_prob}.pt"
    if not val_path.exists():
        print(f"Creating validation data (size={config.image_size})...")
        val_data = create_video_data_with_time(
            config.n_val, config.no_timesteps, config.image_size, config.cube_size, 
            config.generation_batch_size,config.jump_prob

        )
        torch.save(val_data[:, :, 0:1], val_path)
        del val_data

    # Test
    test_path = config.data_dir / f"test_data_size{config.image_size}_jump_prob{config.jump_prob}.pt"
    if not test_path.exists():
        print(f"Creating test data (size={config.image_size})...")
        test_data = create_video_data_with_time(
            config.n_test, config.no_timesteps, config.image_size, config.cube_size, 
            config.generation_batch_size,config.jump_prob

        )
        torch.save(test_data[:, :, 0:1], test_path)
        del test_data
    
    print(f"All datasets for size={config.image_size} are available.")


