import os
import numpy as np
import torch
from PIL import Image
import torchvision
from torchvision import transforms
from src.fl_datasets.augmentation import RandAugmentMC
from src.fl_datasets.utils import split_data, split_clients, reassign_target
from .dataset_base import BasicDataset


def get_fashionmnist(cfgs, name, num_labels, num_classes, data_dir='./data'):
    data_cfgs = cfgs['Dataset']
    name = name.split('_')[0]  # fashionmnist_openset -> fashionmnist
    data_dir = os.path.join(data_dir, name.lower())
    dset = getattr(torchvision.datasets, "FashionMNIST")  # FashionMNIST

    # Train dataset
    train_dset = dset(data_dir, train=True, download=True)
    train_data, train_targets = train_dset.data, train_dset.targets

    # Convert grayscale to RGB by replicating the channel 3 times (옵션: 현재는 사용 안 함)
    train_data = train_data.numpy()  # Convert to numpy array
    # train_data = np.stack([train_data] * 3, axis=-1)  # [N, H, W] -> [N, H, W, 3]
    print("\n\ntrain_data shape: ", train_data.shape)
    seen_classes = set([0, 1, 2, 3, 4, 6])  # Define seen classes
    num_all_classes = 10

    # Split labeled / unlabeled data
    lb_data, lb_targets, ulb_data, ulb_targets = split_data(
        cfgs=data_cfgs,
        data=train_data,
        target=train_targets.numpy(),
        num_labels=num_labels,
        num_all_classes=num_all_classes,
        seen_classes=seen_classes,
        index=None,
        include_lb_to_ulb=False
    )

    # Split clients dataset
    clients_set = split_clients(
        cfgs=data_cfgs,
        ulb_data=ulb_data,
        ulb_targets=ulb_targets
    )
    for cid, client_data in clients_set.items():
        data_size = len(client_data['data'])
        target_size = len(client_data['targets'])
        print(f"Client {cid}: Data size = {data_size}, Target size = {target_size}")

    # Test dataset
    test_dset = dset(data_dir, train=False, download=True)
    test_data, test_targets = test_dset.data, reassign_target(test_dset.targets, num_all_classes, seen_classes)

    # Convert grayscale to RGB for test data (옵션: 현재는 사용 안 함)
    test_data = test_data.numpy()  # Convert to numpy array
    # test_data = np.stack([test_data] * 3, axis=-1)  # [N, H, W] -> [N, H, W, 3]
    print("\n\ntest_data shape: ", test_data.shape)
    return lb_data, lb_targets, clients_set, test_data, test_targets


def get_fashionmnist_server(cfgs, lb_data, lb_targets, test_data, test_targets):
    data_cfgs = cfgs['Dataset']
    name = data_cfgs['dataset']
    name = name.split('_')[0]

    alg = cfgs['server_alg']

    # Transform
    transform_weak, transform_strong, transform_val = set_transform(
        alg=alg,
        crop_size=data_cfgs['image_size'],
        crop_ratio=data_cfgs['crop_ratio'],
        name=name
    )

    num_classes = data_cfgs['num_classes']
    lb_dset = BasicDataset(
        alg=alg,
        data=lb_data,
        targets=lb_targets,
        num_classes=num_classes,
        transform=transform_weak,
        is_ulb=False,
        strong_transform=transform_strong
    )

    test_dset = BasicDataset(
        alg='supervised',
        data=test_data,
        targets=test_targets,
        num_classes=num_classes,
        transform=transform_val,
        is_ulb=False,
        strong_transform=None
    )

    return lb_dset, test_dset


def get_fashionmnist_client(cfgs, cid, clients_set):
    data_cfgs = cfgs['Dataset']
    name = data_cfgs['dataset']
    name = name.split('_')[0]

    alg = cfgs['client_alg']

    transform_weak, transform_strong, _ = set_transform(
        alg=alg,
        crop_size=data_cfgs['image_size'],
        crop_ratio=data_cfgs['crop_ratio'],
        name=name
    )

    num_all_classes = 10
    c_dataset = clients_set[cid]
    c_data, c_targets = c_dataset['data'], c_dataset['targets']

    if alg == 'supervised':
        is_ulb = False
    else:
        is_ulb = True
    c_dset = BasicDataset(
        alg=alg,
        data=c_data,
        targets=c_targets,
        num_classes=num_all_classes,
        transform=transform_weak,
        is_ulb=is_ulb,
        strong_transform=transform_strong
    )

    return c_dset


class _EnsureRGB:
    """Ensure PIL image is RGB for augmentations that expect 3 channels."""
    def __call__(self, img):
        # PIL.Image.Image일 경우에만 처리
        if hasattr(img, "mode"):
            return img.convert("RGB") if img.mode != "RGB" else img
        return img


def set_transform(alg, crop_size, crop_ratio, name):
    mean, std = {}, {}
    mean['fashionmnist'] = [0.2860]
    std['fashionmnist'] = [0.3530]

    transform_ = transforms.Compose([
        transforms.Resize(crop_size),
        transforms.ToTensor(),
        transforms.Normalize(mean[name], std[name])
    ])

    transform_weak = transforms.Compose([
        transforms.Resize(crop_size),
        transforms.RandomCrop(crop_size, padding=int(crop_size * (1 - crop_ratio)), padding_mode='reflect'),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean[name], std[name])
    ])

    # 강한 변환 경로: RandAugmentMC가 RGB를 기대하므로 전/후처리 추가
    transform_strong = transforms.Compose([
        transforms.Resize(crop_size),
        transforms.RandomCrop(crop_size, padding=int(crop_size * (1 - crop_ratio)), padding_mode='reflect'),
        transforms.RandomHorizontalFlip(),
        _EnsureRGB(),                               # L → RGB (Cutout 색상 문제 방지)
        RandAugmentMC(n=2, m=10),
        transforms.Grayscale(num_output_channels=1),  # 다시 1채널로 복원
        transforms.ToTensor(),
        transforms.Normalize(mean[name], std[name])
    ])

    transform_eval = transforms.Compose([
        transforms.Resize(crop_size),
        transforms.ToTensor(),
        transforms.Normalize(mean[name], std[name])
    ])

    if alg in ['ours', 'openmatch', 'magmatch', 'prosub']:
        transform_weak = [transform_weak, transform_]
        transform_strong = transform_strong
    else:
        transform_weak = transform_weak
        transform_strong = transform_strong

    return transform_weak, transform_strong, transform_eval
