import os.path as osp

import numpy as np
import pytest
import torch
from PIL import Image

from sde.datasets import ImageFolderWithAttribution


def get_dummy_data(tmp_path, num_channels=3):
    num_classes = 3
    img_root = tmp_path / 'images'
    attr_root = tmp_path / 'attr_maps'
    gt_mask_root = tmp_path / 'gt_masks'
    img_root.mkdir(parents=True)
    attr_root.mkdir(parents=True)
    gt_mask_root.mkdir(parents=True)

    h, w = 10, 10

    sample_index = 0
    num_samples_per_class = 5
    for i in range(num_classes):
        img_sub_dir = img_root / str(i)
        attr_sub_dir = attr_root / str(i)
        gt_mask_sub_dir = gt_mask_root / str(i)
        img_sub_dir.mkdir()
        attr_sub_dir.mkdir()
        gt_mask_sub_dir.mkdir()

        for j in range(num_samples_per_class):
            if num_channels == 3:
                img = np.full((h, w, 3), 255, dtype=np.uint8)
                img = Image.fromarray(img, mode='RGB')
            else:
                img = np.full((h, w), 255, dtype=np.uint8)
                img = Image.fromarray(img, mode='L')
            img.save(img_sub_dir / f'{sample_index}.png')

            attr_map = Image.fromarray(np.full((h, w), 255, dtype=np.uint8), mode='L')
            attr_map.save(attr_sub_dir / f'{sample_index}.png')

            gt_mask = Image.fromarray(np.full((h, w), 255, dtype=np.uint8), mode='L')
            gt_mask.save(gt_mask_sub_dir / f'{sample_index}.png')
            sample_index += 1

    return str(img_root), str(attr_root), str(gt_mask_root)


def test_valid_fn():
    invalid_imgs = [torch.ones((1, 3, 10, 10)), torch.ones((10, 10)), torch.ones((5, 10, 10))]
    for invalid_img in invalid_imgs:
        with pytest.raises(ValueError, match='img has invalid shape'):
            ImageFolderWithAttribution._valid(invalid_img, torch.zeros([]), None)

    img = torch.ones((3, 10, 10))
    with pytest.raises(ValueError, match='img has spatial size'):
        ImageFolderWithAttribution._valid(img, torch.ones((1, 10, 10)), None)

    with pytest.raises(ValueError, match='img has spatial size'):
        ImageFolderWithAttribution._valid(img, torch.ones((10, 10)), torch.ones(10, 20))


@pytest.mark.parametrize('num_channels', [1, 3])
def test_image_folder_with_attribution(tmp_path, num_channels):
    img_root, attr_root, gt_mask_root = get_dummy_data(tmp_path, num_channels)
    dataset = ImageFolderWithAttribution(
        num_channels, img_root, attr_root, attr_file_format='png', gt_mask_root=gt_mask_root, cls_to_ind='str_to_int')
    assert len(dataset) == 15

    for sample_index in range(len(dataset)):
        sample = dataset[sample_index]
        img_file = sample['img_file']
        assert int(osp.dirname(img_file)) == sample['label']

        img = sample['img']
        assert img.shape == (num_channels, 10, 10)
        # ToTensor will perform value scaling, while PILToTensor will not
        full_value = 255.0 if num_channels == 3 else 1.0
        torch.testing.assert_close(img, torch.full((num_channels, 10, 10), full_value, dtype=torch.float32))

        attr_map = sample['attr_map']
        torch.testing.assert_close(attr_map, torch.full((10, 10), 1.0, dtype=torch.float32))
        # gt_mask is not converted to float if the attr_file_format is png
        gt_mask = sample['gt_mask']
        torch.testing.assert_close(gt_mask, torch.full((10, 10), 255, dtype=torch.uint8))
