# Standard library imports
import json
import os
from random import shuffle

# Third Party Imports
import numpy as np
import pandas as pd
import torch
import torchio.transforms as tio
import torchvision.transforms as t
from torch.utils.data import Dataset

# Local Imports
from utils import window

class RSNADataset(Dataset):
    def __init__(
        self,
        data_dir,
        op,
        weak_supervision=False,
    ):
        image_path = os.path.join(data_dir, f"{op}_images.json")
        series_path = os.path.join(data_dir, f"{op}_series.json")
        self.series_dir = os.path.join(data_dir, "series")
        self.dicom_dir = os.path.join(data_dir, "stage_2_train")
        with open(image_path, "r", encoding="utf-8") as image_f:
            self.images = json.load(image_f)
        with open(series_path, "r", encoding="utf-8") as series_f:
            self.series_dictionary = json.load(series_f)
        self.series_ids = list(self.series_dictionary.keys())
        self.MEAN = torch.tensor([0.485, 0.456, 0.406])
        self.STD = torch.tensor([0.229, 0.224, 0.225])
        self.normalize = t.Normalize(
            mean=self.MEAN,
            std=self.STD,
        )
        self.weak_supervision = weak_supervision

    def __len__(self):
        return len(self.series_ids) if self.weak_supervision else len(self.images)

    def __getitem__(self, idx):
        image_data = self.images[idx]

        image_id = image_data[0]
        image = np.load(
            os.path.join(self.dicom_dir, f"ID_{image_id}.npy"), allow_pickle=True
        ).astype('float')
        image = torch.Tensor(image).float()
        image = image.squeeze()
        image = window(image, window_level=40, window_width=80)
        image = image.repeat(3, 1, 1)
        image_norm = self.normalize(image)

        label = torch.Tensor([int(image_data[1])]).float()
        return image, image_norm, label