import os
from PIL import Image
from typing import List, Tuple
import torch
from torch.utils.data import Dataset
from torchvision import transforms

from torch.utils.data import DataLoader

class TPDataset(Dataset):
    def __init__(self, root_dir: str, split: str = "pnp_in", traj_len: int = 64, transform=None):
        self.root_dir = os.path.join(root_dir, split)
        self.traj_len = traj_len
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])
        self.samples = self._load_samples()

    def _load_samples(self) -> List[Tuple[List[str], str, torch.Tensor, torch.Tensor]]:
        samples = []
        for traj in os.listdir(self.root_dir):
            traj_path = os.path.join(self.root_dir, traj)
            img_dir = os.path.join(traj_path, "images0")
            lang_file = os.path.join(traj_path, "lang.txt")

            if not os.path.exists(lang_file) or not os.path.isdir(img_dir):
                continue

            with open(lang_file, "r") as f:
                goal_text = f.read().strip()
                

            img_files = sorted([
                os.path.join(img_dir, f)
                for f in os.listdir(img_dir)
                if f.endswith(".jpg")
            ])

            if len(img_files) == 0:
                continue
            

            total_len = len(img_files)
            progress_labels  = torch.linspace(1 / total_len, 1.0, steps=total_len)
            for start in range(0, total_len, self.traj_len):
                end = start + self.traj_len
                frame_paths = img_files[start:end]
                real_len = len(frame_paths)

                if real_len < self.traj_len:
                    frame_paths += [frame_paths[-1]] * (self.traj_len - real_len)

                progress = torch.zeros(self.traj_len)
                progress[:real_len] = progress_labels[start:end]

                mask = torch.zeros(self.traj_len)
                mask[:real_len] = 1.0

                samples.append((frame_paths, goal_text, progress, mask))

        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        frame_paths, goal_text, progress_label, valid_mask = self.samples[idx]
        frames = [self.transform(Image.open(p).convert("RGB")) for p in frame_paths]
        frames_tensor = torch.stack(frames)  # [T, C, H, W]

        return frames_tensor, goal_text, progress_label, valid_mask


def temporal_perturb_inplace(frames, progress_label, valid_mask, perturb_rate, device="cuda"):
    """
    In-place perturbation of frames, progress labels (and optionally valid mask).
    
    Args:
        frames (Tensor): (B, T, C, H, W)
        progress_label (Tensor): (B, T)
        valid_mask (Tensor): (B, T)
        perturb_rate (float): fraction of valid frames to shuffle
    """
    B = frames.shape[0]
    for b in range(B):
        L = valid_mask[b].sum().long().item()
        perturb_len = int(L * perturb_rate)
        if perturb_len > 1:
            idx = torch.randperm(perturb_len).to(frames.device)
            frames[b, :perturb_len] = frames[b, idx]
            progress_label[b, :perturb_len] = progress_label[b, idx]
            # Optionally, valid_mask[b, :perturb_len] = valid_mask[b, idx]  # usually not needed if valid_mask stays the same
    return frames, progress_label




def load_dataset(cfg):
    """
    Loads training and validation datasets along with their respective data loaders.
    """
    # Get the split and trajectory length from the config or use default
    split = cfg.dataset.train_split
    traj_length = cfg.dataset.get("traj_len", 50)

    # Load datasets for training and validation
    train_dataset = TPDataset(
        root_dir=cfg.dataset.root_dir_train,
        split=split,
        traj_len=traj_length
    )

    val_dataset = TPDataset(
        root_dir=cfg.dataset.root_dir_val,
        split=cfg.dataset.val_split,
        traj_len=traj_length
    )

    # Create DataLoader for training and validation
    train_loader = DataLoader(
        train_dataset,
        batch_size=cfg.dataset.batch_size,
        shuffle=cfg.dataset.shuffle,
        num_workers=cfg.dataset.num_workers,
        pin_memory=cfg.dataset.pin_memory,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=cfg.dataset.batch_size,
        shuffle=cfg.dataset.shuffle,  # No need to shuffle validation data
        num_workers=cfg.dataset.num_workers,
        pin_memory=cfg.dataset.pin_memory,
    )

    return train_loader, val_loader, traj_length

    
