import os
from pathlib import Path
import logging

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from .fl_dataset import FederatedDataset
from .utils import VisionDataset_FL

logger = logging.getLogger(__name__)

# ---- Transforms (32×32, CIFAR-100 stats) ----
def cifar100Transformation(augment):
    mean = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
    std  = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

    if augment == "jit":
        # Minimal transform for JIT/on-device pipelines; keep as ToTensor only.
        return transforms.Compose([
            transforms.ToTensor(),
        ])
    elif augment:
        return transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    else:
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

class Cifar100Dataset(FederatedDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # Where torchvision will download raw CIFAR-100
        self.data_dir = self.path_to_data
        # Where we store our unified *.pt files
        self.path_to_data = Path(os.path.join(self.path_to_data, "cifar-100-python"))

        # Keep these for parity with your CIFAR-10 class (use Compose, not nn.Sequential)
        self.jit_augment = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.Normalize(
                (0.5070751592371323, 0.48654887331495095, 0.4409178433670343),
                (0.2673342858792401, 0.2564384629170883, 0.27615047132568404),
            ),
        ])

        self.jit_normalize = transforms.Compose([
            transforms.Normalize(
                (0.5070751592371323, 0.48654887331495095, 0.4409178433670343),
                (0.2673342858792401, 0.2564384629170883, 0.27615047132568404),
            ),
        ])

    def download(self):
        """Create unified train.pt/test.pt once if missing."""
        if not os.path.exists(self.path_to_data):
            logger.info("Generating unified CIFAR-100 dataset")
            self.path_to_data.mkdir(parents=True, exist_ok=True)

            # Test set
            test_path = self.path_to_data / "test.pt"
            test_data = datasets.CIFAR100(root=self.data_dir, train=False, download=True)
            X_test, y_test = test_data.data, np.array(test_data.targets, dtype=np.int64)
            torch.save([X_test, y_test], test_path)

            # Train set
            train_path = self.path_to_data / "train.pt"
            train_data = datasets.CIFAR100(root=self.data_dir, train=True, download=True)
            X_train, y_train = train_data.data, np.array(train_data.targets, dtype=np.int64)
            torch.save([X_train, y_train], train_path)

    def get_dataloader(
        self,
        data_pool,
        partition,
        batch_size,
        num_workers,
        augment,
        cid=None,
        path=None,
        shuffle=False,
        val_ratio=0.0,
        seed=None,
        **kwargs,
    ):
        """
        Return a DataLoader for the requested split and pool.

        Arguments mirror your existing CIFAR-10 API for compatibility.
        """
        data_pool = data_pool.lower()
        assert data_pool in ("server", "train", "test"), "Data pool must be in server, train, or test"

        # Resolve on-disk path for the requested partition
        if path is not None and os.path.exists(path):
            # Forced external path (optionally per-client)
            prefix_path = path if cid is None else os.path.join(path, cid)
            path = os.path.join(prefix_path, f"{partition}.pt")
        else:
            if data_pool == "server":
                assert cid is None
                path = os.path.join(self.dataset_fl_root, f"{partition}.pt")
            elif data_pool == "train":
                # client training pool
                prefix_path = self.fed_train_dir if cid is None else os.path.join(self.fed_train_dir, cid)
                path = os.path.join(prefix_path, f"{partition}.pt")
            else:
                # client test pool
                prefix_path = self.fed_test_dir if cid is None else os.path.join(self.fed_test_dir, cid)
                path = os.path.join(prefix_path, f"{partition}.pt")

        # Build dataset with the appropriate transform
        transform = cifar100Transformation(augment)

        if val_ratio:
            # Match CIFAR-10 behavior: only split from train, deterministic seed required, and usually used with 'jit'
            assert partition.lower() == "train", "Validation split only supported from the training partition"
            assert seed is not None, "Provide a 'seed' for deterministic val split"

            dataset = VisionDataset_FL(path_to_data=path, transform=transform)

            val_len = int(val_ratio * len(dataset))
            train_len = len(dataset) - val_len
            splits = torch.utils.data.random_split(
                dataset, [train_len, val_len], generator=torch.Generator().manual_seed(seed)
            )

            return [
                DataLoader(
                    ds,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory=True,
                    drop_last=False,
                    shuffle=shuffle,
                    **kwargs,
                )
                for ds in splits
            ]
        else:
            dataset = VisionDataset_FL(path_to_data=path, transform=transform)
            return DataLoader(
                dataset,
                batch_size=batch_size,
                num_workers=num_workers,
                pin_memory=True,
                drop_last=False,
                shuffle=shuffle,
                **kwargs,
            )