import torchvision
from torch.utils.data import Dataset

from PIL import Image
import requests
import os
import numpy as np


class Ombria(Dataset):

    link = "https://github.com/geodrak/OMBRIA/archive/refs/heads/master.zip"

    def __init__(self, root, split='train', transform=None, target_transform=None, download=False, resolution=256):
        self.transform = transform
        self.target_transform = target_transform
        self.resolution = resolution

        if resolution not in [64, 128, 256]:
            raise ValueError("Resolution must be 64, 128 or 256")

        # Download and extract the dataset
        if download:
            if os.path.exists(root):
                print("Dataset already downloaded.")
            else:
                os.makedirs(root, exist_ok=False)
                response = requests.get(self.link)
                with open(root + "/OMBRIA.zip", "wb") as file:
                    file.write(response.content)

                # unzip
                import zipfile
                import shutil

                os.makedirs(root + '/raw', exist_ok=True)
                with zipfile.ZipFile(root + "/OMBRIA.zip", 'r') as zip_ref:
                    zip_ref.extractall(root + '/raw')
                os.remove(root + "/OMBRIA.zip")

                # Move the folders OmbriaS1 and OmbriaS2 to the root
                os.rename(root + "/raw/OMBRIA-master/OmbriaS1", root + "/OmbriaS1")
                os.rename(root + "/raw/OMBRIA-master/OmbriaS2", root + "/OmbriaS2")
                shutil.rmtree(root + "/raw/OMBRIA-master")

        # The dataset is contained in folders called "OmbriaS1" and OmbriaS2"
        self.data = []
        self.masks = []

        before_paths = os.listdir(os.path.join(root, "OmbriaS1", split, "BEFORE"))
        ids = [path.split("/")[-1].split(".")[0].split("_")[-1] for path in before_paths]
        for id in ids:
            sample = []
            # Load the sample
            for modality in ["S1", "S2"]:
                for state in ["BEFORE", "AFTER"]:
                    im = np.array(Image.open(os.path.join(root, f"Ombria{modality}", split, state, f"{modality}_{state.lower()}_{id}.png"))) / 255
                    if modality == "S1":
                        im = im[..., np.newaxis]
                    sample.append(im)
            # Concatenate the modalities
            sample = np.concatenate(sample, axis=-1)
            # Resize
            if self.resolution == 128:
                sample = sample[::2, ::2]
            elif self.resolution == 64:
                sample = sample[::4, ::4]
            self.data.append(sample)

            # Load the mask
            mask = np.array(Image.open(os.path.join(root, "OmbriaS1", split, "MASK", f"S1_mask_{id}.png"))) / 255
            self.masks.append(mask)

    def __getitem__(self, index):
        return self.data[index], self.masks[index]

    def __len__(self):
        return len(self.data)
