import os
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from typing import List
from source.constants import CITYSCAPES_PATH


class CityscapesDataset(Dataset):
    def __init__(self, images, targets, flip=False):
        self.images = images
        self.targets = targets
        self.flip = flip

    def __len__(self):
        return len(self.images) * 2 if self.flip else len(self.images)

    def __getitem__(self, idx):
        if idx >= len(self.images):
            idx -= len(self.images)
            image = self.images[idx].flip(2)  # flip image horizontally
        else:
            image = self.images[idx]
        target = self.targets[idx]

        return image, target


def _adjust(images, targets):
    # normalize to [0, 1]
    targets /= 224**2
    # normalize images to [0, 1]
    images = images / 255

    return images, targets


def get_train_val(target=None, seed=2357):
    images = torch.load(
        os.path.join(CITYSCAPES_PATH, "train_images.pt"), weights_only=True
    ).float()
    targets = torch.load(
        os.path.join(CITYSCAPES_PATH, "train_targets.pt"), weights_only=True
    ).float()

    assert target is None or 0 <= target < targets.shape[1], "Invalid target"

    # select target
    if target is not None:
        targets = targets[:, target]

    images, targets = _adjust(images, targets)

    # split into train, val and test
    rng = np.random.default_rng(seed)
    indices = np.arange(len(images))
    rng.shuffle(indices)

    train_indices = indices[(len(images) // 8) :]
    val_indices = indices[: (len(images) // 8)]

    train_images, train_targets = images[train_indices], targets[train_indices]
    val_images, val_targets = images[val_indices], targets[val_indices]

    return train_images, train_targets, val_images, val_targets


def get_test(target=None):
    images = torch.load(
        os.path.join(CITYSCAPES_PATH, "test_images.pt"), weights_only=True
    ).float()
    targets = torch.load(
        os.path.join(CITYSCAPES_PATH, "test_targets.pt"), weights_only=True
    ).float()

    assert target is None or 0 <= target < targets.shape[1], "Invalid target"

    # select target
    if target is not None:
        targets = targets[:, target]

    images, targets = _adjust(images, targets)

    return images, targets
