"""
Loads and creates a dataloader for the DomainNet dataset.
"""
import os
import random
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from experiments.helpers import worker_init_fn

class DomainNetDataset(Dataset):
    """Creates Pytorch Dataset for DomainNet.
    Optionally filters by domain."""
    def __init__(
        self,
        path_to_annotations_dir: str,
        path_to_dataset_dir: str,
        domains: list[str],
        classes: list[str],
        transform: transforms.Compose,
        split_type: str = "train"
    ):
        self.domain_to_label_map = {domain: i for i, domain in enumerate(domains)}
        self.path_to_dataset_dir = path_to_dataset_dir
        self.transform = transform
        self.class_to_label_map = [""] * len(classes)

        self.samples = []
        
        for domain in domains:
            annotation_file = os.path.join(path_to_annotations_dir, f"{domain}_{split_type}.txt")
            with open(annotation_file, 'r') as f:
                for line in f:
                    path, label = line.strip().split()
                    label = int(label)
                    class_name = path.split('/')[1].replace('_', ' ')
                    self.class_to_label_map[label] = class_name
                    sample = {
                        "image_path": path,
                        "class_label": label,
                        "class_name": class_name,
                        "domain_name": domain,
                        "domain_label": self.domain_to_label_map[domain]
                    }
                    
                    self.samples.append(sample)

        self.class_names = self.class_to_label_map

    def __len__(self):
        """Return the number of samples in the dataset."""
        return len(self.samples)

    def __getitem__(self, index):
        """Load and return a single sample.

        Returns
        -------
        image : Tensor
            Transformed image tensor.
        int
            Integer class label.
        metadata : Tensor
            Tensor with `[class_label, domain_label]`.
        """
        sample = self.samples[index]
        full_path = os.path.join(self.path_to_dataset_dir, sample["image_path"])

        image = Image.open(full_path)
        if self.transform:
            image = self.transform(image)

        metadata = torch.tensor([
            sample["class_label"],
            sample["domain_label"]
        ])

        return image, int(sample["class_label"]), metadata
    
def get_domainnet_dataloader(
        trial_num: int,
        path_to_annotations_dir: str,
        path_to_dataset_dir: str,
        domains: list[str],
        classes: list[str],
        transform: transforms.Compose,
        split_type: str = "train",
        batch_size=32,
        num_workers=4,
        pin_memory=True,
        is_training=True,
        create_val_from_train=False
    ):
    """Create a DataLoader for DomainNet.

    Returns
    -------
    dataset : DomainNetDataset
        The constructed dataset instance.
    loader : DataLoader
        A PyTorch DataLoader with deterministic worker seeding.
    """
    g = torch.Generator()
    g.manual_seed(trial_num)

    dataset = DomainNetDataset(
        path_to_annotations_dir=path_to_annotations_dir,
        path_to_dataset_dir=path_to_dataset_dir,
        domains=domains,
        classes=classes,
        transform=transform,
        split_type=split_type,
    )

    if is_training:
        shuffle = True
    else:
        shuffle = False

    loader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        worker_init_fn=worker_init_fn,
        generator=g
    )
    return dataset, loader  

